Browse Source

Rebase to upstream

pull/10109/head
dolfies 3 years ago
parent
commit
c932b5d06b
  1. 44
      .github/workflows/crowdin_download.yml
  2. 44
      .github/workflows/crowdin_upload.yml
  3. 3
      discord/__init__.py
  4. 38
      discord/__main__.py
  5. 69
      discord/abc.py
  6. 59
      discord/activity.py
  7. 69
      discord/asset.py
  8. 36
      discord/audit_logs.py
  9. 60
      discord/channel.py
  10. 259
      discord/client.py
  11. 13
      discord/colour.py
  12. 10
      discord/components.py
  13. 10
      discord/context_managers.py
  14. 190
      discord/embeds.py
  15. 10
      discord/emoji.py
  16. 45
      discord/enums.py
  17. 24
      discord/ext/commands/_types.py
  18. 334
      discord/ext/commands/bot.py
  19. 134
      discord/ext/commands/cog.py
  20. 16
      discord/ext/commands/context.py
  21. 153
      discord/ext/commands/converter.py
  22. 4
      discord/ext/commands/cooldowns.py
  23. 195
      discord/ext/commands/core.py
  24. 33
      discord/ext/commands/errors.py
  25. 24
      discord/ext/commands/flags.py
  26. 585
      discord/ext/commands/help.py
  27. 37
      discord/ext/commands/view.py
  28. 182
      discord/ext/tasks/__init__.py
  29. 2
      discord/file.py
  30. 21
      discord/flags.py
  31. 84
      discord/gateway.py
  32. 139
      discord/guild.py
  33. 73
      discord/http.py
  34. 21
      discord/integrations.py
  35. 59
      discord/invite.py
  36. 58
      discord/member.py
  37. 10
      discord/mentions.py
  38. 2293
      discord/message.py
  39. 5
      discord/opus.py
  40. 15
      discord/partial_emoji.py
  41. 19
      discord/permissions.py
  42. 31
      discord/player.py
  43. 10
      discord/reaction.py
  44. 8
      discord/role.py
  45. 66
      discord/scheduled_event.py
  46. 22
      discord/stage_instance.py
  47. 74
      discord/state.py
  48. 4
      discord/template.py
  49. 70
      discord/threads.py
  50. 1
      discord/types/activity.py
  51. 1
      discord/types/channel.py
  52. 242
      discord/types/interactions.py
  53. 4
      discord/types/scheduled_event.py
  54. 14
      discord/types/widget.py
  55. 10
      discord/user.py
  56. 53
      discord/utils.py
  57. 25
      discord/voice_client.py
  58. 337
      discord/webhook/async_.py
  59. 148
      discord/webhook/sync.py
  60. 25
      discord/widget.py
  61. 16
      docs/_static/custom.js
  62. 18
      docs/_static/style.css
  63. 12
      docs/_templates/layout.html
  64. 1004
      docs/api.rst
  65. 2
      docs/conf.py
  66. 21
      docs/crowdin.yml
  67. 7
      docs/ext/commands/api.rst
  68. 5
      docs/ext/commands/cogs.rst
  69. 19
      docs/ext/commands/commands.rst
  70. 10
      docs/ext/commands/extensions.rst
  71. 4
      docs/faq.rst
  72. 1397
      docs/migrating.rst
  73. 1174
      docs/migrating_to_v1.rst
  74. 7
      docs/quickstart.rst
  75. 1
      examples/background_task.py
  76. 1
      examples/background_task_asyncio.py
  77. 8
      examples/basic_voice.py
  78. 69
      examples/modal.py
  79. 2
      pyproject.toml
  80. 2
      requirements.txt
  81. 2
      setup.py
  82. 103
      tests/test_ext_tasks.py

44
.github/workflows/crowdin_download.yml

@ -0,0 +1,44 @@
name: crowdin download
on:
schedule:
- cron: '0 18 * * 1'
workflow_dispatch:
jobs:
download:
runs-on: ubuntu-latest
environment: Crowdin
name: download
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
ref: master
- name: Install system dependencies
run: |
wget -qO - https://artifacts.crowdin.com/repo/GPG-KEY-crowdin | sudo apt-key add -
echo "deb https://artifacts.crowdin.com/repo/deb/ /" | sudo tee -a /etc/apt/sources.list.d/crowdin.list
sudo apt-get update -qq
sudo apt-get install -y crowdin3
- name: Download translations
shell: bash
run: |
cd docs
crowdin download --all
env:
CROWDIN_API_KEY: ${{ secrets.CROWDIN_API_KEY }}
- name: Create pull request
id: cpr_crowdin
uses: peter-evans/create-pull-request@v3
with:
token: ${{ secrets.GITHUB_TOKEN }}
commit-message: Crowdin translations download
title: "[Crowdin] Updated translation files"
body: |
Created by the [Crowdin download workflow](.github/workflows/crowdin_download.yml).
branch: "auto/crowdin"
author: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

44
.github/workflows/crowdin_upload.yml

@ -0,0 +1,44 @@
name: crowdin upload
on:
workflow_dispatch:
jobs:
upload:
runs-on: ubuntu-latest
environment: Crowdin
name: upload
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Set up CPython 3.x
uses: actions/setup-python@v2
with:
python-version: 3.x
- name: Install system dependencies
run: |
wget -qO - https://artifacts.crowdin.com/repo/GPG-KEY-crowdin | sudo apt-key add -
echo "deb https://artifacts.crowdin.com/repo/deb/ /" | sudo tee -a /etc/apt/sources.list.d/crowdin.list
sudo apt-get update -qq
sudo apt-get install -y crowdin3
- name: Install Python dependencies
run: |
python -m pip install --upgrade pip setuptools wheel
pip install -e .[docs,speed,voice]
- name: Build gettext
run: |
cd docs
make gettext
- name: Upload sources
shell: bash
run: |
cd docs
crowdin upload
env:
CROWDIN_API_KEY: ${{ secrets.CROWDIN_API_KEY }}

3
discord/__init__.py

@ -77,3 +77,6 @@ class _VersionInfo(NamedTuple):
version_info: _VersionInfo = _VersionInfo(major=2, minor=0, micro=0, releaselevel='alpha', serial=2) version_info: _VersionInfo = _VersionInfo(major=2, minor=0, micro=0, releaselevel='alpha', serial=2)
logging.getLogger(__name__).addHandler(logging.NullHandler()) logging.getLogger(__name__).addHandler(logging.NullHandler())
del logging, NamedTuple, Literal, _VersionInfo

38
discord/__main__.py

@ -22,6 +22,10 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations
from typing import Optional, Tuple, Dict
import argparse import argparse
import sys import sys
from pathlib import Path from pathlib import Path
@ -32,7 +36,7 @@ import aiohttp
import platform import platform
def show_version(): def show_version() -> None:
entries = [] entries = []
entries.append('- Python v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}'.format(sys.version_info)) entries.append('- Python v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}'.format(sys.version_info))
@ -49,7 +53,7 @@ def show_version():
print('\n'.join(entries)) print('\n'.join(entries))
def core(parser, args): def core(parser: argparse.ArgumentParser, args: argparse.Namespace) -> None:
if args.version: if args.version:
show_version() show_version()
@ -63,9 +67,11 @@ import config
class Bot(commands.Bot): class Bot(commands.Bot):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(command_prefix=commands.when_mentioned_or('{prefix}'), **kwargs) super().__init__(command_prefix=commands.when_mentioned_or('{prefix}'), **kwargs)
async def setup_hook(self):
for cog in config.cogs: for cog in config.cogs:
try: try:
self.load_extension(cog) await self.load_extension(cog)
except Exception as exc: except Exception as exc:
print(f'Could not load extension {{cog}} due to {{exc.__class__.__name__}}: {{exc}}') print(f'Could not load extension {{cog}} due to {{exc.__class__.__name__}}: {{exc}}')
@ -119,12 +125,16 @@ class {name}(commands.Cog{attrs}):
def __init__(self, bot): def __init__(self, bot):
self.bot = bot self.bot = bot
{extra} {extra}
def setup(bot): async def setup(bot):
bot.add_cog({name}(bot)) await bot.add_cog({name}(bot))
''' '''
_cog_extras = ''' _cog_extras = '''
def cog_unload(self): async def cog_load(self):
# loading logic goes here
pass
async def cog_unload(self):
# clean up logic goes here # clean up logic goes here
pass pass
@ -158,7 +168,7 @@ _cog_extras = '''
# certain file names and directory names are forbidden # certain file names and directory names are forbidden
# see: https://msdn.microsoft.com/en-us/library/windows/desktop/aa365247%28v=vs.85%29.aspx # see: https://msdn.microsoft.com/en-us/library/windows/desktop/aa365247%28v=vs.85%29.aspx
# although some of this doesn't apply to Linux, we might as well be consistent # although some of this doesn't apply to Linux, we might as well be consistent
_base_table = { _base_table: Dict[str, Optional[str]] = {
'<': '-', '<': '-',
'>': '-', '>': '-',
':': '-', ':': '-',
@ -176,7 +186,7 @@ _base_table.update((chr(i), None) for i in range(32))
_translation_table = str.maketrans(_base_table) _translation_table = str.maketrans(_base_table)
def to_path(parser, name, *, replace_spaces=False): def to_path(parser: argparse.ArgumentParser, name: str, *, replace_spaces: bool = False) -> Path:
if isinstance(name, Path): if isinstance(name, Path):
return name return name
@ -214,7 +224,7 @@ def to_path(parser, name, *, replace_spaces=False):
return Path(name) return Path(name)
def newbot(parser, args): def newbot(parser: argparse.ArgumentParser, args: argparse.Namespace) -> None:
new_directory = to_path(parser, args.directory) / to_path(parser, args.name) new_directory = to_path(parser, args.directory) / to_path(parser, args.name)
# as a note exist_ok for Path is a 3.5+ only feature # as a note exist_ok for Path is a 3.5+ only feature
@ -255,7 +265,7 @@ def newbot(parser, args):
print('successfully made bot at', new_directory) print('successfully made bot at', new_directory)
def newcog(parser, args): def newcog(parser: argparse.ArgumentParser, args: argparse.Namespace) -> None:
cog_dir = to_path(parser, args.directory) cog_dir = to_path(parser, args.directory)
try: try:
cog_dir.mkdir(exist_ok=True) cog_dir.mkdir(exist_ok=True)
@ -289,7 +299,7 @@ def newcog(parser, args):
print('successfully made cog at', directory) print('successfully made cog at', directory)
def add_newbot_args(subparser): def add_newbot_args(subparser: argparse._SubParsersAction[argparse.ArgumentParser]) -> None:
parser = subparser.add_parser('newbot', help='creates a command bot project quickly') parser = subparser.add_parser('newbot', help='creates a command bot project quickly')
parser.set_defaults(func=newbot) parser.set_defaults(func=newbot)
@ -299,7 +309,7 @@ def add_newbot_args(subparser):
parser.add_argument('--no-git', help='do not create a .gitignore file', action='store_true', dest='no_git') parser.add_argument('--no-git', help='do not create a .gitignore file', action='store_true', dest='no_git')
def add_newcog_args(subparser): def add_newcog_args(subparser: argparse._SubParsersAction[argparse.ArgumentParser]) -> None:
parser = subparser.add_parser('newcog', help='creates a new cog template quickly') parser = subparser.add_parser('newcog', help='creates a new cog template quickly')
parser.set_defaults(func=newcog) parser.set_defaults(func=newcog)
@ -311,7 +321,7 @@ def add_newcog_args(subparser):
parser.add_argument('--full', help='add all special methods as well', action='store_true') parser.add_argument('--full', help='add all special methods as well', action='store_true')
def parse_args(): def parse_args() -> Tuple[argparse.ArgumentParser, argparse.Namespace]:
parser = argparse.ArgumentParser(prog='discord', description='Tools for helping with discord.py') parser = argparse.ArgumentParser(prog='discord', description='Tools for helping with discord.py')
parser.add_argument('-v', '--version', action='store_true', help='shows the library version') parser.add_argument('-v', '--version', action='store_true', help='shows the library version')
parser.set_defaults(func=core) parser.set_defaults(func=core)
@ -322,7 +332,7 @@ def parse_args():
return parser, parser.parse_args() return parser, parser.parse_args()
def main(): def main() -> None:
parser, args = parse_args() parser, args = parse_args()
args.func(parser, args) args.func(parser, args)

69
discord/abc.py

@ -90,6 +90,9 @@ if TYPE_CHECKING:
GuildChannel as GuildChannelPayload, GuildChannel as GuildChannelPayload,
OverwriteType, OverwriteType,
) )
from .types.snowflake import (
SnowflakeList,
)
PartialMessageableChannel = Union[TextChannel, Thread, DMChannel, PartialMessageable] PartialMessageableChannel = Union[TextChannel, Thread, DMChannel, PartialMessageable]
MessageableChannel = Union[PartialMessageableChannel, GroupChannel] MessageableChannel = Union[PartialMessageableChannel, GroupChannel]
@ -725,7 +728,14 @@ class GuildChannel:
) -> None: ) -> None:
... ...
async def set_permissions(self, target, *, overwrite=_undefined, reason=None, **permissions): async def set_permissions(
self,
target: Union[Member, Role],
*,
overwrite: Any = _undefined,
reason: Optional[str] = None,
**permissions: bool,
) -> None:
r"""|coro| r"""|coro|
Sets the channel specific permission overwrites for a target in the Sets the channel specific permission overwrites for a target in the
@ -769,8 +779,8 @@ class GuildChannel:
await channel.set_permissions(member, overwrite=overwrite) await channel.set_permissions(member, overwrite=overwrite)
.. versionchanged:: 2.0 .. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising This function will now raise :exc:`TypeError` instead of
:exc:`TypeError`. ``InvalidArgument``.
Parameters Parameters
@ -934,7 +944,7 @@ class GuildChannel:
) -> None: ) -> None:
... ...
async def move(self, **kwargs) -> None: async def move(self, **kwargs: Any) -> None:
"""|coro| """|coro|
A rich interface to help move a channel relative to other channels. A rich interface to help move a channel relative to other channels.
@ -952,8 +962,8 @@ class GuildChannel:
.. versionadded:: 1.7 .. versionadded:: 1.7
.. versionchanged:: 2.0 .. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising This function will now raise :exc:`TypeError` or
:exc:`ValueError` or :exc:`TypeError` in various cases. :exc:`ValueError` instead of ``InvalidArgument``.
Parameters Parameters
------------ ------------
@ -1210,7 +1220,7 @@ class Messageable:
content: Optional[str] = ..., content: Optional[str] = ...,
*, *,
tts: bool = ..., tts: bool = ...,
files: List[File] = ..., files: Sequence[File] = ...,
stickers: Sequence[Union[GuildSticker, StickerItem]] = ..., stickers: Sequence[Union[GuildSticker, StickerItem]] = ...,
delete_after: float = ..., delete_after: float = ...,
nonce: Union[str, int] = ..., nonce: Union[str, int] = ...,
@ -1244,7 +1254,7 @@ class Messageable:
content: Optional[str] = ..., content: Optional[str] = ...,
*, *,
tts: bool = ..., tts: bool = ...,
files: List[File] = ..., files: Sequence[File] = ...,
stickers: Sequence[Union[GuildSticker, StickerItem]] = ..., stickers: Sequence[Union[GuildSticker, StickerItem]] = ...,
delete_after: float = ..., delete_after: float = ...,
nonce: Union[str, int] = ..., nonce: Union[str, int] = ...,
@ -1257,19 +1267,19 @@ class Messageable:
async def send( async def send(
self, self,
content=None, content: Optional[str] = None,
*, *,
tts=False, tts: bool = False,
file=None, file: Optional[File] = None,
files=None, files: Optional[Sequence[File]] = None,
stickers=None, stickers: Optional[Sequence[Union[GuildSticker, StickerItem]]] = None,
delete_after=None, delete_after: Optional[float] = None,
nonce=MISSING, nonce: Optional[Union[str, int]] = MISSING,
allowed_mentions=None, allowed_mentions: Optional[AllowedMentions] = None,
reference=None, reference: Optional[Union[Message, MessageReference, PartialMessage]] = None,
mention_author=None, mention_author: Optional[bool] = None,
suppress_embeds=False, suppress_embeds: bool = False,
): ) -> Message:
"""|coro| """|coro|
Sends a message to the destination with the content given. Sends a message to the destination with the content given.
@ -1283,8 +1293,8 @@ class Messageable:
**Specifying both parameters will lead to an exception**. **Specifying both parameters will lead to an exception**.
.. versionchanged:: 2.0 .. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising This function will now raise :exc:`TypeError` or
:exc:`ValueError` or :exc:`TypeError` in various cases. :exc:`ValueError` instead of ``InvalidArgument``.
Parameters Parameters
------------ ------------
@ -1362,17 +1372,17 @@ class Messageable:
nonce = str(utils.time_snowflake(datetime.utcnow())) nonce = str(utils.time_snowflake(datetime.utcnow()))
if stickers is not None: if stickers is not None:
stickers = [sticker.id for sticker in stickers] sticker_ids: SnowflakeList = [sticker.id for sticker in stickers]
else: else:
stickers = MISSING sticker_ids = MISSING
if reference is not None: if reference is not None:
try: try:
reference = reference.to_message_reference_dict() reference_dict = reference.to_message_reference_dict()
except AttributeError: except AttributeError:
raise TypeError('reference parameter must be Message, MessageReference, or PartialMessage') from None raise TypeError('reference parameter must be Message, MessageReference, or PartialMessage') from None
else: else:
reference = MISSING reference_dict = MISSING
if suppress_embeds: if suppress_embeds:
from .message import MessageFlags # circular import from .message import MessageFlags # circular import
@ -1388,10 +1398,10 @@ class Messageable:
files=files if files is not None else MISSING, files=files if files is not None else MISSING,
nonce=nonce, nonce=nonce,
allowed_mentions=allowed_mentions, allowed_mentions=allowed_mentions,
message_reference=reference, message_reference=reference_dict,
previous_allowed_mentions=previous_allowed_mention, previous_allowed_mentions=previous_allowed_mention,
mention_author=mention_author, mention_author=mention_author,
stickers=stickers, stickers=sticker_ids,
flags=flags, flags=flags,
) as params: ) as params:
data = await state.http.send_message(channel.id, params=params) data = await state.http.send_message(channel.id, params=params)
@ -1823,7 +1833,8 @@ class Connectable(Protocol):
if cls is MISSING: if cls is MISSING:
cls = VoiceClient cls = VoiceClient
voice = cls(state.client, channel) # The type checker doesn't understand that VoiceClient *is* T here.
voice: T = cls(state.client, channel) # type: ignore
if not isinstance(voice, VoiceProtocol): if not isinstance(voice, VoiceProtocol):
raise TypeError('Type must meet VoiceProtocol abstract base class') raise TypeError('Type must meet VoiceProtocol abstract base class')

59
discord/activity.py

@ -99,6 +99,8 @@ if TYPE_CHECKING:
ActivityButton, ActivityButton,
) )
from .state import ConnectionState
class BaseActivity: class BaseActivity:
"""The base activity that all user-settable activities inherit from. """The base activity that all user-settable activities inherit from.
@ -121,7 +123,7 @@ class BaseActivity:
__slots__ = ('_created_at',) __slots__ = ('_created_at',)
def __init__(self, **kwargs): def __init__(self, **kwargs: Any) -> None:
self._created_at: Optional[float] = kwargs.pop('created_at', None) self._created_at: Optional[float] = kwargs.pop('created_at', None)
@property @property
@ -216,7 +218,7 @@ class Activity(BaseActivity):
'buttons', 'buttons',
) )
def __init__(self, **kwargs): def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
self.state: Optional[str] = kwargs.pop('state', None) self.state: Optional[str] = kwargs.pop('state', None)
self.details: Optional[str] = kwargs.pop('details', None) self.details: Optional[str] = kwargs.pop('details', None)
@ -377,7 +379,7 @@ class Game(BaseActivity):
__slots__ = ('name', '_end', '_start') __slots__ = ('name', '_end', '_start')
def __init__(self, name: str, **extra): def __init__(self, name: str, **extra: Any) -> None:
super().__init__(**extra) super().__init__(**extra)
self.name: str = name self.name: str = name
@ -434,10 +436,10 @@ class Game(BaseActivity):
} }
# fmt: on # fmt: on
def __eq__(self, other: Any) -> bool: def __eq__(self, other: object) -> bool:
return isinstance(other, Game) and other.name == self.name return isinstance(other, Game) and other.name == self.name
def __ne__(self, other: Any) -> bool: def __ne__(self, other: object) -> bool:
return not self.__eq__(other) return not self.__eq__(other)
def __hash__(self) -> int: def __hash__(self) -> int:
@ -491,7 +493,7 @@ class Streaming(BaseActivity):
__slots__ = ('platform', 'name', 'game', 'url', 'details', 'assets') __slots__ = ('platform', 'name', 'game', 'url', 'details', 'assets')
def __init__(self, *, name: Optional[str], url: str, **extra: Any): def __init__(self, *, name: Optional[str], url: str, **extra: Any) -> None:
super().__init__(**extra) super().__init__(**extra)
self.platform: Optional[str] = name self.platform: Optional[str] = name
self.name: Optional[str] = extra.pop('details', name) self.name: Optional[str] = extra.pop('details', name)
@ -515,7 +517,7 @@ class Streaming(BaseActivity):
return f'<Streaming name={self.name!r}>' return f'<Streaming name={self.name!r}>'
@property @property
def twitch_name(self): def twitch_name(self) -> Optional[str]:
"""Optional[:class:`str`]: If provided, the twitch name of the user streaming. """Optional[:class:`str`]: If provided, the twitch name of the user streaming.
This corresponds to the ``large_image`` key of the :attr:`Streaming.assets` This corresponds to the ``large_image`` key of the :attr:`Streaming.assets`
@ -542,10 +544,10 @@ class Streaming(BaseActivity):
ret['details'] = self.details ret['details'] = self.details
return ret return ret
def __eq__(self, other: Any) -> bool: def __eq__(self, other: object) -> bool:
return isinstance(other, Streaming) and other.name == self.name and other.url == self.url return isinstance(other, Streaming) and other.name == self.name and other.url == self.url
def __ne__(self, other: Any) -> bool: def __ne__(self, other: object) -> bool:
return not self.__eq__(other) return not self.__eq__(other)
def __hash__(self) -> int: def __hash__(self) -> int:
@ -577,14 +579,14 @@ class Spotify:
__slots__ = ('_state', '_details', '_timestamps', '_assets', '_party', '_sync_id', '_session_id', '_created_at') __slots__ = ('_state', '_details', '_timestamps', '_assets', '_party', '_sync_id', '_session_id', '_created_at')
def __init__(self, **data): def __init__(self, **data: Any) -> None:
self._state: str = data.pop('state', '') self._state: str = data.pop('state', '')
self._details: str = data.pop('details', '') self._details: str = data.pop('details', '')
self._timestamps: Dict[str, int] = data.pop('timestamps', {}) self._timestamps: ActivityTimestamps = data.pop('timestamps', {})
self._assets: ActivityAssets = data.pop('assets', {}) self._assets: ActivityAssets = data.pop('assets', {})
self._party: ActivityParty = data.pop('party', {}) self._party: ActivityParty = data.pop('party', {})
self._sync_id: str = data.pop('sync_id') self._sync_id: str = data.pop('sync_id', '')
self._session_id: str = data.pop('session_id') self._session_id: Optional[str] = data.pop('session_id')
self._created_at: Optional[float] = data.pop('created_at', None) self._created_at: Optional[float] = data.pop('created_at', None)
@property @property
@ -636,7 +638,7 @@ class Spotify:
""":class:`str`: The activity's name. This will always return "Spotify".""" """:class:`str`: The activity's name. This will always return "Spotify"."""
return 'Spotify' return 'Spotify'
def __eq__(self, other: Any) -> bool: def __eq__(self, other: object) -> bool:
return ( return (
isinstance(other, Spotify) isinstance(other, Spotify)
and other._session_id == self._session_id and other._session_id == self._session_id
@ -644,7 +646,7 @@ class Spotify:
and other.start == self.start and other.start == self.start
) )
def __ne__(self, other: Any) -> bool: def __ne__(self, other: object) -> bool:
return not self.__eq__(other) return not self.__eq__(other)
def __hash__(self) -> int: def __hash__(self) -> int:
@ -705,12 +707,14 @@ class Spotify:
@property @property
def start(self) -> datetime.datetime: def start(self) -> datetime.datetime:
""":class:`datetime.datetime`: When the user started playing this song in UTC.""" """:class:`datetime.datetime`: When the user started playing this song in UTC."""
return datetime.datetime.fromtimestamp(self._timestamps['start'] / 1000, tz=datetime.timezone.utc) # the start key will be present here
return datetime.datetime.fromtimestamp(self._timestamps['start'] / 1000, tz=datetime.timezone.utc) # type: ignore
@property @property
def end(self) -> datetime.datetime: def end(self) -> datetime.datetime:
""":class:`datetime.datetime`: When the user will stop playing this song in UTC.""" """:class:`datetime.datetime`: When the user will stop playing this song in UTC."""
return datetime.datetime.fromtimestamp(self._timestamps['end'] / 1000, tz=datetime.timezone.utc) # the end key will be present here
return datetime.datetime.fromtimestamp(self._timestamps['end'] / 1000, tz=datetime.timezone.utc) # type: ignore
@property @property
def duration(self) -> datetime.timedelta: def duration(self) -> datetime.timedelta:
@ -820,10 +824,10 @@ class CustomActivity(BaseActivity):
o['expires_at'] = expiry.isoformat() o['expires_at'] = expiry.isoformat()
return o return o
def __eq__(self, other: Any) -> bool: def __eq__(self, other: object) -> bool:
return isinstance(other, CustomActivity) and other.name == self.name and other.emoji == self.emoji return isinstance(other, CustomActivity) and other.name == self.name and other.emoji == self.emoji
def __ne__(self, other: Any) -> bool: def __ne__(self, other: object) -> bool:
return not self.__eq__(other) return not self.__eq__(other)
def __hash__(self) -> int: def __hash__(self) -> int:
@ -845,16 +849,16 @@ ActivityTypes = Union[Activity, Game, CustomActivity, Streaming, Spotify]
@overload @overload
def create_activity(data: ActivityPayload) -> ActivityTypes: def create_activity(data: ActivityPayload, state: ConnectionState) -> ActivityTypes:
... ...
@overload @overload
def create_activity(data: None) -> None: def create_activity(data: None, state: ConnectionState) -> None:
... ...
def create_activity(data: Optional[ActivityPayload]) -> Optional[ActivityTypes]: def create_activity(data: Optional[ActivityPayload], state: ConnectionState) -> Optional[ActivityTypes]:
if not data: if not data:
return None return None
@ -867,10 +871,10 @@ def create_activity(data: Optional[ActivityPayload]) -> Optional[ActivityTypes]:
try: try:
name = data.pop('name') name = data.pop('name')
except KeyError: except KeyError:
return Activity(**data) ret = Activity(**data)
else: else:
# We removed the name key from data already # We removed the name key from data already
return CustomActivity(name=name, **data) # type: ignore ret = CustomActivity(name=name, **data) # type: ignore
elif game_type is ActivityType.streaming: elif game_type is ActivityType.streaming:
if 'url' in data: if 'url' in data:
# The url won't be None here # The url won't be None here
@ -878,7 +882,12 @@ def create_activity(data: Optional[ActivityPayload]) -> Optional[ActivityTypes]:
return Activity(**data) return Activity(**data)
elif game_type is ActivityType.listening and 'sync_id' in data and 'session_id' in data: elif game_type is ActivityType.listening and 'sync_id' in data and 'session_id' in data:
return Spotify(**data) return Spotify(**data)
return Activity(**data) else:
ret = Activity(**data)
if isinstance(ret.emoji, PartialEmoji):
ret.emoji._state = state
return ret
def create_settings_activity(*, data, state): def create_settings_activity(*, data, state):

69
discord/asset.py

@ -39,6 +39,13 @@ __all__ = (
# fmt: on # fmt: on
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import Self
from .state import ConnectionState
from .webhook.async_ import _WebhookState
_State = Union[ConnectionState, _WebhookState]
ValidStaticFormatTypes = Literal['webp', 'jpeg', 'jpg', 'png'] ValidStaticFormatTypes = Literal['webp', 'jpeg', 'jpg', 'png']
ValidAssetFormatTypes = Literal['webp', 'jpeg', 'jpg', 'png', 'gif'] ValidAssetFormatTypes = Literal['webp', 'jpeg', 'jpg', 'png', 'gif']
@ -77,7 +84,7 @@ class AssetMixin:
return await self._state.http.get_from_cdn(self.url) return await self._state.http.get_from_cdn(self.url)
async def save(self, fp: Union[str, bytes, os.PathLike, io.BufferedIOBase], *, seek_begin: bool = True) -> int: async def save(self, fp: Union[str, bytes, os.PathLike[Any], io.BufferedIOBase], *, seek_begin: bool = True) -> int:
"""|coro| """|coro|
Saves this asset into a file-like object. Saves this asset into a file-like object.
@ -153,14 +160,14 @@ class Asset(AssetMixin):
BASE = 'https://cdn.discordapp.com' BASE = 'https://cdn.discordapp.com'
def __init__(self, state, *, url: str, key: str, animated: bool = False): def __init__(self, state: _State, *, url: str, key: str, animated: bool = False) -> None:
self._state = state self._state: _State = state
self._url = url self._url: str = url
self._animated = animated self._animated: bool = animated
self._key = key self._key: str = key
@classmethod @classmethod
def _from_default_avatar(cls, state, index: int) -> Asset: def _from_default_avatar(cls, state: _State, index: int) -> Self:
return cls( return cls(
state, state,
url=f'{cls.BASE}/embed/avatars/{index}.png', url=f'{cls.BASE}/embed/avatars/{index}.png',
@ -169,7 +176,7 @@ class Asset(AssetMixin):
) )
@classmethod @classmethod
def _from_avatar(cls, state, user_id: int, avatar: str) -> Asset: def _from_avatar(cls, state: _State, user_id: int, avatar: str) -> Self:
animated = avatar.startswith('a_') animated = avatar.startswith('a_')
format = 'gif' if animated else 'png' format = 'gif' if animated else 'png'
return cls( return cls(
@ -180,7 +187,7 @@ class Asset(AssetMixin):
) )
@classmethod @classmethod
def _from_guild_avatar(cls, state, guild_id: int, member_id: int, avatar: str) -> Asset: def _from_guild_avatar(cls, state: _State, guild_id: int, member_id: int, avatar: str) -> Self:
animated = avatar.startswith('a_') animated = avatar.startswith('a_')
format = 'gif' if animated else 'png' format = 'gif' if animated else 'png'
return cls( return cls(
@ -191,7 +198,7 @@ class Asset(AssetMixin):
) )
@classmethod @classmethod
def _from_icon(cls, state, object_id: int, icon_hash: str, path: str) -> Asset: def _from_icon(cls, state: _State, object_id: int, icon_hash: str, path: str) -> Self:
return cls( return cls(
state, state,
url=f'{cls.BASE}/{path}-icons/{object_id}/{icon_hash}.png?size=1024', url=f'{cls.BASE}/{path}-icons/{object_id}/{icon_hash}.png?size=1024',
@ -200,7 +207,7 @@ class Asset(AssetMixin):
) )
@classmethod @classmethod
def _from_cover_image(cls, state, object_id: int, cover_image_hash: str) -> Asset: def _from_cover_image(cls, state: _State, object_id: int, cover_image_hash: str) -> Self:
return cls( return cls(
state, state,
url=f'{cls.BASE}/app-assets/{object_id}/store/{cover_image_hash}.png?size=1024', url=f'{cls.BASE}/app-assets/{object_id}/store/{cover_image_hash}.png?size=1024',
@ -209,7 +216,7 @@ class Asset(AssetMixin):
) )
@classmethod @classmethod
def _from_scheduled_event_cover_image(cls, state, scheduled_event_id: int, cover_image_hash: str) -> Asset: def _from_scheduled_event_cover_image(cls, state: _State, scheduled_event_id: int, cover_image_hash: str) -> Self:
return cls( return cls(
state, state,
url=f'{cls.BASE}/guild-events/{scheduled_event_id}/{cover_image_hash}.png?size=1024', url=f'{cls.BASE}/guild-events/{scheduled_event_id}/{cover_image_hash}.png?size=1024',
@ -218,7 +225,7 @@ class Asset(AssetMixin):
) )
@classmethod @classmethod
def _from_guild_image(cls, state, guild_id: int, image: str, path: str) -> Asset: def _from_guild_image(cls, state: _State, guild_id: int, image: str, path: str) -> Self:
animated = image.startswith('a_') animated = image.startswith('a_')
format = 'gif' if animated else 'png' format = 'gif' if animated else 'png'
return cls( return cls(
@ -229,7 +236,7 @@ class Asset(AssetMixin):
) )
@classmethod @classmethod
def _from_guild_icon(cls, state, guild_id: int, icon_hash: str) -> Asset: def _from_guild_icon(cls, state: _State, guild_id: int, icon_hash: str) -> Self:
animated = icon_hash.startswith('a_') animated = icon_hash.startswith('a_')
format = 'gif' if animated else 'png' format = 'gif' if animated else 'png'
return cls( return cls(
@ -240,7 +247,7 @@ class Asset(AssetMixin):
) )
@classmethod @classmethod
def _from_sticker_banner(cls, state, banner: int) -> Asset: def _from_sticker_banner(cls, state: _State, banner: int) -> Self:
return cls( return cls(
state, state,
url=f'{cls.BASE}/app-assets/710982414301790216/store/{banner}.png', url=f'{cls.BASE}/app-assets/710982414301790216/store/{banner}.png',
@ -249,7 +256,7 @@ class Asset(AssetMixin):
) )
@classmethod @classmethod
def _from_user_banner(cls, state, user_id: int, banner_hash: str) -> Asset: def _from_user_banner(cls, state: _State, user_id: int, banner_hash: str) -> Self:
animated = banner_hash.startswith('a_') animated = banner_hash.startswith('a_')
format = 'gif' if animated else 'png' format = 'gif' if animated else 'png'
return cls( return cls(
@ -274,14 +281,14 @@ class Asset(AssetMixin):
def __len__(self) -> int: def __len__(self) -> int:
return len(self._url) return len(self._url)
def __repr__(self): def __repr__(self) -> str:
shorten = self._url.replace(self.BASE, '') shorten = self._url.replace(self.BASE, '')
return f'<Asset url={shorten!r}>' return f'<Asset url={shorten!r}>'
def __eq__(self, other): def __eq__(self, other: object) -> bool:
return isinstance(other, Asset) and self._url == other._url return isinstance(other, Asset) and self._url == other._url
def __hash__(self): def __hash__(self) -> int:
return hash(self._url) return hash(self._url)
@property @property
@ -304,12 +311,12 @@ class Asset(AssetMixin):
size: int = MISSING, size: int = MISSING,
format: ValidAssetFormatTypes = MISSING, format: ValidAssetFormatTypes = MISSING,
static_format: ValidStaticFormatTypes = MISSING, static_format: ValidStaticFormatTypes = MISSING,
) -> Asset: ) -> Self:
"""Returns a new asset with the passed components replaced. """Returns a new asset with the passed components replaced.
.. versionchanged:: 2.0 .. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising This function will now raise :exc:`ValueError` instead of
:exc:`ValueError`. ``InvalidArgument``.
Parameters Parameters
----------- -----------
@ -359,12 +366,12 @@ class Asset(AssetMixin):
url = str(url) url = str(url)
return Asset(state=self._state, url=url, key=self._key, animated=self._animated) return Asset(state=self._state, url=url, key=self._key, animated=self._animated)
def with_size(self, size: int, /) -> Asset: def with_size(self, size: int, /) -> Self:
"""Returns a new asset with the specified size. """Returns a new asset with the specified size.
.. versionchanged:: 2.0 .. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising This function will now raise :exc:`ValueError` instead of
:exc:`ValueError`. ``InvalidArgument``.
Parameters Parameters
------------ ------------
@ -387,12 +394,12 @@ class Asset(AssetMixin):
url = str(yarl.URL(self._url).with_query(size=size)) url = str(yarl.URL(self._url).with_query(size=size))
return Asset(state=self._state, url=url, key=self._key, animated=self._animated) return Asset(state=self._state, url=url, key=self._key, animated=self._animated)
def with_format(self, format: ValidAssetFormatTypes, /) -> Asset: def with_format(self, format: ValidAssetFormatTypes, /) -> Self:
"""Returns a new asset with the specified format. """Returns a new asset with the specified format.
.. versionchanged:: 2.0 .. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising This function will now raise :exc:`ValueError` instead of
:exc:`ValueError`. ``InvalidArgument``.
Parameters Parameters
------------ ------------
@ -422,15 +429,15 @@ class Asset(AssetMixin):
url = str(url.with_path(f'{path}.{format}').with_query(url.raw_query_string)) url = str(url.with_path(f'{path}.{format}').with_query(url.raw_query_string))
return Asset(state=self._state, url=url, key=self._key, animated=self._animated) return Asset(state=self._state, url=url, key=self._key, animated=self._animated)
def with_static_format(self, format: ValidStaticFormatTypes, /) -> Asset: def with_static_format(self, format: ValidStaticFormatTypes, /) -> Self:
"""Returns a new asset with the specified static format. """Returns a new asset with the specified static format.
This only changes the format if the underlying asset is This only changes the format if the underlying asset is
not animated. Otherwise, the asset is not changed. not animated. Otherwise, the asset is not changed.
.. versionchanged:: 2.0 .. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising This function will now raise :exc:`ValueError` instead of
:exc:`ValueError`. ``InvalidArgument``.
Parameters Parameters
------------ ------------

36
discord/audit_logs.py

@ -49,12 +49,13 @@ if TYPE_CHECKING:
from .guild import Guild from .guild import Guild
from .member import Member from .member import Member
from .role import Role from .role import Role
from .scheduled_event import ScheduledEvent
from .state import ConnectionState
from .types.audit_log import ( from .types.audit_log import (
AuditLogChange as AuditLogChangePayload, AuditLogChange as AuditLogChangePayload,
AuditLogEntry as AuditLogEntryPayload, AuditLogEntry as AuditLogEntryPayload,
) )
from .types.channel import ( from .types.channel import (
PartialChannel as PartialChannelPayload,
PermissionOverwrite as PermissionOverwritePayload, PermissionOverwrite as PermissionOverwritePayload,
) )
from .types.invite import Invite as InvitePayload from .types.invite import Invite as InvitePayload
@ -241,8 +242,8 @@ class AuditLogChanges:
# fmt: on # fmt: on
def __init__(self, entry: AuditLogEntry, data: List[AuditLogChangePayload]): def __init__(self, entry: AuditLogEntry, data: List[AuditLogChangePayload]):
self.before = AuditLogDiff() self.before: AuditLogDiff = AuditLogDiff()
self.after = AuditLogDiff() self.after: AuditLogDiff = AuditLogDiff()
for elem in data: for elem in data:
attr = elem['key'] attr = elem['key']
@ -389,16 +390,17 @@ class AuditLogEntry(Hashable):
""" """
def __init__(self, *, users: Dict[int, User], data: AuditLogEntryPayload, guild: Guild): def __init__(self, *, users: Dict[int, User], data: AuditLogEntryPayload, guild: Guild):
self._state = guild._state self._state: ConnectionState = guild._state
self.guild = guild self.guild: Guild = guild
self._users = users self._users: Dict[int, User] = users
self._from_data(data) self._from_data(data)
def _from_data(self, data: AuditLogEntryPayload) -> None: def _from_data(self, data: AuditLogEntryPayload) -> None:
self.action = enums.try_enum(enums.AuditLogAction, data['action_type']) self.action: enums.AuditLogAction = enums.try_enum(enums.AuditLogAction, data['action_type'])
self.id = int(data['id']) self.id: int = int(data['id'])
self.reason = data.get('reason') # This key is technically not usually present
self.reason: Optional[str] = data.get('reason')
extra = data.get('options') extra = data.get('options')
# fmt: off # fmt: off
@ -462,10 +464,13 @@ class AuditLogEntry(Hashable):
self._changes = data.get('changes', []) self._changes = data.get('changes', [])
user_id = utils._get_as_snowflake(data, 'user_id') user_id = utils._get_as_snowflake(data, 'user_id')
self.user = user_id and self._get_member(user_id) self.user: Optional[Union[User, Member]] = self._get_member(user_id)
self._target_id = utils._get_as_snowflake(data, 'target_id') self._target_id = utils._get_as_snowflake(data, 'target_id')
def _get_member(self, user_id: int) -> Union[Member, User, None]: def _get_member(self, user_id: Optional[int]) -> Union[Member, User, None]:
if user_id is None:
return None
return self.guild.get_member(user_id) or self._users.get(user_id) return self.guild.get_member(user_id) or self._users.get(user_id)
def __repr__(self) -> str: def __repr__(self) -> str:
@ -478,12 +483,14 @@ class AuditLogEntry(Hashable):
@utils.cached_property @utils.cached_property
def target(self) -> TargetType: def target(self) -> TargetType:
if self._target_id is None or self.action.target_type is None: if self.action.target_type is None:
return None return None
try: try:
converter = getattr(self, '_convert_target_' + self.action.target_type) converter = getattr(self, '_convert_target_' + self.action.target_type)
except AttributeError: except AttributeError:
if self._target_id is None:
return None
return Object(id=self._target_id) return Object(id=self._target_id)
else: else:
return converter(self._target_id) return converter(self._target_id)
@ -522,7 +529,7 @@ class AuditLogEntry(Hashable):
def _convert_target_role(self, target_id: int) -> Union[Role, Object]: def _convert_target_role(self, target_id: int) -> Union[Role, Object]:
return self.guild.get_role(target_id) or Object(id=target_id) return self.guild.get_role(target_id) or Object(id=target_id)
def _convert_target_invite(self, target_id: int) -> Invite: def _convert_target_invite(self, target_id: None) -> Invite:
# Invites have target_id set to null # Invites have target_id set to null
# So figure out which change has the full invite data # So figure out which change has the full invite data
changeset = self.before if self.action is enums.AuditLogAction.invite_delete else self.after changeset = self.before if self.action is enums.AuditLogAction.invite_delete else self.after
@ -557,3 +564,6 @@ class AuditLogEntry(Hashable):
def _convert_target_thread(self, target_id: int) -> Union[Thread, Object]: def _convert_target_thread(self, target_id: int) -> Union[Thread, Object]:
return self.guild.get_thread(target_id) or Object(id=target_id) return self.guild.get_thread(target_id) or Object(id=target_id)
def _convert_target_guild_scheduled_event(self, target_id: int) -> Union[ScheduledEvent, Object]:
return self.guild.get_scheduled_event(target_id) or Object(id=target_id)

60
discord/channel.py

@ -46,7 +46,6 @@ from .permissions import PermissionOverwrite, Permissions
from .enums import ChannelType, PrivacyLevel, try_enum, VideoQualityMode from .enums import ChannelType, PrivacyLevel, try_enum, VideoQualityMode
from .calls import PrivateCall, GroupCall from .calls import PrivateCall, GroupCall
from .mixins import Hashable from .mixins import Hashable
from .object import Object
from . import utils from . import utils
from .utils import MISSING from .utils import MISSING
from .asset import Asset from .asset import Asset
@ -194,7 +193,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
self.last_message_id: Optional[int] = utils._get_as_snowflake(data, 'last_message_id') self.last_message_id: Optional[int] = utils._get_as_snowflake(data, 'last_message_id')
self._fill_overwrites(data) self._fill_overwrites(data)
async def _get_channel(self): async def _get_channel(self) -> Self:
return self return self
@property @property
@ -279,7 +278,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
async def edit(self) -> Optional[TextChannel]: async def edit(self) -> Optional[TextChannel]:
... ...
async def edit(self, *, reason=None, **options): async def edit(self, *, reason: Optional[str] = None, **options: Any) -> Optional[TextChannel]:
"""|coro| """|coro|
Edits the channel. Edits the channel.
@ -297,8 +296,8 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
Edits are no longer in-place, the newly edited channel is returned instead. Edits are no longer in-place, the newly edited channel is returned instead.
.. versionchanged:: 2.0 .. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising This function will now raise :exc:`TypeError` or
:exc:`ValueError` or :exc:`TypeError` in various cases. :exc:`ValueError` instead of ``InvalidArgument``.
Parameters Parameters
---------- ----------
@ -574,8 +573,8 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
.. versionadded:: 1.3 .. versionadded:: 1.3
.. versionchanged:: 2.0 .. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising This function will now raise :exc:`TypeError` instead of
:exc:`TypeError`. ``InvalidArgument``.
Parameters Parameters
----------- -----------
@ -696,7 +695,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
reason: Optional[:class:`str`] reason: Optional[:class:`str`]
The reason for creating a new thread. Shows up on the audit log. The reason for creating a new thread. Shows up on the audit log.
invitable: :class:`bool` invitable: :class:`bool`
Whether non-modertators can add users to the thread. Only applicable to private threads. Whether non-moderators can add users to the thread. Only applicable to private threads.
Defaults to ``True``. Defaults to ``True``.
slowmode_delay: Optional[:class:`int`] slowmode_delay: Optional[:class:`int`]
Specifies the slowmode rate limit for user in this channel, in seconds. Specifies the slowmode rate limit for user in this channel, in seconds.
@ -863,7 +862,7 @@ class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hasha
return self.guild.id, self.id return self.guild.id, self.id
def _update(self, guild: Guild, data: Union[VoiceChannelPayload, StageChannelPayload]) -> None: def _update(self, guild: Guild, data: Union[VoiceChannelPayload, StageChannelPayload]) -> None:
self.guild = guild self.guild: Guild = guild
self.name: str = data['name'] self.name: str = data['name']
self.rtc_region: Optional[str] = data.get('rtc_region') self.rtc_region: Optional[str] = data.get('rtc_region')
self.video_quality_mode: VideoQualityMode = try_enum(VideoQualityMode, data.get('video_quality_mode', 1)) self.video_quality_mode: VideoQualityMode = try_enum(VideoQualityMode, data.get('video_quality_mode', 1))
@ -1031,7 +1030,7 @@ class VoiceChannel(VocalGuildChannel):
async def edit(self) -> Optional[VoiceChannel]: async def edit(self) -> Optional[VoiceChannel]:
... ...
async def edit(self, *, reason=None, **options): async def edit(self, *, reason: Optional[str] = None, **options: Any) -> Optional[VoiceChannel]:
"""|coro| """|coro|
Edits the channel. Edits the channel.
@ -1049,8 +1048,8 @@ class VoiceChannel(VocalGuildChannel):
The ``region`` parameter now accepts :class:`str` instead of an enum. The ``region`` parameter now accepts :class:`str` instead of an enum.
.. versionchanged:: 2.0 .. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising This function will now raise :exc:`TypeError` instead of
:exc:`TypeError`. ``InvalidArgument``.
Parameters Parameters
---------- ----------
@ -1175,7 +1174,7 @@ class StageChannel(VocalGuildChannel):
def _update(self, guild: Guild, data: StageChannelPayload) -> None: def _update(self, guild: Guild, data: StageChannelPayload) -> None:
super()._update(guild, data) super()._update(guild, data)
self.topic = data.get('topic') self.topic: Optional[str] = data.get('topic')
@property @property
def requesting_to_speak(self) -> List[Member]: def requesting_to_speak(self) -> List[Member]:
@ -1316,7 +1315,7 @@ class StageChannel(VocalGuildChannel):
async def edit(self) -> Optional[StageChannel]: async def edit(self) -> Optional[StageChannel]:
... ...
async def edit(self, *, reason=None, **options): async def edit(self, *, reason: Optional[str] = None, **options: Any) -> Optional[StageChannel]:
"""|coro| """|coro|
Edits the channel. Edits the channel.
@ -1334,8 +1333,8 @@ class StageChannel(VocalGuildChannel):
The ``region`` parameter now accepts :class:`str` instead of an enum. The ``region`` parameter now accepts :class:`str` instead of an enum.
.. versionchanged:: 2.0 .. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising This function will now raise :exc:`TypeError` instead of
:exc:`TypeError`. ``InvalidArgument``.
Parameters Parameters
---------- ----------
@ -1477,7 +1476,7 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
async def edit(self) -> Optional[CategoryChannel]: async def edit(self) -> Optional[CategoryChannel]:
... ...
async def edit(self, *, reason=None, **options): async def edit(self, *, reason: Optional[str] = None, **options: Any) -> Optional[CategoryChannel]:
"""|coro| """|coro|
Edits the channel. Edits the channel.
@ -1492,8 +1491,8 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
Edits are no longer in-place, the newly edited channel is returned instead. Edits are no longer in-place, the newly edited channel is returned instead.
.. versionchanged:: 2.0 .. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising This function will now raise :exc:`TypeError` or
:exc:`ValueError` or :exc:`TypeError` in various cases. :exc:`ValueError` instead of ``InvalidArgument``.
Parameters Parameters
---------- ----------
@ -1533,7 +1532,7 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore
@utils.copy_doc(discord.abc.GuildChannel.move) @utils.copy_doc(discord.abc.GuildChannel.move)
async def move(self, **kwargs): async def move(self, **kwargs: Any) -> None:
kwargs.pop('category', None) kwargs.pop('category', None)
await super().move(**kwargs) await super().move(**kwargs)
@ -1717,9 +1716,9 @@ class StoreChannel(discord.abc.GuildChannel, Hashable):
position: int = ..., position: int = ...,
nsfw: bool = ..., nsfw: bool = ...,
sync_permissions: bool = ..., sync_permissions: bool = ...,
category: Optional[CategoryChannel], category: Optional[CategoryChannel] = ...,
reason: Optional[str], reason: Optional[str] = ...,
overwrites: Mapping[Union[Role, Member], PermissionOverwrite], overwrites: Mapping[Union[Role, Member], PermissionOverwrite] = ...,
) -> Optional[StoreChannel]: ) -> Optional[StoreChannel]:
... ...
@ -1727,7 +1726,7 @@ class StoreChannel(discord.abc.GuildChannel, Hashable):
async def edit(self) -> Optional[StoreChannel]: async def edit(self) -> Optional[StoreChannel]:
... ...
async def edit(self, *, reason=None, **options): async def edit(self, *, reason: Optional[str] = None, **options: Any) -> Optional[StoreChannel]:
"""|coro| """|coro|
Edits the channel. Edits the channel.
@ -1739,8 +1738,8 @@ class StoreChannel(discord.abc.GuildChannel, Hashable):
Edits are no longer in-place, the newly edited channel is returned instead. Edits are no longer in-place, the newly edited channel is returned instead.
.. versionchanged:: 2.0 .. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising This function will now raise :exc:`TypeError` or
:exc:`ValueError` or :exc:`TypeError` in various cases. :exc:`ValueError` instead of ``InvalidArgument``.
Parameters Parameters
---------- ----------
@ -1844,7 +1843,7 @@ class DMChannel(discord.abc.Messageable, discord.abc.Connectable, Hashable):
def _add_call(self, **kwargs) -> PrivateCall: def _add_call(self, **kwargs) -> PrivateCall:
return PrivateCall(**kwargs) return PrivateCall(**kwargs)
async def _get_channel(self): async def _get_channel(self) -> Self:
await self._state.access_private_channel(self.id) await self._state.access_private_channel(self.id)
return self return self
@ -2066,7 +2065,7 @@ class GroupChannel(discord.abc.Messageable, discord.abc.Connectable, Hashable):
def _get_voice_state_pair(self) -> Tuple[int, int]: def _get_voice_state_pair(self) -> Tuple[int, int]:
return self.me.id, self.id return self.me.id, self.id
async def _get_channel(self): async def _get_channel(self) -> Self:
await self._state.access_private_channel(self.id) await self._state.access_private_channel(self.id)
return self return self
@ -2331,7 +2330,7 @@ class GroupChannel(discord.abc.Messageable, discord.abc.Connectable, Hashable):
class PartialMessageable(discord.abc.Messageable, Hashable): class PartialMessageable(discord.abc.Messageable, Hashable):
"""Represents a partial messageable to aid with working messageable channels when """Represents a partial messageable to aid with working messageable channels when
only a channel ID are present. only a channel ID is present.
The only way to construct this class is through :meth:`Client.get_partial_messageable`. The only way to construct this class is through :meth:`Client.get_partial_messageable`.
@ -2367,6 +2366,9 @@ class PartialMessageable(discord.abc.Messageable, Hashable):
self.type: Optional[ChannelType] = type self.type: Optional[ChannelType] = type
self.last_message_id: Optional[int] = None self.last_message_id: Optional[int] = None
def __repr__(self) -> str:
return f'<{self.__class__.__name__} id={self.id} type={self.type!r}>'
async def _get_channel(self) -> PartialMessageable: async def _get_channel(self) -> PartialMessageable:
return self return self

259
discord/client.py

@ -27,7 +27,6 @@ from __future__ import annotations
import asyncio import asyncio
import datetime import datetime
import logging import logging
import signal
import sys import sys
import traceback import traceback
from typing import ( from typing import (
@ -42,6 +41,7 @@ from typing import (
Sequence, Sequence,
TYPE_CHECKING, TYPE_CHECKING,
Tuple, Tuple,
Type,
TypeVar, TypeVar,
Union, Union,
) )
@ -80,6 +80,8 @@ from .team import Team
from .member import _ClientStatus from .member import _ClientStatus
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import Self
from types import TracebackType
from .types.guild import Guild as GuildPayload from .types.guild import Guild as GuildPayload
from .guild import GuildChannel from .guild import GuildChannel
from .abc import PrivateChannel, GuildChannel, Snowflake, SnowflakeTime from .abc import PrivateChannel, GuildChannel, Snowflake, SnowflakeTime
@ -97,43 +99,22 @@ __all__ = (
Coro = TypeVar('Coro', bound=Callable[..., Coroutine[Any, Any, Any]]) Coro = TypeVar('Coro', bound=Callable[..., Coroutine[Any, Any, Any]])
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
def _cancel_tasks(loop: asyncio.AbstractEventLoop) -> None: class _LoopSentinel:
tasks = {t for t in asyncio.all_tasks(loop=loop) if not t.done()} __slots__ = ()
if not tasks:
return
_log.info('Cleaning up after %d tasks.', len(tasks))
for task in tasks:
task.cancel()
loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
_log.info('All tasks finished cancelling.')
for task in tasks: def __getattr__(self, attr: str) -> None:
if task.cancelled(): msg = (
continue 'loop attribute cannot be accessed in non-async contexts. '
if task.exception() is not None: 'Consider using either an asynchronous main function and passing it to asyncio.run or '
loop.call_exception_handler( 'using asynchronous initialisation hooks such as Client.setup_hook'
{
'message': 'Unhandled exception during Client.run shutdown.',
'exception': task.exception(),
'task': task,
}
) )
raise AttributeError(msg)
def _cleanup_loop(loop: asyncio.AbstractEventLoop) -> None: _loop: Any = _LoopSentinel()
try:
_cancel_tasks(loop)
loop.run_until_complete(loop.shutdown_asyncgens())
finally:
_log.info('Closing the event loop.')
loop.close()
class Client: class Client:
@ -150,12 +131,6 @@ class Client:
.. versionchanged:: 1.3 .. versionchanged:: 1.3
Allow disabling the message cache and change the default size to ``1000``. Allow disabling the message cache and change the default size to ``1000``.
loop: Optional[:class:`asyncio.AbstractEventLoop`]
The :class:`asyncio.AbstractEventLoop` to use for asynchronous operations.
Defaults to ``None``, in which case the default event loop is used via
:func:`asyncio.get_event_loop()`.
connector: Optional[:class:`aiohttp.BaseConnector`]
The connector to use for connection pooling.
proxy: Optional[:class:`str`] proxy: Optional[:class:`str`]
Proxy URL. Proxy URL.
proxy_auth: Optional[:class:`aiohttp.BasicAuth`] proxy_auth: Optional[:class:`aiohttp.BasicAuth`]
@ -172,10 +147,9 @@ class Client:
.. versionadded:: 1.5 .. versionadded:: 1.5
request_guilds: :class:`bool` request_guilds: :class:`bool`
Whether to request guilds at startup (behaves similarly to the old Whether to request guilds at startup. Defaults to True.
guild_subscriptions option). Defaults to True.
.. versionadded:: 1.10 .. versionadded:: 2.0
status: Optional[:class:`.Status`] status: Optional[:class:`.Status`]
A status to start your presence with upon logging on to Discord. A status to start your presence with upon logging on to Discord.
activity: Optional[:class:`.BaseActivity`] activity: Optional[:class:`.BaseActivity`]
@ -209,36 +183,44 @@ class Client:
Whether to keep presences up-to-date across clients. Whether to keep presences up-to-date across clients.
The default behavior is ``True`` (what the client does). The default behavior is ``True`` (what the client does).
.. versionadded:: 2.0
http_trace: :class:`aiohttp.TraceConfig`
The trace configuration to use for tracking HTTP requests the library does using ``aiohttp``.
This allows you to check requests the library is using. For more information, check the
`aiohttp documentation <https://docs.aiohttp.org/en/stable/client_advanced.html#client-tracing>`_.
.. versionadded:: 2.0
Attributes Attributes
----------- -----------
ws ws
The websocket gateway the client is currently connected to. Could be ``None``. The websocket gateway the client is currently connected to. Could be ``None``.
loop: :class:`asyncio.AbstractEventLoop`
The event loop that the client uses for asynchronous operations.
""" """
def __init__( def __init__(self, **options: Any) -> None:
self, self.loop: asyncio.AbstractEventLoop = _loop
*, # self.ws is set in the connect method
loop: Optional[asyncio.AbstractEventLoop] = None,
**options: Any,
):
# Set in the connect method
self.ws: DiscordWebSocket = None # type: ignore self.ws: DiscordWebSocket = None # type: ignore
self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop
self._listeners: Dict[str, List[Tuple[asyncio.Future, Callable[..., bool]]]] = {} self._listeners: Dict[str, List[Tuple[asyncio.Future, Callable[..., bool]]]] = {}
connector: Optional[aiohttp.BaseConnector] = options.pop('connector', None)
proxy: Optional[str] = options.pop('proxy', None) proxy: Optional[str] = options.pop('proxy', None)
proxy_auth: Optional[aiohttp.BasicAuth] = options.pop('proxy_auth', None) proxy_auth: Optional[aiohttp.BasicAuth] = options.pop('proxy_auth', None)
unsync_clock: bool = options.pop('assume_unsync_clock', True) unsync_clock: bool = options.pop('assume_unsync_clock', True)
http_trace: Optional[aiohttp.TraceConfig] = options.pop('http_trace', None)
self.http: HTTPClient = HTTPClient( self.http: HTTPClient = HTTPClient(
connector, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock, loop=self.loop self.loop,
proxy=proxy,
proxy_auth=proxy_auth,
unsync_clock=unsync_clock,
http_trace=http_trace,
) )
self._handlers: Dict[str, Callable] = {'ready': self._handle_ready, 'connect': self._handle_connect} self._handlers: Dict[str, Callable[..., None]] = {
'ready': self._handle_ready,
'connect': self._handle_connect,
}
self._hooks: Dict[str, Callable] = { self._hooks: Dict[str, Callable[..., Coroutine[Any, Any, Any]]] = {
'before_identify': self._call_before_identify_hook, 'before_identify': self._call_before_identify_hook,
} }
@ -246,7 +228,7 @@ class Client:
self._sync_presences: bool = options.pop('sync_presence', True) self._sync_presences: bool = options.pop('sync_presence', True)
self._connection: ConnectionState = self._get_state(**options) self._connection: ConnectionState = self._get_state(**options)
self._closed: bool = False self._closed: bool = False
self._ready: asyncio.Event = asyncio.Event() self._ready: asyncio.Event = MISSING
self._client_status: _ClientStatus = _ClientStatus() self._client_status: _ClientStatus = _ClientStatus()
self._client_activities: Dict[Optional[str], Tuple[ActivityTypes, ...]] = { self._client_activities: Dict[Optional[str], Tuple[ActivityTypes, ...]] = {
@ -259,6 +241,19 @@ class Client:
VoiceClient.warn_nacl = False VoiceClient.warn_nacl = False
_log.warning('PyNaCl is not installed, voice will NOT be supported.') _log.warning('PyNaCl is not installed, voice will NOT be supported.')
async def __aenter__(self) -> Self:
await self._async_setup_hook()
return self
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
if not self.is_closed():
await self.close()
# Internals # Internals
def _get_state(self, **options: Any) -> ConnectionState: def _get_state(self, **options: Any) -> ConnectionState:
@ -350,7 +345,7 @@ class Client:
def is_ready(self) -> bool: def is_ready(self) -> bool:
""":class:`bool`: Specifies if the client's internal cache is ready for use.""" """:class:`bool`: Specifies if the client's internal cache is ready for use."""
return self._ready.is_set() return self._ready is not MISSING and self._ready.is_set()
async def _run_event( async def _run_event(
self, self,
@ -377,9 +372,10 @@ class Client:
**kwargs: Any, **kwargs: Any,
) -> asyncio.Task: ) -> asyncio.Task:
wrapped = self._run_event(coro, event_name, *args, **kwargs) wrapped = self._run_event(coro, event_name, *args, **kwargs)
return asyncio.create_task(wrapped, name=f'discord.py: {event_name}') # Schedules the task
return self.loop.create_task(wrapped, name=f'discord.py: {event_name}')
def dispatch(self, event: str, *args: Any, **kwargs: Any) -> None: def dispatch(self, event: str, /, *args: Any, **kwargs: Any) -> None:
_log.debug('Dispatching event %s.', event) _log.debug('Dispatching event %s.', event)
method = 'on_' + event method = 'on_' + event
@ -419,7 +415,7 @@ class Client:
else: else:
self._schedule_event(coro, method, *args, **kwargs) self._schedule_event(coro, method, *args, **kwargs)
async def on_error(self, event_method: str, *args: Any, **kwargs: Any) -> None: async def on_error(self, event_method: str, /, *args: Any, **kwargs: Any) -> None:
"""|coro| """|coro|
The default error handler provided by the client. The default error handler provided by the client.
@ -427,6 +423,10 @@ class Client:
By default this prints to :data:`sys.stderr` however it could be By default this prints to :data:`sys.stderr` however it could be
overridden to have a different implementation. overridden to have a different implementation.
Check :func:`~discord.on_error` for more details. Check :func:`~discord.on_error` for more details.
.. versionchanged:: 2.0
``event_method`` parameter is now positional-only.
""" """
print(f'Ignoring exception in {event_method}', file=sys.stderr) print(f'Ignoring exception in {event_method}', file=sys.stderr)
traceback.print_exc() traceback.print_exc()
@ -471,12 +471,45 @@ class Client:
""" """
pass pass
async def _async_setup_hook(self) -> None:
# Called whenever the client needs to initialise asyncio objects with a running loop
loop = asyncio.get_running_loop()
self.loop = loop
self.http.loop = loop
self._connection.loop = loop
await self._connection.async_setup()
self._ready = asyncio.Event()
async def setup_hook(self) -> None:
"""|coro|
A coroutine to be called to setup the bot, by default this is blank.
To perform asynchronous setup after the bot is logged in but before
it has connected to the Websocket, overwrite this coroutine.
This is only called once, in :meth:`login`, and will be called before
any events are dispatched, making it a better solution than doing such
setup in the :func:`~discord.on_ready` event.
.. warning::
Since this is called *before* the websocket connection is made therefore
anything that waits for the websocket will deadlock, this includes things
like :meth:`wait_for` and :meth:`wait_until_ready`.
.. versionadded:: 2.0
"""
pass
# Login state management # Login state management
async def login(self, token: str) -> None: async def login(self, token: str) -> None:
"""|coro| """|coro|
Logs in the client with the specified credentials. Logs in the client with the specified credentials and
calls the :meth:`setup_hook`.
.. warning:: .. warning::
@ -502,10 +535,13 @@ class Client:
_log.info('Logging in using static token.') _log.info('Logging in using static token.')
await self._async_setup_hook()
state = self._connection state = self._connection
data = await state.http.static_login(token.strip()) data = await state.http.static_login(token.strip())
state.analytics_token = data.get('analytics_token', '') state.analytics_token = data.get('analytics_token', '')
state.user = ClientUser(state=state, data=data) state.user = ClientUser(state=state, data=data)
await self.setup_hook()
async def connect(self, *, reconnect: bool = True) -> None: async def connect(self, *, reconnect: bool = True) -> None:
"""|coro| """|coro|
@ -611,8 +647,12 @@ class Client:
await self.ws.close(code=1000) await self.ws.close(code=1000)
await self.http.close() await self.http.close()
if self._ready is not MISSING:
self._ready.clear() self._ready.clear()
self.loop = MISSING
def clear(self) -> None: def clear(self) -> None:
"""Clears the internal state of the bot. """Clears the internal state of the bot.
@ -644,12 +684,9 @@ class Client:
Roughly Equivalent to: :: Roughly Equivalent to: ::
try: try:
loop.run_until_complete(start(*args, **kwargs)) asyncio.run(self.start(*args, **kwargs))
except KeyboardInterrupt: except KeyboardInterrupt:
loop.run_until_complete(close()) return
# cancel all tasks lingering
finally:
loop.close()
.. warning:: .. warning::
@ -657,41 +694,18 @@ class Client:
is blocking. That means that registration of events or anything being is blocking. That means that registration of events or anything being
called after this function call will not execute until it returns. called after this function call will not execute until it returns.
""" """
loop = self.loop
try:
loop.add_signal_handler(signal.SIGINT, lambda: loop.stop())
loop.add_signal_handler(signal.SIGTERM, lambda: loop.stop())
except NotImplementedError:
pass
async def runner(): async def runner():
try: async with self:
await self.start(*args, **kwargs) await self.start(*args, **kwargs)
finally:
if not self.is_closed():
await self.close()
def stop_loop_on_completion(f):
loop.stop()
future = asyncio.ensure_future(runner(), loop=loop)
future.add_done_callback(stop_loop_on_completion)
try:
loop.run_forever()
except KeyboardInterrupt:
_log.info('Received signal to terminate bot and event loop.')
finally:
future.remove_done_callback(stop_loop_on_completion)
_log.info('Cleaning up tasks.')
_cleanup_loop(loop)
if not future.cancelled():
try: try:
return future.result() asyncio.run(runner())
except KeyboardInterrupt: except KeyboardInterrupt:
# I am unsure why this gets raised here but suppress it anyway # nothing to do here
return None # `asyncio.run` handles the loop cleanup
# and `self.start` closes all sockets and the HTTPClient instance.
return
# Properties # Properties
@ -712,7 +726,8 @@ class Client:
The client may be setting multiple activities, these can be accessed under :attr:`initial_activities`. The client may be setting multiple activities, these can be accessed under :attr:`initial_activities`.
""" """
return create_activity(self._connection._activities[0]) if self._connection._activities else None state = self._connection
return create_activity(state._activities[0], state) if state._activities else None
@initial_activity.setter @initial_activity.setter
def initial_activity(self, value: Optional[ActivityTypes]) -> None: def initial_activity(self, value: Optional[ActivityTypes]) -> None:
@ -727,7 +742,8 @@ class Client:
@property @property
def initial_activities(self) -> List[ActivityTypes]: def initial_activities(self) -> List[ActivityTypes]:
"""List[:class:`.BaseActivity`]: The activities set upon logging in.""" """List[:class:`.BaseActivity`]: The activities set upon logging in."""
return [create_activity(activity) for activity in self._connection._activities] state = self._connection
return [create_activity(activity, state) for activity in state._activities]
@initial_activities.setter @initial_activities.setter
def initial_activities(self, values: List[ActivityTypes]) -> None: def initial_activities(self, values: List[ActivityTypes]) -> None:
@ -750,7 +766,7 @@ class Client:
return return
@initial_status.setter @initial_status.setter
def initial_status(self, value): def initial_status(self, value: Status):
if value is Status.offline: if value is Status.offline:
self._connection._status = 'invisible' self._connection._status = 'invisible'
elif isinstance(value, Status): elif isinstance(value, Status):
@ -837,9 +853,10 @@ class Client:
the user is listening to a song on Spotify with a title longer the user is listening to a song on Spotify with a title longer
than 128 characters. See :issue:`1738` for more information. than 128 characters. See :issue:`1738` for more information.
""" """
activities = tuple(map(create_activity, self._client_activities[None])) state = self._connection
activities = tuple(create_activity(d, state) for d in self._client_activities[None])
if activities is None and not self.is_closed(): if activities is None and not self.is_closed():
activities = getattr(self._connection.settings, 'custom_activity', []) activities = getattr(state.settings, 'custom_activity', [])
activities = [activities] if activities else activities activities = [activities] if activities else activities
return activities return activities
@ -870,7 +887,8 @@ class Client:
.. versionadded:: 2.0 .. versionadded:: 2.0
""" """
return tuple(map(create_activity, self._client_activities.get('mobile', []))) state = self._connection
return tuple(create_activity(d, state) for d in self._client_activities.get('mobile', []))
@property @property
def desktop_activities(self) -> Tuple[ActivityTypes]: def desktop_activities(self) -> Tuple[ActivityTypes]:
@ -879,7 +897,8 @@ class Client:
.. versionadded:: 2.0 .. versionadded:: 2.0
""" """
return tuple(map(create_activity, self._client_activities.get('desktop', []))) state = self._connection
return tuple(create_activity(d, state) for d in self._client_activities.get('desktop', []))
@property @property
def web_activities(self) -> Tuple[ActivityTypes]: def web_activities(self) -> Tuple[ActivityTypes]:
@ -888,7 +907,8 @@ class Client:
.. versionadded:: 2.0 .. versionadded:: 2.0
""" """
return tuple(map(create_activity, self._client_activities.get('web', []))) state = self._connection
return tuple(create_activity(d, state) for d in self._client_activities.get('web', []))
@property @property
def client_activities(self) -> Tuple[ActivityTypes]: def client_activities(self) -> Tuple[ActivityTypes]:
@ -897,9 +917,10 @@ class Client:
.. versionadded:: 2.0 .. versionadded:: 2.0
""" """
activities = tuple(map(create_activity, self._client_activities.get('this', []))) state = self._connection
activities = tuple(create_activity(d, state) for d in self._client_activities.get('this', []))
if activities is None and not self.is_closed(): if activities is None and not self.is_closed():
activities = getattr(self._connection.settings, 'custom_activity', []) activities = getattr(state.settings, 'custom_activity', [])
activities = [activities] if activities else activities activities = [activities] if activities else activities
return activities return activities
@ -979,7 +1000,7 @@ class Client:
Returns Returns
-------- --------
Optional[:class:`.StageInstance`] Optional[:class:`.StageInstance`]
The returns stage instance of ``None`` if not found. The stage instance or ``None`` if not found.
""" """
from .channel import StageChannel from .channel import StageChannel
@ -1109,12 +1130,18 @@ class Client:
"""|coro| """|coro|
Waits until the client's internal cache is all ready. Waits until the client's internal cache is all ready.
.. warning::
Calling this inside :meth:`setup_hook` can lead to a deadlock.
""" """
if self._ready is not MISSING:
await self._ready.wait() await self._ready.wait()
def wait_for( def wait_for(
self, self,
event: str, event: str,
/,
*, *,
check: Optional[Callable[..., bool]] = None, check: Optional[Callable[..., bool]] = None,
timeout: Optional[float] = None, timeout: Optional[float] = None,
@ -1174,6 +1201,10 @@ class Client:
else: else:
await channel.send('\N{THUMBS UP SIGN}') await channel.send('\N{THUMBS UP SIGN}')
.. versionchanged:: 2.0
``event`` parameter is now positional-only.
Parameters Parameters
------------ ------------
@ -1220,7 +1251,7 @@ class Client:
# Event registration # Event registration
def event(self, coro: Coro) -> Coro: def event(self, coro: Coro, /) -> Coro:
"""A decorator that registers an event to listen to. """A decorator that registers an event to listen to.
You can find more info about the events on the :ref:`documentation below <discord-api-events>`. You can find more info about the events on the :ref:`documentation below <discord-api-events>`.
@ -1236,6 +1267,10 @@ class Client:
async def on_ready(): async def on_ready():
print('Ready!') print('Ready!')
.. versionchanged:: 2.0
``coro`` parameter is now positional-only.
Raises Raises
-------- --------
TypeError TypeError
@ -1257,7 +1292,7 @@ class Client:
status: Optional[Status] = None, status: Optional[Status] = None,
afk: bool = False, afk: bool = False,
edit_settings: bool = True, edit_settings: bool = True,
): ) -> None:
"""|coro| """|coro|
Changes the client's presence. Changes the client's presence.
@ -1267,8 +1302,8 @@ class Client:
Added option to update settings. Added option to update settings.
.. versionchanged:: 2.0 .. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising This function will now raise :exc:`TypeError` instead of
:exc:`TypeError`. ``InvalidArgument``.
Example Example
--------- ---------
@ -1439,7 +1474,7 @@ class Client:
""" """
code = utils.resolve_template(code) code = utils.resolve_template(code)
data = await self.http.get_template(code) data = await self.http.get_template(code)
return Template(data=data, state=self._connection) # type: ignore return Template(data=data, state=self._connection)
async def fetch_guild(self, guild_id: int, /, *, with_counts: bool = True) -> Guild: async def fetch_guild(self, guild_id: int, /, *, with_counts: bool = True) -> Guild:
"""|coro| """|coro|
@ -1498,8 +1533,8 @@ class Client:
``name`` and ``icon`` parameters are now keyword-only. The `region`` parameter has been removed. ``name`` and ``icon`` parameters are now keyword-only. The `region`` parameter has been removed.
.. versionchanged:: 2.0 .. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising This function will now raise :exc:`ValueError` instead of
:exc:`ValueError`. ``InvalidArgument``.
Parameters Parameters
---------- ----------

13
discord/colour.py

@ -32,7 +32,6 @@ from typing import (
Callable, Callable,
Optional, Optional,
Tuple, Tuple,
Type,
Union, Union,
) )
@ -90,10 +89,10 @@ class Colour:
def _get_byte(self, byte: int) -> int: def _get_byte(self, byte: int) -> int:
return (self.value >> (8 * byte)) & 0xFF return (self.value >> (8 * byte)) & 0xFF
def __eq__(self, other: Any) -> bool: def __eq__(self, other: object) -> bool:
return isinstance(other, Colour) and self.value == other.value return isinstance(other, Colour) and self.value == other.value
def __ne__(self, other: Any) -> bool: def __ne__(self, other: object) -> bool:
return not self.__eq__(other) return not self.__eq__(other)
def __str__(self) -> str: def __str__(self) -> str:
@ -265,28 +264,28 @@ class Colour:
"""A factory method that returns a :class:`Colour` with a value of ``0x95a5a6``.""" """A factory method that returns a :class:`Colour` with a value of ``0x95a5a6``."""
return cls(0x95A5A6) return cls(0x95A5A6)
lighter_gray: Callable[[Type[Self]], Self] = lighter_grey lighter_gray = lighter_grey
@classmethod @classmethod
def dark_grey(cls) -> Self: def dark_grey(cls) -> Self:
"""A factory method that returns a :class:`Colour` with a value of ``0x607d8b``.""" """A factory method that returns a :class:`Colour` with a value of ``0x607d8b``."""
return cls(0x607D8B) return cls(0x607D8B)
dark_gray: Callable[[Type[Self]], Self] = dark_grey dark_gray = dark_grey
@classmethod @classmethod
def light_grey(cls) -> Self: def light_grey(cls) -> Self:
"""A factory method that returns a :class:`Colour` with a value of ``0x979c9f``.""" """A factory method that returns a :class:`Colour` with a value of ``0x979c9f``."""
return cls(0x979C9F) return cls(0x979C9F)
light_gray: Callable[[Type[Self]], Self] = light_grey light_gray = light_grey
@classmethod @classmethod
def darker_grey(cls) -> Self: def darker_grey(cls) -> Self:
"""A factory method that returns a :class:`Colour` with a value of ``0x546e7a``.""" """A factory method that returns a :class:`Colour` with a value of ``0x546e7a``."""
return cls(0x546E7A) return cls(0x546E7A)
darker_gray: Callable[[Type[Self]], Self] = darker_grey darker_gray = darker_grey
@classmethod @classmethod
def og_blurple(cls) -> Self: def og_blurple(cls) -> Self:

10
discord/components.py

@ -374,9 +374,9 @@ class SelectOption:
emoji: Optional[Union[str, Emoji, PartialEmoji]] = None, emoji: Optional[Union[str, Emoji, PartialEmoji]] = None,
default: bool = False, default: bool = False,
) -> None: ) -> None:
self.label = label self.label: str = label
self.value = label if value is MISSING else value self.value: str = label if value is MISSING else value
self.description = description self.description: Optional[str] = description
if emoji is not None: if emoji is not None:
if isinstance(emoji, str): if isinstance(emoji, str):
@ -386,8 +386,8 @@ class SelectOption:
else: else:
raise TypeError(f'expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}') raise TypeError(f'expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}')
self.emoji = emoji self.emoji: Optional[Union[str, Emoji, PartialEmoji]] = emoji
self.default = default self.default: bool = default
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (

10
discord/context_managers.py

@ -25,13 +25,15 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from typing import TYPE_CHECKING, Optional, Type from typing import TYPE_CHECKING, Optional, Type, TypeVar
if TYPE_CHECKING: if TYPE_CHECKING:
from .abc import Messageable from .abc import Messageable
from types import TracebackType from types import TracebackType
BE = TypeVar('BE', bound=BaseException)
# fmt: off # fmt: off
__all__ = ( __all__ = (
'Typing', 'Typing',
@ -67,13 +69,13 @@ class Typing:
async def __aenter__(self) -> None: async def __aenter__(self) -> None:
self._channel = channel = await self.messageable._get_channel() self._channel = channel = await self.messageable._get_channel()
await channel._state.http.send_typing(channel.id) await channel._state.http.send_typing(channel.id)
self.task: asyncio.Task = self.loop.create_task(self.do_typing()) self.task: asyncio.Task[None] = self.loop.create_task(self.do_typing())
self.task.add_done_callback(_typing_done_callback) self.task.add_done_callback(_typing_done_callback)
async def __aexit__( async def __aexit__(
self, self,
exc_type: Optional[Type[BaseException]], exc_type: Optional[Type[BE]],
exc_value: Optional[BaseException], exc: Optional[BE],
traceback: Optional[TracebackType], traceback: Optional[TracebackType],
) -> None: ) -> None:
self.task.cancel() self.task.cancel()

190
discord/embeds.py

@ -25,7 +25,7 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations from __future__ import annotations
import datetime import datetime
from typing import Any, Dict, Final, List, Mapping, Protocol, TYPE_CHECKING, TypeVar, Union from typing import Any, Dict, List, Mapping, Optional, Protocol, TYPE_CHECKING, TypeVar, Union
from . import utils from . import utils
from .colour import Colour from .colour import Colour
@ -37,20 +37,6 @@ __all__ = (
# fmt: on # fmt: on
class _EmptyEmbed:
def __bool__(self) -> bool:
return False
def __repr__(self) -> str:
return 'Embed.Empty'
def __len__(self) -> int:
return 0
EmptyEmbed: Final = _EmptyEmbed()
class EmbedProxy: class EmbedProxy:
def __init__(self, layer: Dict[str, Any]): def __init__(self, layer: Dict[str, Any]):
self.__dict__.update(layer) self.__dict__.update(layer)
@ -62,8 +48,8 @@ class EmbedProxy:
inner = ', '.join((f'{k}={v!r}' for k, v in self.__dict__.items() if not k.startswith('_'))) inner = ', '.join((f'{k}={v!r}' for k, v in self.__dict__.items() if not k.startswith('_')))
return f'EmbedProxy({inner})' return f'EmbedProxy({inner})'
def __getattr__(self, attr: str) -> _EmptyEmbed: def __getattr__(self, attr: str) -> None:
return EmptyEmbed return None
if TYPE_CHECKING: if TYPE_CHECKING:
@ -72,37 +58,36 @@ if TYPE_CHECKING:
from .types.embed import Embed as EmbedData, EmbedType from .types.embed import Embed as EmbedData, EmbedType
T = TypeVar('T') T = TypeVar('T')
MaybeEmpty = Union[T, _EmptyEmbed]
class _EmbedFooterProxy(Protocol): class _EmbedFooterProxy(Protocol):
text: MaybeEmpty[str] text: Optional[str]
icon_url: MaybeEmpty[str] icon_url: Optional[str]
class _EmbedFieldProxy(Protocol): class _EmbedFieldProxy(Protocol):
name: MaybeEmpty[str] name: Optional[str]
value: MaybeEmpty[str] value: Optional[str]
inline: bool inline: bool
class _EmbedMediaProxy(Protocol): class _EmbedMediaProxy(Protocol):
url: MaybeEmpty[str] url: Optional[str]
proxy_url: MaybeEmpty[str] proxy_url: Optional[str]
height: MaybeEmpty[int] height: Optional[int]
width: MaybeEmpty[int] width: Optional[int]
class _EmbedVideoProxy(Protocol): class _EmbedVideoProxy(Protocol):
url: MaybeEmpty[str] url: Optional[str]
height: MaybeEmpty[int] height: Optional[int]
width: MaybeEmpty[int] width: Optional[int]
class _EmbedProviderProxy(Protocol): class _EmbedProviderProxy(Protocol):
name: MaybeEmpty[str] name: Optional[str]
url: MaybeEmpty[str] url: Optional[str]
class _EmbedAuthorProxy(Protocol): class _EmbedAuthorProxy(Protocol):
name: MaybeEmpty[str] name: Optional[str]
url: MaybeEmpty[str] url: Optional[str]
icon_url: MaybeEmpty[str] icon_url: Optional[str]
proxy_icon_url: MaybeEmpty[str] proxy_icon_url: Optional[str]
class Embed: class Embed:
@ -121,18 +106,15 @@ class Embed:
.. versionadded:: 2.0 .. versionadded:: 2.0
Certain properties return an ``EmbedProxy``, a type
that acts similar to a regular :class:`dict` except using dotted access,
e.g. ``embed.author.icon_url``. If the attribute
is invalid or empty, then a special sentinel value is returned,
:attr:`Embed.Empty`.
For ease of use, all parameters that expect a :class:`str` are implicitly For ease of use, all parameters that expect a :class:`str` are implicitly
casted to :class:`str` for you. casted to :class:`str` for you.
.. versionchanged:: 2.0
``Embed.Empty`` has been removed in favour of ``None``.
Attributes Attributes
----------- -----------
title: :class:`str` title: Optional[:class:`str`]
The title of the embed. The title of the embed.
This can be set during initialisation. This can be set during initialisation.
type: :class:`str` type: :class:`str`
@ -140,22 +122,19 @@ class Embed:
This can be set during initialisation. This can be set during initialisation.
Possible strings for embed types can be found on discord's Possible strings for embed types can be found on discord's
`api docs <https://discord.com/developers/docs/resources/channel#embed-object-embed-types>`_ `api docs <https://discord.com/developers/docs/resources/channel#embed-object-embed-types>`_
description: :class:`str` description: Optional[:class:`str`]
The description of the embed. The description of the embed.
This can be set during initialisation. This can be set during initialisation.
url: :class:`str` url: Optional[:class:`str`]
The URL of the embed. The URL of the embed.
This can be set during initialisation. This can be set during initialisation.
timestamp: :class:`datetime.datetime` timestamp: Optional[:class:`datetime.datetime`]
The timestamp of the embed content. This is an aware datetime. The timestamp of the embed content. This is an aware datetime.
If a naive datetime is passed, it is converted to an aware If a naive datetime is passed, it is converted to an aware
datetime with the local timezone. datetime with the local timezone.
colour: Union[:class:`Colour`, :class:`int`] colour: Optional[Union[:class:`Colour`, :class:`int`]]
The colour code of the embed. Aliased to ``color`` as well. The colour code of the embed. Aliased to ``color`` as well.
This can be set during initialisation. This can be set during initialisation.
Empty
A special sentinel value used by ``EmbedProxy`` and this class
to denote that the value or attribute is empty.
""" """
__slots__ = ( __slots__ = (
@ -174,36 +153,34 @@ class Embed:
'description', 'description',
) )
Empty: Final = EmptyEmbed
def __init__( def __init__(
self, self,
*, *,
colour: Union[int, Colour, _EmptyEmbed] = EmptyEmbed, colour: Optional[Union[int, Colour]] = None,
color: Union[int, Colour, _EmptyEmbed] = EmptyEmbed, color: Optional[Union[int, Colour]] = None,
title: MaybeEmpty[Any] = EmptyEmbed, title: Optional[Any] = None,
type: EmbedType = 'rich', type: EmbedType = 'rich',
url: MaybeEmpty[Any] = EmptyEmbed, url: Optional[Any] = None,
description: MaybeEmpty[Any] = EmptyEmbed, description: Optional[Any] = None,
timestamp: MaybeEmpty[datetime.datetime] = EmptyEmbed, timestamp: Optional[datetime.datetime] = None,
): ):
self.colour = colour if colour is not EmptyEmbed else color self.colour = colour if colour is not None else color
self.title = title self.title: Optional[str] = title
self.type = type self.type: EmbedType = type
self.url = url self.url: Optional[str] = url
self.description = description self.description: Optional[str] = description
if self.title is not EmptyEmbed: if self.title is not None:
self.title = str(self.title) self.title = str(self.title)
if self.description is not EmptyEmbed: if self.description is not None:
self.description = str(self.description) self.description = str(self.description)
if self.url is not EmptyEmbed: if self.url is not None:
self.url = str(self.url) self.url = str(self.url)
if timestamp is not EmptyEmbed: if timestamp is not None:
self.timestamp = timestamp self.timestamp = timestamp
@classmethod @classmethod
@ -227,18 +204,18 @@ class Embed:
# fill in the basic fields # fill in the basic fields
self.title = data.get('title', EmptyEmbed) self.title = data.get('title', None)
self.type = data.get('type', EmptyEmbed) self.type = data.get('type', None)
self.description = data.get('description', EmptyEmbed) self.description = data.get('description', None)
self.url = data.get('url', EmptyEmbed) self.url = data.get('url', None)
if self.title is not EmptyEmbed: if self.title is not None:
self.title = str(self.title) self.title = str(self.title)
if self.description is not EmptyEmbed: if self.description is not None:
self.description = str(self.description) self.description = str(self.description)
if self.url is not EmptyEmbed: if self.url is not None:
self.url = str(self.url) self.url = str(self.url)
# try to fill in the more rich fields # try to fill in the more rich fields
@ -268,7 +245,7 @@ class Embed:
return self.__class__.from_dict(self.to_dict()) return self.__class__.from_dict(self.to_dict())
def __len__(self) -> int: def __len__(self) -> int:
total = len(self.title) + len(self.description) total = len(self.title or '') + len(self.description or '')
for field in getattr(self, '_fields', []): for field in getattr(self, '_fields', []):
total += len(field['name']) + len(field['value']) total += len(field['name']) + len(field['value'])
@ -307,34 +284,36 @@ class Embed:
) )
@property @property
def colour(self) -> MaybeEmpty[Colour]: def colour(self) -> Optional[Colour]:
return getattr(self, '_colour', EmptyEmbed) return getattr(self, '_colour', None)
@colour.setter @colour.setter
def colour(self, value: Union[int, Colour, _EmptyEmbed]): def colour(self, value: Optional[Union[int, Colour]]) -> None:
if isinstance(value, (Colour, _EmptyEmbed)): if value is None:
self._colour = None
elif isinstance(value, Colour):
self._colour = value self._colour = value
elif isinstance(value, int): elif isinstance(value, int):
self._colour = Colour(value=value) self._colour = Colour(value=value)
else: else:
raise TypeError(f'Expected discord.Colour, int, or Embed.Empty but received {value.__class__.__name__} instead.') raise TypeError(f'Expected discord.Colour, int, or None but received {value.__class__.__name__} instead.')
color = colour color = colour
@property @property
def timestamp(self) -> MaybeEmpty[datetime.datetime]: def timestamp(self) -> Optional[datetime.datetime]:
return getattr(self, '_timestamp', EmptyEmbed) return getattr(self, '_timestamp', None)
@timestamp.setter @timestamp.setter
def timestamp(self, value: MaybeEmpty[datetime.datetime]): def timestamp(self, value: Optional[datetime.datetime]) -> None:
if isinstance(value, datetime.datetime): if isinstance(value, datetime.datetime):
if value.tzinfo is None: if value.tzinfo is None:
value = value.astimezone() value = value.astimezone()
self._timestamp = value self._timestamp = value
elif isinstance(value, _EmptyEmbed): elif value is None:
self._timestamp = value self._timestamp = None
else: else:
raise TypeError(f"Expected datetime.datetime or Embed.Empty received {value.__class__.__name__} instead") raise TypeError(f"Expected datetime.datetime or None received {value.__class__.__name__} instead")
@property @property
def footer(self) -> _EmbedFooterProxy: def footer(self) -> _EmbedFooterProxy:
@ -342,12 +321,12 @@ class Embed:
See :meth:`set_footer` for possible values you can access. See :meth:`set_footer` for possible values you can access.
If the attribute has no value then :attr:`Empty` is returned. If the attribute has no value then ``None`` is returned.
""" """
# Lying to the type checker for better developer UX. # Lying to the type checker for better developer UX.
return EmbedProxy(getattr(self, '_footer', {})) # type: ignore return EmbedProxy(getattr(self, '_footer', {})) # type: ignore
def set_footer(self, *, text: MaybeEmpty[Any] = EmptyEmbed, icon_url: MaybeEmpty[Any] = EmptyEmbed) -> Self: def set_footer(self, *, text: Optional[Any] = None, icon_url: Optional[Any] = None) -> Self:
"""Sets the footer for the embed content. """Sets the footer for the embed content.
This function returns the class instance to allow for fluent-style This function returns the class instance to allow for fluent-style
@ -362,10 +341,10 @@ class Embed:
""" """
self._footer = {} self._footer = {}
if text is not EmptyEmbed: if text is not None:
self._footer['text'] = str(text) self._footer['text'] = str(text)
if icon_url is not EmptyEmbed: if icon_url is not None:
self._footer['icon_url'] = str(icon_url) self._footer['icon_url'] = str(icon_url)
return self return self
@ -396,27 +375,24 @@ class Embed:
- ``width`` - ``width``
- ``height`` - ``height``
If the attribute has no value then :attr:`Empty` is returned. If the attribute has no value then ``None`` is returned.
""" """
# Lying to the type checker for better developer UX. # Lying to the type checker for better developer UX.
return EmbedProxy(getattr(self, '_image', {})) # type: ignore return EmbedProxy(getattr(self, '_image', {})) # type: ignore
def set_image(self, *, url: MaybeEmpty[Any]) -> Self: def set_image(self, *, url: Optional[Any]) -> Self:
"""Sets the image for the embed content. """Sets the image for the embed content.
This function returns the class instance to allow for fluent-style This function returns the class instance to allow for fluent-style
chaining. chaining.
.. versionchanged:: 1.4
Passing :attr:`Empty` removes the image.
Parameters Parameters
----------- -----------
url: :class:`str` url: :class:`str`
The source URL for the image. Only HTTP(S) is supported. The source URL for the image. Only HTTP(S) is supported.
""" """
if url is EmptyEmbed: if url is None:
try: try:
del self._image del self._image
except AttributeError: except AttributeError:
@ -439,19 +415,19 @@ class Embed:
- ``width`` - ``width``
- ``height`` - ``height``
If the attribute has no value then :attr:`Empty` is returned. If the attribute has no value then ``None`` is returned.
""" """
# Lying to the type checker for better developer UX. # Lying to the type checker for better developer UX.
return EmbedProxy(getattr(self, '_thumbnail', {})) # type: ignore return EmbedProxy(getattr(self, '_thumbnail', {})) # type: ignore
def set_thumbnail(self, *, url: MaybeEmpty[Any]) -> Self: def set_thumbnail(self, *, url: Optional[Any]) -> Self:
"""Sets the thumbnail for the embed content. """Sets the thumbnail for the embed content.
This function returns the class instance to allow for fluent-style This function returns the class instance to allow for fluent-style
chaining. chaining.
.. versionchanged:: 1.4 .. versionchanged:: 1.4
Passing :attr:`Empty` removes the thumbnail. Passing ``None`` removes the thumbnail.
Parameters Parameters
----------- -----------
@ -459,7 +435,7 @@ class Embed:
The source URL for the thumbnail. Only HTTP(S) is supported. The source URL for the thumbnail. Only HTTP(S) is supported.
""" """
if url is EmptyEmbed: if url is None:
try: try:
del self._thumbnail del self._thumbnail
except AttributeError: except AttributeError:
@ -481,7 +457,7 @@ class Embed:
- ``height`` for the video height. - ``height`` for the video height.
- ``width`` for the video width. - ``width`` for the video width.
If the attribute has no value then :attr:`Empty` is returned. If the attribute has no value then ``None`` is returned.
""" """
# Lying to the type checker for better developer UX. # Lying to the type checker for better developer UX.
return EmbedProxy(getattr(self, '_video', {})) # type: ignore return EmbedProxy(getattr(self, '_video', {})) # type: ignore
@ -492,7 +468,7 @@ class Embed:
The only attributes that might be accessed are ``name`` and ``url``. The only attributes that might be accessed are ``name`` and ``url``.
If the attribute has no value then :attr:`Empty` is returned. If the attribute has no value then ``None`` is returned.
""" """
# Lying to the type checker for better developer UX. # Lying to the type checker for better developer UX.
return EmbedProxy(getattr(self, '_provider', {})) # type: ignore return EmbedProxy(getattr(self, '_provider', {})) # type: ignore
@ -503,12 +479,12 @@ class Embed:
See :meth:`set_author` for possible values you can access. See :meth:`set_author` for possible values you can access.
If the attribute has no value then :attr:`Empty` is returned. If the attribute has no value then ``None`` is returned.
""" """
# Lying to the type checker for better developer UX. # Lying to the type checker for better developer UX.
return EmbedProxy(getattr(self, '_author', {})) # type: ignore return EmbedProxy(getattr(self, '_author', {})) # type: ignore
def set_author(self, *, name: Any, url: MaybeEmpty[Any] = EmptyEmbed, icon_url: MaybeEmpty[Any] = EmptyEmbed) -> Self: def set_author(self, *, name: Any, url: Optional[Any] = None, icon_url: Optional[Any] = None) -> Self:
"""Sets the author for the embed content. """Sets the author for the embed content.
This function returns the class instance to allow for fluent-style This function returns the class instance to allow for fluent-style
@ -528,10 +504,10 @@ class Embed:
'name': str(name), 'name': str(name),
} }
if url is not EmptyEmbed: if url is not None:
self._author['url'] = str(url) self._author['url'] = str(url)
if icon_url is not EmptyEmbed: if icon_url is not None:
self._author['icon_url'] = str(icon_url) self._author['icon_url'] = str(icon_url)
return self return self
@ -553,11 +529,11 @@ class Embed:
@property @property
def fields(self) -> List[_EmbedFieldProxy]: def fields(self) -> List[_EmbedFieldProxy]:
"""List[Union[``EmbedProxy``, :attr:`Empty`]]: Returns a :class:`list` of ``EmbedProxy`` denoting the field contents. """List[``EmbedProxy``]: Returns a :class:`list` of ``EmbedProxy`` denoting the field contents.
See :meth:`add_field` for possible values you can access. See :meth:`add_field` for possible values you can access.
If the attribute has no value then :attr:`Empty` is returned. If the attribute has no value then ``None`` is returned.
""" """
# Lying to the type checker for better developer UX. # Lying to the type checker for better developer UX.
return [EmbedProxy(d) for d in getattr(self, '_fields', [])] # type: ignore return [EmbedProxy(d) for d in getattr(self, '_fields', [])] # type: ignore

10
discord/emoji.py

@ -23,7 +23,7 @@ DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations from __future__ import annotations
from typing import Any, Iterator, List, Optional, TYPE_CHECKING, Tuple from typing import Any, Collection, Iterator, List, Optional, TYPE_CHECKING, Tuple
from .asset import Asset, AssetMixin from .asset import Asset, AssetMixin
from .utils import SnowflakeList, snowflake_time, MISSING from .utils import SnowflakeList, snowflake_time, MISSING
@ -142,10 +142,10 @@ class Emoji(_EmojiTag, AssetMixin):
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<Emoji id={self.id} name={self.name!r} animated={self.animated} managed={self.managed}>' return f'<Emoji id={self.id} name={self.name!r} animated={self.animated} managed={self.managed}>'
def __eq__(self, other: Any) -> bool: def __eq__(self, other: object) -> bool:
return isinstance(other, _EmojiTag) and self.id == other.id return isinstance(other, _EmojiTag) and self.id == other.id
def __ne__(self, other: Any) -> bool: def __ne__(self, other: object) -> bool:
return not self.__eq__(other) return not self.__eq__(other)
def __hash__(self) -> int: def __hash__(self) -> int:
@ -214,7 +214,9 @@ class Emoji(_EmojiTag, AssetMixin):
await self._state.http.delete_custom_emoji(self.guild_id, self.id, reason=reason) await self._state.http.delete_custom_emoji(self.guild_id, self.id, reason=reason)
async def edit(self, *, name: str = MISSING, roles: List[Snowflake] = MISSING, reason: Optional[str] = None) -> Emoji: async def edit(
self, *, name: str = MISSING, roles: Collection[Snowflake] = MISSING, reason: Optional[str] = None
) -> Emoji:
r"""|coro| r"""|coro|
Edits the custom emoji. Edits the custom emoji.

45
discord/enums.py

@ -25,7 +25,7 @@ from __future__ import annotations
import types import types
from collections import namedtuple from collections import namedtuple
from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, TypeVar from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Iterator, Mapping
__all__ = ( __all__ = (
'Enum', 'Enum',
@ -149,38 +149,38 @@ class EnumMeta(type):
value_cls._actual_enum_cls_ = actual_cls # type: ignore - Runtime attribute isn't understood value_cls._actual_enum_cls_ = actual_cls # type: ignore - Runtime attribute isn't understood
return actual_cls return actual_cls
def __iter__(cls): def __iter__(cls) -> Iterator[Any]:
return (cls._enum_member_map_[name] for name in cls._enum_member_names_) return (cls._enum_member_map_[name] for name in cls._enum_member_names_)
def __reversed__(cls): def __reversed__(cls) -> Iterator[Any]:
return (cls._enum_member_map_[name] for name in reversed(cls._enum_member_names_)) return (cls._enum_member_map_[name] for name in reversed(cls._enum_member_names_))
def __len__(cls): def __len__(cls) -> int:
return len(cls._enum_member_names_) return len(cls._enum_member_names_)
def __repr__(cls): def __repr__(cls) -> str:
return f'<enum {cls.__name__}>' return f'<enum {cls.__name__}>'
@property @property
def __members__(cls): def __members__(cls) -> Mapping[str, Any]:
return types.MappingProxyType(cls._enum_member_map_) return types.MappingProxyType(cls._enum_member_map_)
def __call__(cls, value): def __call__(cls, value: str) -> Any:
try: try:
return cls._enum_value_map_[value] return cls._enum_value_map_[value]
except (KeyError, TypeError): except (KeyError, TypeError):
raise ValueError(f"{value!r} is not a valid {cls.__name__}") raise ValueError(f"{value!r} is not a valid {cls.__name__}")
def __getitem__(cls, key): def __getitem__(cls, key: str) -> Any:
return cls._enum_member_map_[key] return cls._enum_member_map_[key]
def __setattr__(cls, name, value): def __setattr__(cls, name: str, value: Any) -> None:
raise TypeError('Enums are immutable') raise TypeError('Enums are immutable')
def __delattr__(cls, attr): def __delattr__(cls, attr: str) -> None:
raise TypeError('Enums are immutable') raise TypeError('Enums are immutable')
def __instancecheck__(self, instance): def __instancecheck__(self, instance: Any) -> bool:
# isinstance(x, Y) # isinstance(x, Y)
# -> __instancecheck__(Y, x) # -> __instancecheck__(Y, x)
try: try:
@ -215,7 +215,7 @@ class ChannelType(Enum):
private_thread = 12 private_thread = 12
stage_voice = 13 stage_voice = 13
def __str__(self): def __str__(self) -> str:
return self.name return self.name
def __int__(self): def __int__(self):
@ -258,10 +258,10 @@ class SpeakingState(Enum):
soundshare = 2 soundshare = 2
priority = 4 priority = 4
def __str__(self): def __str__(self) -> str:
return self.name return self.name
def __int__(self): def __int__(self) -> int:
return self.value return self.value
@ -272,7 +272,7 @@ class VerificationLevel(Enum, comparable=True):
high = 3 high = 3
highest = 4 highest = 4
def __str__(self): def __str__(self) -> str:
return self.name return self.name
@ -281,7 +281,7 @@ class ContentFilter(Enum, comparable=True):
no_role = 1 no_role = 1
all_members = 2 all_members = 2
def __str__(self): def __str__(self) -> str:
return self.name return self.name
@ -347,7 +347,7 @@ class Status(Enum):
do_not_disturb = 'dnd' do_not_disturb = 'dnd'
invisible = 'invisible' invisible = 'invisible'
def __str__(self): def __str__(self) -> str:
return self.value return self.value
@ -360,7 +360,7 @@ class DefaultAvatar(Enum):
red = 4 red = 4
pink = 5 pink = 5
def __str__(self): def __str__(self) -> str:
return self.name return self.name
@ -554,6 +554,7 @@ class UserFlags(Enum):
discord_certified_moderator = 262144 discord_certified_moderator = 262144
bot_http_interactions = 524288 bot_http_interactions = 524288
spammer = 1048576 spammer = 1048576
disable_premium = 2097152
class ActivityType(Enum): class ActivityType(Enum):
@ -565,7 +566,7 @@ class ActivityType(Enum):
custom = 4 custom = 4
competing = 5 competing = 5
def __int__(self): def __int__(self) -> int:
return self.value return self.value
@ -690,7 +691,7 @@ class VideoQualityMode(Enum):
auto = 1 auto = 1
full = 2 full = 2
def __int__(self): def __int__(self) -> int:
return self.value return self.value
@ -700,7 +701,7 @@ class ComponentType(Enum):
select = 3 select = 3
text_input = 4 text_input = 4
def __int__(self): def __int__(self) -> int:
return self.value return self.value
@ -719,7 +720,7 @@ class ButtonStyle(Enum):
red = 4 red = 4
url = 5 url = 5
def __int__(self): def __int__(self) -> int:
return self.value return self.value

24
discord/ext/commands/_types.py

@ -23,26 +23,42 @@ DEALINGS IN THE SOFTWARE.
""" """
from typing import Any, Callable, Coroutine, TYPE_CHECKING, TypeVar, Union from typing import Any, Callable, Coroutine, TYPE_CHECKING, TypeVar, Union, Tuple
T = TypeVar('T')
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import ParamSpec
from .bot import Bot
from .context import Context from .context import Context
from .cog import Cog from .cog import Cog
from .errors import CommandError from .errors import CommandError
T = TypeVar('T') P = ParamSpec('P')
MaybeCoroFunc = Union[
Callable[P, 'Coro[T]'],
Callable[P, T],
]
else:
P = TypeVar('P')
MaybeCoroFunc = Tuple[P, T]
_Bot = Bot
Coro = Coroutine[Any, Any, T] Coro = Coroutine[Any, Any, T]
MaybeCoro = Union[T, Coro[T]] MaybeCoro = Union[T, Coro[T]]
CoroFunc = Callable[..., Coro[Any]] CoroFunc = Callable[..., Coro[Any]]
ContextT = TypeVar('ContextT', bound='Context')
Check = Union[Callable[["Cog", "ContextT"], MaybeCoro[bool]], Callable[["ContextT"], MaybeCoro[bool]]] Check = Union[Callable[["Cog", "ContextT"], MaybeCoro[bool]], Callable[["ContextT"], MaybeCoro[bool]]]
Hook = Union[Callable[["Cog", "ContextT"], Coro[Any]], Callable[["ContextT"], Coro[Any]]] Hook = Union[Callable[["Cog", "ContextT"], Coro[Any]], Callable[["ContextT"], Coro[Any]]]
Error = Union[Callable[["Cog", "ContextT", "CommandError"], Coro[Any]], Callable[["ContextT", "CommandError"], Coro[Any]]] Error = Union[Callable[["Cog", "ContextT", "CommandError"], Coro[Any]], Callable[["ContextT", "CommandError"], Coro[Any]]]
ContextT = TypeVar('ContextT', bound='Context[Any]')
BotT = TypeVar('BotT', bound=_Bot, covariant=True)
ErrorT = TypeVar('ErrorT', bound='Error[Context[Any]]')
HookT = TypeVar('HookT', bound='Hook[Context[Any]]')
# This is merely a tag type to avoid circular import issues. # This is merely a tag type to avoid circular import issues.
# Yes, this is a terrible solution but ultimately it is the only solution. # Yes, this is a terrible solution but ultimately it is the only solution.

334
discord/ext/commands/bot.py

@ -33,9 +33,24 @@ import importlib.util
import sys import sys
import traceback import traceback
import types import types
from typing import Any, Callable, Mapping, List, Dict, TYPE_CHECKING, Optional, TypeVar, Type, Union, overload from typing import (
Any,
Callable,
Mapping,
List,
Dict,
TYPE_CHECKING,
Optional,
TypeVar,
Type,
Union,
Iterable,
Collection,
overload,
)
import discord import discord
from discord.utils import MISSING, _is_submodule
from .core import GroupMixin from .core import GroupMixin
from .view import StringView from .view import StringView
@ -50,36 +65,44 @@ if TYPE_CHECKING:
import importlib.machinery import importlib.machinery
from discord.message import Message from discord.message import Message
from discord.abc import User from discord.abc import User, Snowflake
from ._types import ( from ._types import (
_Bot,
BotT,
Check, Check,
CoroFunc, CoroFunc,
ContextT,
MaybeCoroFunc,
) )
_Prefix = Union[Iterable[str], str]
_PrefixCallable = MaybeCoroFunc[[BotT, Message], _Prefix]
PrefixType = Union[_Prefix, _PrefixCallable[BotT]]
__all__ = ( __all__ = (
'when_mentioned', 'when_mentioned',
'when_mentioned_or', 'when_mentioned_or',
'Bot', 'Bot',
) )
MISSING: Any = discord.utils.MISSING
T = TypeVar('T') T = TypeVar('T')
CFT = TypeVar('CFT', bound='CoroFunc') CFT = TypeVar('CFT', bound='CoroFunc')
CXT = TypeVar('CXT', bound='Context')
BT = TypeVar('BT', bound='Bot')
def when_mentioned(bot: Bot, msg: Message) -> List[str]: def when_mentioned(bot: _Bot, msg: Message, /) -> List[str]:
"""A callable that implements a command prefix equivalent to being mentioned. """A callable that implements a command prefix equivalent to being mentioned.
These are meant to be passed into the :attr:`.Bot.command_prefix` attribute. These are meant to be passed into the :attr:`.Bot.command_prefix` attribute.
.. versionchanged:: 2.0
``bot`` and ``msg`` parameters are now positional-only.
""" """
# bot.user will never be None when this is called # bot.user will never be None when this is called
return [f'<@{bot.user.id}> ', f'<@!{bot.user.id}> '] # type: ignore return [f'<@{bot.user.id}> ', f'<@!{bot.user.id}> '] # type: ignore
def when_mentioned_or(*prefixes: str) -> Callable[[Bot, Message], List[str]]: def when_mentioned_or(*prefixes: str) -> Callable[[_Bot, Message], List[str]]:
"""A callable that implements when mentioned or other prefixes provided. """A callable that implements when mentioned or other prefixes provided.
These are meant to be passed into the :attr:`.Bot.command_prefix` attribute. These are meant to be passed into the :attr:`.Bot.command_prefix` attribute.
@ -117,34 +140,38 @@ def when_mentioned_or(*prefixes: str) -> Callable[[Bot, Message], List[str]]:
return inner return inner
def _is_submodule(parent: str, child: str) -> bool:
return parent == child or child.startswith(parent + ".")
class _DefaultRepr: class _DefaultRepr:
def __repr__(self): def __repr__(self):
return '<default-help-command>' return '<default-help-command>'
_default = _DefaultRepr() _default: Any = _DefaultRepr()
class BotBase(GroupMixin): class BotBase(GroupMixin[None]):
def __init__(self, command_prefix, help_command=_default, description=None, **options): def __init__(
self,
command_prefix: PrefixType[BotT],
help_command: Optional[HelpCommand[Any]] = _default,
description: Optional[str] = None,
**options: Any,
) -> None:
super().__init__(**options) super().__init__(**options)
self.command_prefix = command_prefix self.command_prefix: PrefixType[BotT] = command_prefix
self.extra_events: Dict[str, List[CoroFunc]] = {} self.extra_events: Dict[str, List[CoroFunc]] = {}
# Self doesn't have the ClientT bound, but since this is a mixin it technically does
self.__tree: app_commands.CommandTree[Self] = app_commands.CommandTree(self) # type: ignore
self.__cogs: Dict[str, Cog] = {} self.__cogs: Dict[str, Cog] = {}
self.__extensions: Dict[str, types.ModuleType] = {} self.__extensions: Dict[str, types.ModuleType] = {}
self._checks: List[Check] = [] self._checks: List[Check] = []
self._check_once = [] self._check_once: List[Check] = []
self._before_invoke = None self._before_invoke: Optional[CoroFunc] = None
self._after_invoke = None self._after_invoke: Optional[CoroFunc] = None
self._help_command = None self._help_command: Optional[HelpCommand[Any]] = None
self.description = inspect.cleandoc(description) if description else '' self.description: str = inspect.cleandoc(description) if description else ''
self.owner_id = options.get('owner_id') self.owner_id: Optional[int] = options.get('owner_id')
self.owner_ids = options.get('owner_ids', set()) self.owner_ids: Optional[Collection[int]] = options.get('owner_ids', set())
self.strip_after_prefix = options.get('strip_after_prefix', False) self.strip_after_prefix: bool = options.get('strip_after_prefix', False)
if self.owner_id and self.owner_ids: if self.owner_id and self.owner_ids:
raise TypeError('Both owner_id and owner_ids are set') raise TypeError('Both owner_id and owner_ids are set')
@ -172,7 +199,7 @@ class BotBase(GroupMixin):
# internal helpers # internal helpers
def dispatch(self, event_name: str, *args: Any, **kwargs: Any) -> None: def dispatch(self, event_name: str, /, *args: Any, **kwargs: Any) -> None:
# super() will resolve to Client # super() will resolve to Client
super().dispatch(event_name, *args, **kwargs) # type: ignore super().dispatch(event_name, *args, **kwargs) # type: ignore
ev = 'on_' + event_name ev = 'on_' + event_name
@ -183,19 +210,19 @@ class BotBase(GroupMixin):
async def close(self) -> None: async def close(self) -> None:
for extension in tuple(self.__extensions): for extension in tuple(self.__extensions):
try: try:
self.unload_extension(extension) await self.unload_extension(extension)
except Exception: except Exception:
pass pass
for cog in tuple(self.__cogs): for cog in tuple(self.__cogs):
try: try:
self.remove_cog(cog) await self.remove_cog(cog)
except Exception: except Exception:
pass pass
await super().close() # type: ignore await super().close() # type: ignore
async def on_command_error(self, context: Context, exception: errors.CommandError) -> None: async def on_command_error(self, context: Context[BotT], exception: errors.CommandError, /) -> None:
"""|coro| """|coro|
The default command error handler provided by the bot. The default command error handler provided by the bot.
@ -204,6 +231,10 @@ class BotBase(GroupMixin):
overridden to have a different implementation. overridden to have a different implementation.
This only fires if you do not specify any listeners for command error. This only fires if you do not specify any listeners for command error.
.. versionchanged:: 2.0
``context`` and ``exception`` parameters are now positional-only.
""" """
if self.extra_events.get('on_command_error', None): if self.extra_events.get('on_command_error', None):
return return
@ -221,7 +252,7 @@ class BotBase(GroupMixin):
# global check registration # global check registration
def check(self, func: T) -> T: def check(self, func: T, /) -> T:
r"""A decorator that adds a global check to the bot. r"""A decorator that adds a global check to the bot.
A global check is similar to a :func:`.check` that is applied A global check is similar to a :func:`.check` that is applied
@ -245,12 +276,15 @@ class BotBase(GroupMixin):
def check_commands(ctx): def check_commands(ctx):
return ctx.command.qualified_name in allowed_commands return ctx.command.qualified_name in allowed_commands
.. versionchanged:: 2.0
``func`` parameter is now positional-only.
""" """
# T was used instead of Check to ensure the type matches on return # T was used instead of Check to ensure the type matches on return
self.add_check(func) # type: ignore self.add_check(func) # type: ignore
return func return func
def add_check(self, func: Check, /, *, call_once: bool = False) -> None: def add_check(self, func: Check[ContextT], /, *, call_once: bool = False) -> None:
"""Adds a global check to the bot. """Adds a global check to the bot.
This is the non-decorator interface to :meth:`.check` This is the non-decorator interface to :meth:`.check`
@ -274,7 +308,7 @@ class BotBase(GroupMixin):
else: else:
self._checks.append(func) self._checks.append(func)
def remove_check(self, func: Check, /, *, call_once: bool = False) -> None: def remove_check(self, func: Check[ContextT], /, *, call_once: bool = False) -> None:
"""Removes a global check from the bot. """Removes a global check from the bot.
This function is idempotent and will not raise an exception This function is idempotent and will not raise an exception
@ -299,7 +333,7 @@ class BotBase(GroupMixin):
except ValueError: except ValueError:
pass pass
def check_once(self, func: CFT) -> CFT: def check_once(self, func: CFT, /) -> CFT:
r"""A decorator that adds a "call once" global check to the bot. r"""A decorator that adds a "call once" global check to the bot.
Unlike regular global checks, this one is called only once Unlike regular global checks, this one is called only once
@ -333,11 +367,15 @@ class BotBase(GroupMixin):
def whitelist(ctx): def whitelist(ctx):
return ctx.message.author.id in my_whitelist return ctx.message.author.id in my_whitelist
.. versionchanged:: 2.0
``func`` parameter is now positional-only.
""" """
self.add_check(func, call_once=True) self.add_check(func, call_once=True)
return func return func
async def can_run(self, ctx: Context, *, call_once: bool = False) -> bool: async def can_run(self, ctx: Context[BotT], /, *, call_once: bool = False) -> bool:
data = self._check_once if call_once else self._checks data = self._check_once if call_once else self._checks
if len(data) == 0: if len(data) == 0:
@ -346,12 +384,15 @@ class BotBase(GroupMixin):
# type-checker doesn't distinguish between functions and methods # type-checker doesn't distinguish between functions and methods
return await discord.utils.async_all(f(ctx) for f in data) # type: ignore return await discord.utils.async_all(f(ctx) for f in data) # type: ignore
async def is_owner(self, user: User) -> bool: async def is_owner(self, user: User, /) -> bool:
"""|coro| """|coro|
Checks if a :class:`~discord.User` or :class:`~discord.Member` is the owner of Checks if a :class:`~discord.User` or :class:`~discord.Member` is the owner of
this bot. this bot.
.. versionchanged:: 2.0
``user`` parameter is now positional-only.
Parameters Parameters
----------- -----------
user: :class:`.abc.User` user: :class:`.abc.User`
@ -374,7 +415,7 @@ class BotBase(GroupMixin):
else: else:
raise AttributeError('Owners aren\'t set.') raise AttributeError('Owners aren\'t set.')
def before_invoke(self, coro: CFT) -> CFT: def before_invoke(self, coro: CFT, /) -> CFT:
"""A decorator that registers a coroutine as a pre-invoke hook. """A decorator that registers a coroutine as a pre-invoke hook.
A pre-invoke hook is called directly before the command is A pre-invoke hook is called directly before the command is
@ -390,6 +431,10 @@ class BotBase(GroupMixin):
without error. If any check or argument parsing procedures fail without error. If any check or argument parsing procedures fail
then the hooks are not called. then the hooks are not called.
.. versionchanged:: 2.0
``coro`` parameter is now positional-only.
Parameters Parameters
----------- -----------
coro: :ref:`coroutine <coroutine>` coro: :ref:`coroutine <coroutine>`
@ -406,7 +451,7 @@ class BotBase(GroupMixin):
self._before_invoke = coro self._before_invoke = coro
return coro return coro
def after_invoke(self, coro: CFT) -> CFT: def after_invoke(self, coro: CFT, /) -> CFT:
r"""A decorator that registers a coroutine as a post-invoke hook. r"""A decorator that registers a coroutine as a post-invoke hook.
A post-invoke hook is called directly after the command is A post-invoke hook is called directly after the command is
@ -423,6 +468,10 @@ class BotBase(GroupMixin):
callback raising an error (i.e. :exc:`.CommandInvokeError`\). callback raising an error (i.e. :exc:`.CommandInvokeError`\).
This makes it ideal for clean-up scenarios. This makes it ideal for clean-up scenarios.
.. versionchanged:: 2.0
``coro`` parameter is now positional-only.
Parameters Parameters
----------- -----------
coro: :ref:`coroutine <coroutine>` coro: :ref:`coroutine <coroutine>`
@ -441,9 +490,13 @@ class BotBase(GroupMixin):
# listener registration # listener registration
def add_listener(self, func: CoroFunc, name: str = MISSING) -> None: def add_listener(self, func: CoroFunc, /, name: str = MISSING) -> None:
"""The non decorator alternative to :meth:`.listen`. """The non decorator alternative to :meth:`.listen`.
.. versionchanged:: 2.0
``func`` parameter is now positional-only.
Parameters Parameters
----------- -----------
func: :ref:`coroutine <coroutine>` func: :ref:`coroutine <coroutine>`
@ -473,9 +526,13 @@ class BotBase(GroupMixin):
else: else:
self.extra_events[name] = [func] self.extra_events[name] = [func]
def remove_listener(self, func: CoroFunc, name: str = MISSING) -> None: def remove_listener(self, func: CoroFunc, /, name: str = MISSING) -> None:
"""Removes a listener from the pool of listeners. """Removes a listener from the pool of listeners.
.. versionchanged:: 2.0
``func`` parameter is now positional-only.
Parameters Parameters
----------- -----------
func func
@ -531,11 +588,29 @@ class BotBase(GroupMixin):
# cogs # cogs
def add_cog(self, cog: Cog, /, *, override: bool = False) -> None: async def add_cog(
"""Adds a "cog" to the bot. self,
cog: Cog,
/,
*,
override: bool = False,
guild: Optional[Snowflake] = MISSING,
guilds: List[Snowflake] = MISSING,
) -> None:
"""|coro|
Adds a "cog" to the bot.
A cog is a class that has its own event listeners and commands. A cog is a class that has its own event listeners and commands.
If the cog is a :class:`.app_commands.Group` then it is added to
the bot's :class:`~discord.app_commands.CommandTree` as well.
.. note::
Exceptions raised inside a :class:`.Cog`'s :meth:`~.Cog.cog_load` method will be
propagated to the caller.
.. versionchanged:: 2.0 .. versionchanged:: 2.0
:exc:`.ClientException` is raised when a cog with the same name :exc:`.ClientException` is raised when a cog with the same name
@ -545,6 +620,10 @@ class BotBase(GroupMixin):
``cog`` parameter is now positional-only. ``cog`` parameter is now positional-only.
.. versionchanged:: 2.0
This method is now a :term:`coroutine`.
Parameters Parameters
----------- -----------
cog: :class:`.Cog` cog: :class:`.Cog`
@ -553,6 +632,19 @@ class BotBase(GroupMixin):
If a previously loaded cog with the same name should be ejected If a previously loaded cog with the same name should be ejected
instead of raising an error. instead of raising an error.
.. versionadded:: 2.0
guild: Optional[:class:`~discord.abc.Snowflake`]
If the cog is an application command group, then this would be the
guild where the cog group would be added to. If not given then
it becomes a global command instead.
.. versionadded:: 2.0
guilds: List[:class:`~discord.abc.Snowflake`]
If the cog is an application command group, then this would be the
guilds where the cog group would be added to. If not given then
it becomes a global command instead. Cannot be mixed with
``guild``.
.. versionadded:: 2.0 .. versionadded:: 2.0
Raises Raises
@ -574,9 +666,12 @@ class BotBase(GroupMixin):
if existing is not None: if existing is not None:
if not override: if not override:
raise discord.ClientException(f'Cog named {cog_name!r} already loaded') raise discord.ClientException(f'Cog named {cog_name!r} already loaded')
self.remove_cog(cog_name) await self.remove_cog(cog_name, guild=guild, guilds=guilds)
cog = cog._inject(self) if isinstance(cog, app_commands.Group):
self.__tree.add_command(cog, override=override, guild=guild, guilds=guilds)
cog = await cog._inject(self, override=override, guild=guild, guilds=guilds)
self.__cogs[cog_name] = cog self.__cogs[cog_name] = cog
def get_cog(self, name: str, /) -> Optional[Cog]: def get_cog(self, name: str, /) -> Optional[Cog]:
@ -602,8 +697,17 @@ class BotBase(GroupMixin):
""" """
return self.__cogs.get(name) return self.__cogs.get(name)
def remove_cog(self, name: str, /) -> Optional[Cog]: async def remove_cog(
"""Removes a cog from the bot and returns it. self,
name: str,
/,
*,
guild: Optional[Snowflake] = MISSING,
guilds: List[Snowflake] = MISSING,
) -> Optional[Cog]:
"""|coro|
Removes a cog from the bot and returns it.
All registered commands and event listeners that the All registered commands and event listeners that the
cog has registered will be removed as well. cog has registered will be removed as well.
@ -614,10 +718,27 @@ class BotBase(GroupMixin):
``name`` parameter is now positional-only. ``name`` parameter is now positional-only.
.. versionchanged:: 2.0
This method is now a :term:`coroutine`.
Parameters Parameters
----------- -----------
name: :class:`str` name: :class:`str`
The name of the cog to remove. The name of the cog to remove.
guild: Optional[:class:`~discord.abc.Snowflake`]
If the cog is an application command group, then this would be the
guild where the cog group would be removed from. If not given then
a global command is removed instead instead.
.. versionadded:: 2.0
guilds: List[:class:`~discord.abc.Snowflake`]
If the cog is an application command group, then this would be the
guilds where the cog group would be removed from. If not given then
a global command is removed instead instead. Cannot be mixed with
``guild``.
.. versionadded:: 2.0
Returns Returns
------- -------
@ -632,7 +753,16 @@ class BotBase(GroupMixin):
help_command = self._help_command help_command = self._help_command
if help_command and help_command.cog is cog: if help_command and help_command.cog is cog:
help_command.cog = None help_command.cog = None
cog._eject(self)
guild_ids = _retrieve_guild_ids(cog, guild, guilds)
if isinstance(cog, app_commands.Group):
if guild_ids is None:
self.__tree.remove_command(name)
else:
for guild_id in guild_ids:
self.__tree.remove_command(name, guild=discord.Object(guild_id))
await cog._eject(self, guild_ids=guild_ids)
return cog return cog
@ -643,12 +773,12 @@ class BotBase(GroupMixin):
# extensions # extensions
def _remove_module_references(self, name: str) -> None: async def _remove_module_references(self, name: str) -> None:
# find all references to the module # find all references to the module
# remove the cogs registered from the module # remove the cogs registered from the module
for cogname, cog in self.__cogs.copy().items(): for cogname, cog in self.__cogs.copy().items():
if _is_submodule(name, cog.__module__): if _is_submodule(name, cog.__module__):
self.remove_cog(cogname) await self.remove_cog(cogname)
# remove all the commands from the module # remove all the commands from the module
for cmd in self.all_commands.copy().values(): for cmd in self.all_commands.copy().values():
@ -667,14 +797,17 @@ class BotBase(GroupMixin):
for index in reversed(remove): for index in reversed(remove):
del event_list[index] del event_list[index]
def _call_module_finalizers(self, lib: types.ModuleType, key: str) -> None: # remove all relevant application commands from the tree
self.__tree._remove_with_module(name)
async def _call_module_finalizers(self, lib: types.ModuleType, key: str) -> None:
try: try:
func = getattr(lib, 'teardown') func = getattr(lib, 'teardown')
except AttributeError: except AttributeError:
pass pass
else: else:
try: try:
func(self) await func(self)
except Exception: except Exception:
pass pass
finally: finally:
@ -685,7 +818,7 @@ class BotBase(GroupMixin):
if _is_submodule(name, module): if _is_submodule(name, module):
del sys.modules[module] del sys.modules[module]
def _load_from_module_spec(self, spec: importlib.machinery.ModuleSpec, key: str) -> None: async def _load_from_module_spec(self, spec: importlib.machinery.ModuleSpec, key: str) -> None:
# precondition: key not in self.__extensions # precondition: key not in self.__extensions
lib = importlib.util.module_from_spec(spec) lib = importlib.util.module_from_spec(spec)
sys.modules[key] = lib sys.modules[key] = lib
@ -702,11 +835,11 @@ class BotBase(GroupMixin):
raise errors.NoEntryPointError(key) raise errors.NoEntryPointError(key)
try: try:
setup(self) await setup(self)
except Exception as e: except Exception as e:
del sys.modules[key] del sys.modules[key]
self._remove_module_references(lib.__name__) await self._remove_module_references(lib.__name__)
self._call_module_finalizers(lib, key) await self._call_module_finalizers(lib, key)
raise errors.ExtensionFailed(key, e) from e raise errors.ExtensionFailed(key, e) from e
else: else:
self.__extensions[key] = lib self.__extensions[key] = lib
@ -717,8 +850,10 @@ class BotBase(GroupMixin):
except ImportError: except ImportError:
raise errors.ExtensionNotFound(name) raise errors.ExtensionNotFound(name)
def load_extension(self, name: str, *, package: Optional[str] = None) -> None: async def load_extension(self, name: str, *, package: Optional[str] = None) -> None:
"""Loads an extension. """|coro|
Loads an extension.
An extension is a python module that contains commands, cogs, or An extension is a python module that contains commands, cogs, or
listeners. listeners.
@ -727,6 +862,10 @@ class BotBase(GroupMixin):
the entry point on what to do when the extension is loaded. This entry the entry point on what to do when the extension is loaded. This entry
point must have a single argument, the ``bot``. point must have a single argument, the ``bot``.
.. versionchanged:: 2.0
This method is now a :term:`coroutine`.
Parameters Parameters
------------ ------------
name: :class:`str` name: :class:`str`
@ -762,10 +901,12 @@ class BotBase(GroupMixin):
if spec is None: if spec is None:
raise errors.ExtensionNotFound(name) raise errors.ExtensionNotFound(name)
self._load_from_module_spec(spec, name) await self._load_from_module_spec(spec, name)
def unload_extension(self, name: str, *, package: Optional[str] = None) -> None: async def unload_extension(self, name: str, *, package: Optional[str] = None) -> None:
"""Unloads an extension. """|coro|
Unloads an extension.
When the extension is unloaded, all commands, listeners, and cogs are When the extension is unloaded, all commands, listeners, and cogs are
removed from the bot and the module is un-imported. removed from the bot and the module is un-imported.
@ -775,6 +916,10 @@ class BotBase(GroupMixin):
parameter, the ``bot``, similar to ``setup`` from parameter, the ``bot``, similar to ``setup`` from
:meth:`~.Bot.load_extension`. :meth:`~.Bot.load_extension`.
.. versionchanged:: 2.0
This method is now a :term:`coroutine`.
Parameters Parameters
------------ ------------
name: :class:`str` name: :class:`str`
@ -802,10 +947,10 @@ class BotBase(GroupMixin):
if lib is None: if lib is None:
raise errors.ExtensionNotLoaded(name) raise errors.ExtensionNotLoaded(name)
self._remove_module_references(lib.__name__) await self._remove_module_references(lib.__name__)
self._call_module_finalizers(lib, name) await self._call_module_finalizers(lib, name)
def reload_extension(self, name: str, *, package: Optional[str] = None) -> None: async def reload_extension(self, name: str, *, package: Optional[str] = None) -> None:
"""Atomically reloads an extension. """Atomically reloads an extension.
This replaces the extension with the same extension, only refreshed. This is This replaces the extension with the same extension, only refreshed. This is
@ -856,14 +1001,14 @@ class BotBase(GroupMixin):
try: try:
# Unload and then load the module... # Unload and then load the module...
self._remove_module_references(lib.__name__) await self._remove_module_references(lib.__name__)
self._call_module_finalizers(lib, name) await self._call_module_finalizers(lib, name)
self.load_extension(name) await self.load_extension(name)
except Exception: except Exception:
# if the load failed, the remnants should have been # if the load failed, the remnants should have been
# cleaned from the load_extension function call # cleaned from the load_extension function call
# so let's load it from our old compiled library. # so let's load it from our old compiled library.
lib.setup(self) # type: ignore await lib.setup(self)
self.__extensions[name] = lib self.__extensions[name] = lib
# revert sys.modules back to normal and raise back to caller # revert sys.modules back to normal and raise back to caller
@ -878,11 +1023,11 @@ class BotBase(GroupMixin):
# help command stuff # help command stuff
@property @property
def help_command(self) -> Optional[HelpCommand]: def help_command(self) -> Optional[HelpCommand[Any]]:
return self._help_command return self._help_command
@help_command.setter @help_command.setter
def help_command(self, value: Optional[HelpCommand]) -> None: def help_command(self, value: Optional[HelpCommand[Any]]) -> None:
if value is not None: if value is not None:
if not isinstance(value, HelpCommand): if not isinstance(value, HelpCommand):
raise TypeError('help_command must be a subclass of HelpCommand') raise TypeError('help_command must be a subclass of HelpCommand')
@ -896,14 +1041,32 @@ class BotBase(GroupMixin):
else: else:
self._help_command = None self._help_command = None
# application command interop
# As mentioned above, this is a mixin so the Self type hint fails here.
# However, since the only classes that can use this are subclasses of Client
# anyway, then this is sound.
@property
def tree(self) -> app_commands.CommandTree[Self]: # type: ignore
""":class:`~discord.app_commands.CommandTree`: The command tree responsible for handling the application commands
in this bot.
.. versionadded:: 2.0
"""
return self.__tree
# command processing # command processing
async def get_prefix(self, message: Message) -> Union[List[str], str]: async def get_prefix(self, message: Message, /) -> Union[List[str], str]:
"""|coro| """|coro|
Retrieves the prefix the bot is listening to Retrieves the prefix the bot is listening to
with the message as a context. with the message as a context.
.. versionchanged:: 2.0
``message`` parameter is now positional-only.
Parameters Parameters
----------- -----------
message: :class:`discord.Message` message: :class:`discord.Message`
@ -917,11 +1080,12 @@ class BotBase(GroupMixin):
""" """
prefix = ret = self.command_prefix prefix = ret = self.command_prefix
if callable(prefix): if callable(prefix):
ret = await discord.utils.maybe_coroutine(prefix, self, message) # self will be a Bot or AutoShardedBot
ret = await discord.utils.maybe_coroutine(prefix, self, message) # type: ignore
if not isinstance(ret, str): if not isinstance(ret, str):
try: try:
ret = list(ret) ret = list(ret) # type: ignore
except TypeError: except TypeError:
# It's possible that a generator raised this exception. Don't # It's possible that a generator raised this exception. Don't
# replace it with our own error if that's the case. # replace it with our own error if that's the case.
@ -942,6 +1106,7 @@ class BotBase(GroupMixin):
async def get_context( async def get_context(
self, self,
message: Message, message: Message,
/,
) -> Context[Self]: # type: ignore ) -> Context[Self]: # type: ignore
... ...
@ -949,16 +1114,18 @@ class BotBase(GroupMixin):
async def get_context( async def get_context(
self, self,
message: Message, message: Message,
/,
*, *,
cls: Type[CXT] = ..., cls: Type[ContextT] = ...,
) -> CXT: # type: ignore ) -> ContextT:
... ...
async def get_context( async def get_context(
self, self,
message: Message, message: Message,
/,
*, *,
cls: Type[CXT] = MISSING, cls: Type[ContextT] = MISSING,
) -> Any: ) -> Any:
r"""|coro| r"""|coro|
@ -972,6 +1139,10 @@ class BotBase(GroupMixin):
If the context is not valid then it is not a valid candidate to be If the context is not valid then it is not a valid candidate to be
invoked under :meth:`~.Bot.invoke`. invoked under :meth:`~.Bot.invoke`.
.. versionchanged:: 2.0
``message`` parameter is now positional-only.
Parameters Parameters
----------- -----------
message: :class:`discord.Message` message: :class:`discord.Message`
@ -1039,12 +1210,16 @@ class BotBase(GroupMixin):
ctx.command = self.all_commands.get(invoker) ctx.command = self.all_commands.get(invoker)
return ctx return ctx
async def invoke(self, ctx: Context) -> None: async def invoke(self, ctx: Context[BotT], /) -> None:
"""|coro| """|coro|
Invokes the command given under the invocation context and Invokes the command given under the invocation context and
handles all the internal event dispatch mechanisms. handles all the internal event dispatch mechanisms.
.. versionchanged:: 2.0
``ctx`` parameter is now positional-only.
Parameters Parameters
----------- -----------
ctx: :class:`.Context` ctx: :class:`.Context`
@ -1065,7 +1240,7 @@ class BotBase(GroupMixin):
exc = errors.CommandNotFound(f'Command "{ctx.invoked_with}" is not found') exc = errors.CommandNotFound(f'Command "{ctx.invoked_with}" is not found')
self.dispatch('command_error', ctx, exc) self.dispatch('command_error', ctx, exc)
async def process_commands(self, message: Message) -> None: async def process_commands(self, message: Message, /) -> None:
"""|coro| """|coro|
This function processes the commands that have been registered This function processes the commands that have been registered
@ -1082,6 +1257,10 @@ class BotBase(GroupMixin):
This also checks if the message's author is a bot and doesn't This also checks if the message's author is a bot and doesn't
call :meth:`~.Bot.get_context` or :meth:`~.Bot.invoke` if so. call :meth:`~.Bot.get_context` or :meth:`~.Bot.invoke` if so.
.. versionchanged:: 2.0
``message`` parameter is now positional-only.
Parameters Parameters
----------- -----------
message: :class:`discord.Message` message: :class:`discord.Message`
@ -1091,9 +1270,10 @@ class BotBase(GroupMixin):
return return
ctx = await self.get_context(message) ctx = await self.get_context(message)
await self.invoke(ctx) # the type of the invocation context's bot attribute will be correct
await self.invoke(ctx) # type: ignore
async def on_message(self, message): async def on_message(self, message: Message, /) -> None:
await self.process_commands(message) await self.process_commands(message)

134
discord/ext/commands/cog.py

@ -24,14 +24,17 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations from __future__ import annotations
import inspect import inspect
import discord.utils import discord
from discord import app_commands
from discord.utils import maybe_coroutine
from typing import Any, Callable, ClassVar, Dict, Generator, List, Optional, TYPE_CHECKING, Tuple, TypeVar, Type from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, TYPE_CHECKING, Tuple, TypeVar, Union
from ._types import _BaseCommand from ._types import _BaseCommand, BotT
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import Self from typing_extensions import Self
from discord.abc import Snowflake
from .bot import BotBase from .bot import BotBase
from .context import Context from .context import Context
@ -109,20 +112,35 @@ class CogMeta(type):
__cog_name__: str __cog_name__: str
__cog_settings__: Dict[str, Any] __cog_settings__: Dict[str, Any]
__cog_commands__: List[Command] __cog_commands__: List[Command[Any, ..., Any]]
__cog_is_app_commands_group__: bool
__cog_app_commands__: List[Union[app_commands.Group, app_commands.Command[Any, ..., Any]]]
__cog_listeners__: List[Tuple[str, str]] __cog_listeners__: List[Tuple[str, str]]
def __new__(cls, *args: Any, **kwargs: Any) -> Self: def __new__(cls, *args: Any, **kwargs: Any) -> Self:
name, bases, attrs = args name, bases, attrs = args
attrs['__cog_name__'] = kwargs.pop('name', name) attrs['__cog_name__'] = kwargs.get('name', name)
attrs['__cog_settings__'] = kwargs.pop('command_attrs', {}) attrs['__cog_settings__'] = kwargs.pop('command_attrs', {})
is_parent = any(issubclass(base, app_commands.Group) for base in bases)
attrs['__cog_is_app_commands_group__'] = is_parent
description = kwargs.pop('description', None) description = kwargs.get('description', None)
if description is None: if description is None:
description = inspect.cleandoc(attrs.get('__doc__', '')) description = inspect.cleandoc(attrs.get('__doc__', ''))
attrs['__cog_description__'] = description attrs['__cog_description__'] = description
if is_parent:
attrs['__discord_app_commands_skip_init_binding__'] = True
# This is hacky, but it signals the Group not to process this info.
# It's overridden later.
attrs['__discord_app_commands_group_children__'] = True
else:
# Remove the extraneous keyword arguments we're using
kwargs.pop('name', None)
kwargs.pop('description', None)
commands = {} commands = {}
cog_app_commands = {}
listeners = {} listeners = {}
no_bot_cog = 'Commands or listeners must not start with cog_ or bot_ (in method {0.__name__}.{1})' no_bot_cog = 'Commands or listeners must not start with cog_ or bot_ (in method {0.__name__}.{1})'
@ -143,6 +161,8 @@ class CogMeta(type):
if elem.startswith(('cog_', 'bot_')): if elem.startswith(('cog_', 'bot_')):
raise TypeError(no_bot_cog.format(base, elem)) raise TypeError(no_bot_cog.format(base, elem))
commands[elem] = value commands[elem] = value
elif isinstance(value, (app_commands.Group, app_commands.Command)) and value.parent is None:
cog_app_commands[elem] = value
elif inspect.iscoroutinefunction(value): elif inspect.iscoroutinefunction(value):
try: try:
getattr(value, '__cog_listener__') getattr(value, '__cog_listener__')
@ -154,6 +174,13 @@ class CogMeta(type):
listeners[elem] = value listeners[elem] = value
new_cls.__cog_commands__ = list(commands.values()) # this will be copied in Cog.__new__ new_cls.__cog_commands__ = list(commands.values()) # this will be copied in Cog.__new__
new_cls.__cog_app_commands__ = list(cog_app_commands.values())
if is_parent:
# Prefill the app commands for the Group as well..
# The type checker doesn't like runtime attribute modification and this one's
# optional so it can't be cheesed.
new_cls.__discord_app_commands_group_children__ = new_cls.__cog_app_commands__ # type: ignore
listeners_as_list = [] listeners_as_list = []
for listener in listeners.values(): for listener in listeners.values():
@ -189,10 +216,11 @@ class Cog(metaclass=CogMeta):
are equally valid here. are equally valid here.
""" """
__cog_name__: ClassVar[str] __cog_name__: str
__cog_settings__: ClassVar[Dict[str, Any]] __cog_settings__: Dict[str, Any]
__cog_commands__: ClassVar[List[Command[Self, ..., Any]]] __cog_commands__: List[Command[Self, ..., Any]]
__cog_listeners__: ClassVar[List[Tuple[str, str]]] __cog_app_commands__: List[Union[app_commands.Group, app_commands.Command[Self, ..., Any]]]
__cog_listeners__: List[Tuple[str, str]]
def __new__(cls, *args: Any, **kwargs: Any) -> Self: def __new__(cls, *args: Any, **kwargs: Any) -> Self:
# For issue 426, we need to store a copy of the command objects # For issue 426, we need to store a copy of the command objects
@ -219,6 +247,27 @@ class Cog(metaclass=CogMeta):
parent.remove_command(command.name) # type: ignore parent.remove_command(command.name) # type: ignore
parent.add_command(command) # type: ignore parent.add_command(command) # type: ignore
# Register the application commands
children: List[Union[app_commands.Group, app_commands.Command[Self, ..., Any]]] = []
for command in cls.__cog_app_commands__:
if cls.__cog_is_app_commands_group__:
# Type checker doesn't understand this type of narrowing.
# Not even with TypeGuard somehow.
command.parent = self # type: ignore
copy = command._copy_with_binding(self)
children.append(copy)
if command._attr:
setattr(self, command._attr, copy)
self.__cog_app_commands__ = children
if cls.__cog_is_app_commands_group__:
# Dynamic attribute setting
self.__discord_app_commands_group_children__ = children # type: ignore
# Enforce this to work even if someone forgets __init__
self.module = cls.__module__ # type: ignore
return self return self
def get_commands(self) -> List[Command[Self, ..., Any]]: def get_commands(self) -> List[Command[Self, ..., Any]]:
@ -330,18 +379,35 @@ class Cog(metaclass=CogMeta):
return not hasattr(self.cog_command_error.__func__, '__cog_special_method__') return not hasattr(self.cog_command_error.__func__, '__cog_special_method__')
@_cog_special_method @_cog_special_method
def cog_unload(self) -> None: async def cog_load(self) -> None:
"""A special method that is called when the cog gets removed. """|maybecoro|
A special method that is called when the cog gets loaded.
This function **cannot** be a coroutine. It must be a regular Subclasses must replace this if they want special asynchronous loading behaviour.
function. Note that the ``__init__`` special method does not allow asynchronous code to run
inside it, thus this is helpful for setting up code that needs to be asynchronous.
.. versionadded:: 2.0
"""
pass
@_cog_special_method
async def cog_unload(self) -> None:
"""|maybecoro|
A special method that is called when the cog gets removed.
Subclasses must replace this if they want special unloading behaviour. Subclasses must replace this if they want special unloading behaviour.
.. versionchanged:: 2.0
This method can now be a :term:`coroutine`.
""" """
pass pass
@_cog_special_method @_cog_special_method
def bot_check_once(self, ctx: Context) -> bool: def bot_check_once(self, ctx: Context[BotT]) -> bool:
"""A special method that registers as a :meth:`.Bot.check_once` """A special method that registers as a :meth:`.Bot.check_once`
check. check.
@ -351,7 +417,7 @@ class Cog(metaclass=CogMeta):
return True return True
@_cog_special_method @_cog_special_method
def bot_check(self, ctx: Context) -> bool: def bot_check(self, ctx: Context[BotT]) -> bool:
"""A special method that registers as a :meth:`.Bot.check` """A special method that registers as a :meth:`.Bot.check`
check. check.
@ -361,7 +427,7 @@ class Cog(metaclass=CogMeta):
return True return True
@_cog_special_method @_cog_special_method
def cog_check(self, ctx: Context) -> bool: def cog_check(self, ctx: Context[BotT]) -> bool:
"""A special method that registers as a :func:`~discord.ext.commands.check` """A special method that registers as a :func:`~discord.ext.commands.check`
for every command and subcommand in this cog. for every command and subcommand in this cog.
@ -371,7 +437,7 @@ class Cog(metaclass=CogMeta):
return True return True
@_cog_special_method @_cog_special_method
async def cog_command_error(self, ctx: Context, error: Exception) -> None: async def cog_command_error(self, ctx: Context[BotT], error: Exception) -> None:
"""A special method that is called whenever an error """A special method that is called whenever an error
is dispatched inside this cog. is dispatched inside this cog.
@ -390,7 +456,7 @@ class Cog(metaclass=CogMeta):
pass pass
@_cog_special_method @_cog_special_method
async def cog_before_invoke(self, ctx: Context) -> None: async def cog_before_invoke(self, ctx: Context[BotT]) -> None:
"""A special method that acts as a cog local pre-invoke hook. """A special method that acts as a cog local pre-invoke hook.
This is similar to :meth:`.Command.before_invoke`. This is similar to :meth:`.Command.before_invoke`.
@ -405,7 +471,7 @@ class Cog(metaclass=CogMeta):
pass pass
@_cog_special_method @_cog_special_method
async def cog_after_invoke(self, ctx: Context) -> None: async def cog_after_invoke(self, ctx: Context[BotT]) -> None:
"""A special method that acts as a cog local post-invoke hook. """A special method that acts as a cog local post-invoke hook.
This is similar to :meth:`.Command.after_invoke`. This is similar to :meth:`.Command.after_invoke`.
@ -419,9 +485,13 @@ class Cog(metaclass=CogMeta):
""" """
pass pass
def _inject(self, bot: BotBase) -> Self: async def _inject(self, bot: BotBase, override: bool, guild: Optional[Snowflake], guilds: List[Snowflake]) -> Self:
cls = self.__class__ cls = self.__class__
# we'll call this first so that errors can propagate without
# having to worry about undoing anything
await maybe_coroutine(self.cog_load)
# realistically, the only thing that can cause loading errors # realistically, the only thing that can cause loading errors
# is essentially just the command loading, which raises if there are # is essentially just the command loading, which raises if there are
# duplicates. When this condition is met, we want to undo all what # duplicates. When this condition is met, we want to undo all what
@ -430,7 +500,8 @@ class Cog(metaclass=CogMeta):
command.cog = self command.cog = self
if command.parent is None: if command.parent is None:
try: try:
bot.add_command(command) # Type checker does not understand the generic bounds here
bot.add_command(command) # type: ignore
except Exception as e: except Exception as e:
# undo our additions # undo our additions
for to_undo in self.__cog_commands__[:index]: for to_undo in self.__cog_commands__[:index]:
@ -452,9 +523,15 @@ class Cog(metaclass=CogMeta):
for name, method_name in self.__cog_listeners__: for name, method_name in self.__cog_listeners__:
bot.add_listener(getattr(self, method_name), name) bot.add_listener(getattr(self, method_name), name)
# Only do this if these are "top level" commands
if not cls.__cog_is_app_commands_group__:
for command in self.__cog_app_commands__:
# This is already atomic
bot.tree.add_command(command, override=override, guild=guild, guilds=guilds)
return self return self
def _eject(self, bot: BotBase) -> None: async def _eject(self, bot: BotBase, guild_ids: Optional[Iterable[int]]) -> None:
cls = self.__class__ cls = self.__class__
try: try:
@ -462,6 +539,15 @@ class Cog(metaclass=CogMeta):
if command.parent is None: if command.parent is None:
bot.remove_command(command.name) bot.remove_command(command.name)
if not cls.__cog_is_app_commands_group__:
for command in self.__cog_app_commands__:
guild_ids = guild_ids or command._guild_ids
if guild_ids is None:
bot.tree.remove_command(command.name)
else:
for guild_id in guild_ids:
bot.tree.remove_command(command.name, guild=discord.Object(id=guild_id))
for name, method_name in self.__cog_listeners__: for name, method_name in self.__cog_listeners__:
bot.remove_listener(getattr(self, method_name), name) bot.remove_listener(getattr(self, method_name), name)
@ -472,6 +558,6 @@ class Cog(metaclass=CogMeta):
bot.remove_check(self.bot_check_once, call_once=True) bot.remove_check(self.bot_check_once, call_once=True)
finally: finally:
try: try:
self.cog_unload() await maybe_coroutine(self.cog_unload)
except Exception: except Exception:
pass pass

16
discord/ext/commands/context.py

@ -28,6 +28,8 @@ import re
from typing import Any, Dict, Generic, List, Optional, TYPE_CHECKING, TypeVar, Union from typing import Any, Dict, Generic, List, Optional, TYPE_CHECKING, TypeVar, Union
from ._types import BotT
import discord.abc import discord.abc
import discord.utils import discord.utils
@ -58,7 +60,6 @@ MISSING: Any = discord.utils.MISSING
T = TypeVar('T') T = TypeVar('T')
BotT = TypeVar('BotT', bound="Bot")
CogT = TypeVar('CogT', bound="Cog") CogT = TypeVar('CogT', bound="Cog")
if TYPE_CHECKING: if TYPE_CHECKING:
@ -132,10 +133,10 @@ class Context(discord.abc.Messageable, Generic[BotT]):
args: List[Any] = MISSING, args: List[Any] = MISSING,
kwargs: Dict[str, Any] = MISSING, kwargs: Dict[str, Any] = MISSING,
prefix: Optional[str] = None, prefix: Optional[str] = None,
command: Optional[Command] = None, command: Optional[Command[Any, ..., Any]] = None,
invoked_with: Optional[str] = None, invoked_with: Optional[str] = None,
invoked_parents: List[str] = MISSING, invoked_parents: List[str] = MISSING,
invoked_subcommand: Optional[Command] = None, invoked_subcommand: Optional[Command[Any, ..., Any]] = None,
subcommand_passed: Optional[str] = None, subcommand_passed: Optional[str] = None,
command_failed: bool = False, command_failed: bool = False,
current_parameter: Optional[inspect.Parameter] = None, current_parameter: Optional[inspect.Parameter] = None,
@ -145,11 +146,11 @@ class Context(discord.abc.Messageable, Generic[BotT]):
self.args: List[Any] = args or [] self.args: List[Any] = args or []
self.kwargs: Dict[str, Any] = kwargs or {} self.kwargs: Dict[str, Any] = kwargs or {}
self.prefix: Optional[str] = prefix self.prefix: Optional[str] = prefix
self.command: Optional[Command] = command self.command: Optional[Command[Any, ..., Any]] = command
self.view: StringView = view self.view: StringView = view
self.invoked_with: Optional[str] = invoked_with self.invoked_with: Optional[str] = invoked_with
self.invoked_parents: List[str] = invoked_parents or [] self.invoked_parents: List[str] = invoked_parents or []
self.invoked_subcommand: Optional[Command] = invoked_subcommand self.invoked_subcommand: Optional[Command[Any, ..., Any]] = invoked_subcommand
self.subcommand_passed: Optional[str] = subcommand_passed self.subcommand_passed: Optional[str] = subcommand_passed
self.command_failed: bool = command_failed self.command_failed: bool = command_failed
self.current_parameter: Optional[inspect.Parameter] = current_parameter self.current_parameter: Optional[inspect.Parameter] = current_parameter
@ -352,6 +353,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
""" """
from .core import Group, Command, wrap_callback from .core import Group, Command, wrap_callback
from .errors import CommandError from .errors import CommandError
from .help import _context
bot = self.bot bot = self.bot
cmd = bot.help_command cmd = bot.help_command
@ -359,8 +361,8 @@ class Context(discord.abc.Messageable, Generic[BotT]):
if cmd is None: if cmd is None:
return None return None
cmd = cmd.copy() _context.set(self)
cmd.context = self
if len(args) == 0: if len(args) == 0:
await cmd.prepare_help_command(self, None) await cmd.prepare_help_command(self, None)
mapping = cmd.get_bot_mapping() mapping = cmd.get_bot_mapping()

153
discord/ext/commands/converter.py

@ -41,7 +41,6 @@ from typing import (
Tuple, Tuple,
Union, Union,
runtime_checkable, runtime_checkable,
overload,
) )
import discord import discord
@ -51,9 +50,8 @@ if TYPE_CHECKING:
from .context import Context from .context import Context
from discord.state import Channel from discord.state import Channel
from discord.threads import Thread from discord.threads import Thread
from .bot import Bot
_Bot = Bot from ._types import BotT, _Bot
__all__ = ( __all__ = (
@ -80,13 +78,14 @@ __all__ = (
'ThreadConverter', 'ThreadConverter',
'GuildChannelConverter', 'GuildChannelConverter',
'GuildStickerConverter', 'GuildStickerConverter',
'ScheduledEventConverter',
'clean_content', 'clean_content',
'Greedy', 'Greedy',
'run_converters', 'run_converters',
) )
def _get_from_guilds(bot, getter, argument): def _get_from_guilds(bot: _Bot, getter: str, argument: Any) -> Any:
result = None result = None
for guild in bot.guilds: for guild in bot.guilds:
result = getattr(guild, getter)(argument) result = getattr(guild, getter)(argument)
@ -114,7 +113,7 @@ class Converter(Protocol[T_co]):
method to do its conversion logic. This method must be a :ref:`coroutine <coroutine>`. method to do its conversion logic. This method must be a :ref:`coroutine <coroutine>`.
""" """
async def convert(self, ctx: Context, argument: str) -> T_co: async def convert(self, ctx: Context[BotT], argument: str) -> T_co:
"""|coro| """|coro|
The method to override to do conversion logic. The method to override to do conversion logic.
@ -162,7 +161,7 @@ class ObjectConverter(IDConverter[discord.Object]):
2. Lookup by member, role, or channel mention. 2. Lookup by member, role, or channel mention.
""" """
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Object: async def convert(self, ctx: Context[BotT], argument: str) -> discord.Object:
match = self._get_id_match(argument) or re.match(r'<(?:@(?:!|&)?|#)([0-9]{15,20})>$', argument) match = self._get_id_match(argument) or re.match(r'<(?:@(?:!|&)?|#)([0-9]{15,20})>$', argument)
if match is None: if match is None:
@ -195,7 +194,7 @@ class MemberConverter(IDConverter[discord.Member]):
optionally caching the result if :attr:`.MemberCacheFlags.joined` is enabled. optionally caching the result if :attr:`.MemberCacheFlags.joined` is enabled.
""" """
async def query_member_named(self, guild, argument): async def query_member_named(self, guild: discord.Guild, argument: str) -> Optional[discord.Member]:
cache = guild._state.member_cache_flags.joined cache = guild._state.member_cache_flags.joined
if len(argument) > 5 and argument[-5] == '#': if len(argument) > 5 and argument[-5] == '#':
username, _, discriminator = argument.rpartition('#') username, _, discriminator = argument.rpartition('#')
@ -226,7 +225,7 @@ class MemberConverter(IDConverter[discord.Member]):
return None return None
return members[0] return members[0]
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Member: async def convert(self, ctx: Context[BotT], argument: str) -> discord.Member:
bot = ctx.bot bot = ctx.bot
match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument) match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument)
guild = ctx.guild guild = ctx.guild
@ -280,7 +279,7 @@ class UserConverter(IDConverter[discord.User]):
and it's not available in cache. and it's not available in cache.
""" """
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.User: async def convert(self, ctx: Context[BotT], argument: str) -> discord.User:
match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument) match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument)
result = None result = None
state = ctx._state state = ctx._state
@ -345,7 +344,7 @@ class PartialMessageConverter(Converter[discord.PartialMessage]):
if not match: if not match:
raise MessageNotFound(argument) raise MessageNotFound(argument)
data = match.groupdict() data = match.groupdict()
channel_id = discord.utils._get_as_snowflake(data, 'channel_id') channel_id = discord.utils._get_as_snowflake(data, 'channel_id') or ctx.channel.id
message_id = int(data['message_id']) message_id = int(data['message_id'])
guild_id = data.get('guild_id') guild_id = data.get('guild_id')
if guild_id is None: if guild_id is None:
@ -358,7 +357,7 @@ class PartialMessageConverter(Converter[discord.PartialMessage]):
@staticmethod @staticmethod
def _resolve_channel( def _resolve_channel(
ctx: Context[_Bot], guild_id: Optional[int], channel_id: Optional[int] ctx: Context[BotT], guild_id: Optional[int], channel_id: Optional[int]
) -> Optional[Union[Channel, Thread]]: ) -> Optional[Union[Channel, Thread]]:
if channel_id is None: if channel_id is None:
# we were passed just a message id so we can assume the channel is the current context channel # we were passed just a message id so we can assume the channel is the current context channel
@ -372,7 +371,7 @@ class PartialMessageConverter(Converter[discord.PartialMessage]):
return ctx.bot.get_channel(channel_id) return ctx.bot.get_channel(channel_id)
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.PartialMessage: async def convert(self, ctx: Context[BotT], argument: str) -> discord.PartialMessage:
guild_id, message_id, channel_id = self._get_id_matches(ctx, argument) guild_id, message_id, channel_id = self._get_id_matches(ctx, argument)
channel = self._resolve_channel(ctx, guild_id, channel_id) channel = self._resolve_channel(ctx, guild_id, channel_id)
if not channel or not isinstance(channel, discord.abc.Messageable): if not channel or not isinstance(channel, discord.abc.Messageable):
@ -395,14 +394,14 @@ class MessageConverter(IDConverter[discord.Message]):
Raise :exc:`.ChannelNotFound`, :exc:`.MessageNotFound` or :exc:`.ChannelNotReadable` instead of generic :exc:`.BadArgument` Raise :exc:`.ChannelNotFound`, :exc:`.MessageNotFound` or :exc:`.ChannelNotReadable` instead of generic :exc:`.BadArgument`
""" """
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Message: async def convert(self, ctx: Context[BotT], argument: str) -> discord.Message:
guild_id, message_id, channel_id = PartialMessageConverter._get_id_matches(ctx, argument) guild_id, message_id, channel_id = PartialMessageConverter._get_id_matches(ctx, argument)
message = ctx.bot._connection._get_message(message_id) message = ctx.bot._connection._get_message(message_id)
if message: if message:
return message return message
channel = PartialMessageConverter._resolve_channel(ctx, guild_id, channel_id) channel = PartialMessageConverter._resolve_channel(ctx, guild_id, channel_id)
if not channel or not isinstance(channel, discord.abc.Messageable): if not channel or not isinstance(channel, discord.abc.Messageable):
raise ChannelNotFound(channel_id) # type: ignore - channel_id won't be None here raise ChannelNotFound(channel_id)
try: try:
return await channel.fetch_message(message_id) return await channel.fetch_message(message_id)
except discord.NotFound: except discord.NotFound:
@ -426,11 +425,11 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
.. versionadded:: 2.0 .. versionadded:: 2.0
""" """
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.abc.GuildChannel: async def convert(self, ctx: Context[BotT], argument: str) -> discord.abc.GuildChannel:
return self._resolve_channel(ctx, argument, 'channels', discord.abc.GuildChannel) return self._resolve_channel(ctx, argument, 'channels', discord.abc.GuildChannel)
@staticmethod @staticmethod
def _resolve_channel(ctx: Context, argument: str, attribute: str, type: Type[CT]) -> CT: def _resolve_channel(ctx: Context[BotT], argument: str, attribute: str, type: Type[CT]) -> CT:
bot = ctx.bot bot = ctx.bot
match = IDConverter._get_id_match(argument) or re.match(r'<#([0-9]{15,20})>$', argument) match = IDConverter._get_id_match(argument) or re.match(r'<#([0-9]{15,20})>$', argument)
@ -447,7 +446,7 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
def check(c): def check(c):
return isinstance(c, type) and c.name == argument return isinstance(c, type) and c.name == argument
result = discord.utils.find(check, bot.get_all_channels()) result = discord.utils.find(check, bot.get_all_channels()) # type: ignore
else: else:
channel_id = int(match.group(1)) channel_id = int(match.group(1))
if guild: if guild:
@ -462,7 +461,7 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
return result return result
@staticmethod @staticmethod
def _resolve_thread(ctx: Context, argument: str, attribute: str, type: Type[TT]) -> TT: def _resolve_thread(ctx: Context[BotT], argument: str, attribute: str, type: Type[TT]) -> TT:
bot = ctx.bot bot = ctx.bot
match = IDConverter._get_id_match(argument) or re.match(r'<#([0-9]{15,20})>$', argument) match = IDConverter._get_id_match(argument) or re.match(r'<#([0-9]{15,20})>$', argument)
@ -501,7 +500,7 @@ class TextChannelConverter(IDConverter[discord.TextChannel]):
Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument` Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument`
""" """
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.TextChannel: async def convert(self, ctx: Context[BotT], argument: str) -> discord.TextChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'text_channels', discord.TextChannel) return GuildChannelConverter._resolve_channel(ctx, argument, 'text_channels', discord.TextChannel)
@ -521,7 +520,7 @@ class VoiceChannelConverter(IDConverter[discord.VoiceChannel]):
Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument` Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument`
""" """
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.VoiceChannel: async def convert(self, ctx: Context[BotT], argument: str) -> discord.VoiceChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'voice_channels', discord.VoiceChannel) return GuildChannelConverter._resolve_channel(ctx, argument, 'voice_channels', discord.VoiceChannel)
@ -540,7 +539,7 @@ class StageChannelConverter(IDConverter[discord.StageChannel]):
3. Lookup by name 3. Lookup by name
""" """
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.StageChannel: async def convert(self, ctx: Context[BotT], argument: str) -> discord.StageChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'stage_channels', discord.StageChannel) return GuildChannelConverter._resolve_channel(ctx, argument, 'stage_channels', discord.StageChannel)
@ -560,7 +559,7 @@ class CategoryChannelConverter(IDConverter[discord.CategoryChannel]):
Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument` Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument`
""" """
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.CategoryChannel: async def convert(self, ctx: Context[BotT], argument: str) -> discord.CategoryChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'categories', discord.CategoryChannel) return GuildChannelConverter._resolve_channel(ctx, argument, 'categories', discord.CategoryChannel)
@ -579,7 +578,7 @@ class StoreChannelConverter(IDConverter[discord.StoreChannel]):
.. versionadded:: 1.7 .. versionadded:: 1.7
""" """
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.StoreChannel: async def convert(self, ctx: Context[BotT], argument: str) -> discord.StoreChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'channels', discord.StoreChannel) return GuildChannelConverter._resolve_channel(ctx, argument, 'channels', discord.StoreChannel)
@ -597,7 +596,7 @@ class ThreadConverter(IDConverter[discord.Thread]):
.. versionadded: 2.0 .. versionadded: 2.0
""" """
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Thread: async def convert(self, ctx: Context[BotT], argument: str) -> discord.Thread:
return GuildChannelConverter._resolve_thread(ctx, argument, 'threads', discord.Thread) return GuildChannelConverter._resolve_thread(ctx, argument, 'threads', discord.Thread)
@ -629,7 +628,7 @@ class ColourConverter(Converter[discord.Colour]):
RGB_REGEX = re.compile(r'rgb\s*\((?P<r>[0-9]{1,3}%?)\s*,\s*(?P<g>[0-9]{1,3}%?)\s*,\s*(?P<b>[0-9]{1,3}%?)\s*\)') RGB_REGEX = re.compile(r'rgb\s*\((?P<r>[0-9]{1,3}%?)\s*,\s*(?P<g>[0-9]{1,3}%?)\s*,\s*(?P<b>[0-9]{1,3}%?)\s*\)')
def parse_hex_number(self, argument): def parse_hex_number(self, argument: str) -> discord.Colour:
arg = ''.join(i * 2 for i in argument) if len(argument) == 3 else argument arg = ''.join(i * 2 for i in argument) if len(argument) == 3 else argument
try: try:
value = int(arg, base=16) value = int(arg, base=16)
@ -640,7 +639,7 @@ class ColourConverter(Converter[discord.Colour]):
else: else:
return discord.Color(value=value) return discord.Color(value=value)
def parse_rgb_number(self, argument, number): def parse_rgb_number(self, argument: str, number: str) -> int:
if number[-1] == '%': if number[-1] == '%':
value = int(number[:-1]) value = int(number[:-1])
if not (0 <= value <= 100): if not (0 <= value <= 100):
@ -652,7 +651,7 @@ class ColourConverter(Converter[discord.Colour]):
raise BadColourArgument(argument) raise BadColourArgument(argument)
return value return value
def parse_rgb(self, argument, *, regex=RGB_REGEX): def parse_rgb(self, argument: str, *, regex: re.Pattern[str] = RGB_REGEX) -> discord.Colour:
match = regex.match(argument) match = regex.match(argument)
if match is None: if match is None:
raise BadColourArgument(argument) raise BadColourArgument(argument)
@ -662,7 +661,7 @@ class ColourConverter(Converter[discord.Colour]):
blue = self.parse_rgb_number(argument, match.group('b')) blue = self.parse_rgb_number(argument, match.group('b'))
return discord.Color.from_rgb(red, green, blue) return discord.Color.from_rgb(red, green, blue)
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Colour: async def convert(self, ctx: Context[BotT], argument: str) -> discord.Colour:
if argument[0] == '#': if argument[0] == '#':
return self.parse_hex_number(argument[1:]) return self.parse_hex_number(argument[1:])
@ -703,7 +702,7 @@ class RoleConverter(IDConverter[discord.Role]):
Raise :exc:`.RoleNotFound` instead of generic :exc:`.BadArgument` Raise :exc:`.RoleNotFound` instead of generic :exc:`.BadArgument`
""" """
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Role: async def convert(self, ctx: Context[BotT], argument: str) -> discord.Role:
guild = ctx.guild guild = ctx.guild
if not guild: if not guild:
raise NoPrivateMessage() raise NoPrivateMessage()
@ -722,7 +721,7 @@ class RoleConverter(IDConverter[discord.Role]):
class GameConverter(Converter[discord.Game]): class GameConverter(Converter[discord.Game]):
"""Converts to :class:`~discord.Game`.""" """Converts to :class:`~discord.Game`."""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Game: async def convert(self, ctx: Context[BotT], argument: str) -> discord.Game:
return discord.Game(name=argument) return discord.Game(name=argument)
@ -735,7 +734,7 @@ class InviteConverter(Converter[discord.Invite]):
Raise :exc:`.BadInviteArgument` instead of generic :exc:`.BadArgument` Raise :exc:`.BadInviteArgument` instead of generic :exc:`.BadArgument`
""" """
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Invite: async def convert(self, ctx: Context[BotT], argument: str) -> discord.Invite:
try: try:
invite = await ctx.bot.fetch_invite(argument) invite = await ctx.bot.fetch_invite(argument)
return invite return invite
@ -754,7 +753,7 @@ class GuildConverter(IDConverter[discord.Guild]):
.. versionadded:: 1.7 .. versionadded:: 1.7
""" """
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Guild: async def convert(self, ctx: Context[BotT], argument: str) -> discord.Guild:
match = self._get_id_match(argument) match = self._get_id_match(argument)
result = None result = None
@ -786,7 +785,7 @@ class EmojiConverter(IDConverter[discord.Emoji]):
Raise :exc:`.EmojiNotFound` instead of generic :exc:`.BadArgument` Raise :exc:`.EmojiNotFound` instead of generic :exc:`.BadArgument`
""" """
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Emoji: async def convert(self, ctx: Context[BotT], argument: str) -> discord.Emoji:
match = self._get_id_match(argument) or re.match(r'<a?:[a-zA-Z0-9\_]{1,32}:([0-9]{15,20})>$', argument) match = self._get_id_match(argument) or re.match(r'<a?:[a-zA-Z0-9\_]{1,32}:([0-9]{15,20})>$', argument)
result = None result = None
bot = ctx.bot bot = ctx.bot
@ -820,7 +819,7 @@ class PartialEmojiConverter(Converter[discord.PartialEmoji]):
Raise :exc:`.PartialEmojiConversionFailure` instead of generic :exc:`.BadArgument` Raise :exc:`.PartialEmojiConversionFailure` instead of generic :exc:`.BadArgument`
""" """
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.PartialEmoji: async def convert(self, ctx: Context[BotT], argument: str) -> discord.PartialEmoji:
match = re.match(r'<(a?):([a-zA-Z0-9\_]{1,32}):([0-9]{15,20})>$', argument) match = re.match(r'<(a?):([a-zA-Z0-9\_]{1,32}):([0-9]{15,20})>$', argument)
if match: if match:
@ -844,12 +843,12 @@ class GuildStickerConverter(IDConverter[discord.GuildSticker]):
The lookup strategy is as follows (in order): The lookup strategy is as follows (in order):
1. Lookup by ID. 1. Lookup by ID.
3. Lookup by name 2. Lookup by name.
.. versionadded:: 2.0 .. versionadded:: 2.0
""" """
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.GuildSticker: async def convert(self, ctx: Context[BotT], argument: str) -> discord.GuildSticker:
match = self._get_id_match(argument) match = self._get_id_match(argument)
result = None result = None
bot = ctx.bot bot = ctx.bot
@ -874,6 +873,65 @@ class GuildStickerConverter(IDConverter[discord.GuildSticker]):
return result return result
class ScheduledEventConverter(IDConverter[discord.ScheduledEvent]):
"""Converts to a :class:`~discord.ScheduledEvent`.
Lookups are done for the local guild if available. Otherwise, for a DM context,
lookup is done by the global cache.
The lookup strategy is as follows (in order):
1. Lookup by ID.
2. Lookup by url.
3. Lookup by name.
.. versionadded:: 2.0
"""
async def convert(self, ctx: Context[BotT], argument: str) -> discord.ScheduledEvent:
guild = ctx.guild
match = self._get_id_match(argument)
result = None
if match:
# ID match
event_id = int(match.group(1))
if guild:
result = guild.get_scheduled_event(event_id)
else:
for guild in ctx.bot.guilds:
result = guild.get_scheduled_event(event_id)
if result:
break
else:
pattern = (
r'https?://(?:(ptb|canary|www)\.)?discord\.com/events/'
r'(?P<guild_id>[0-9]{15,20})/'
r'(?P<event_id>[0-9]{15,20})$'
)
match = re.match(pattern, argument, flags=re.I)
if match:
# URL match
guild = ctx.bot.get_guild(int(match.group('guild_id')))
if guild:
event_id = int(match.group('event_id'))
result = guild.get_scheduled_event(event_id)
else:
# lookup by name
if guild:
result = discord.utils.get(guild.scheduled_events, name=argument)
else:
for guild in ctx.bot.guilds:
result = discord.utils.get(guild.scheduled_events, name=argument)
if result:
break
if result is None:
raise ScheduledEventNotFound(argument)
return result
class clean_content(Converter[str]): class clean_content(Converter[str]):
"""Converts the argument to mention scrubbed version of """Converts the argument to mention scrubbed version of
said content. said content.
@ -907,7 +965,7 @@ class clean_content(Converter[str]):
self.escape_markdown = escape_markdown self.escape_markdown = escape_markdown
self.remove_markdown = remove_markdown self.remove_markdown = remove_markdown
async def convert(self, ctx: Context[_Bot], argument: str) -> str: async def convert(self, ctx: Context[BotT], argument: str) -> str:
msg = ctx.message msg = ctx.message
if ctx.guild: if ctx.guild:
@ -924,7 +982,7 @@ class clean_content(Converter[str]):
def resolve_member(id: int) -> str: def resolve_member(id: int) -> str:
m = _utils_get(msg.mentions, id=id) or ctx.bot.get_user(id) m = _utils_get(msg.mentions, id=id) or ctx.bot.get_user(id)
return f'@{m.name}' if m else '@deleted-user' return f'@{m.display_name}' if m else '@deleted-user'
def resolve_role(id: int) -> str: def resolve_role(id: int) -> str:
return '@deleted-role' return '@deleted-role'
@ -932,7 +990,7 @@ class clean_content(Converter[str]):
if self.fix_channel_mentions and ctx.guild: if self.fix_channel_mentions and ctx.guild:
def resolve_channel(id: int) -> str: def resolve_channel(id: int) -> str:
c = ctx.guild.get_channel(id) # type: ignore c = ctx.guild._resolve_channel(id) # type: ignore
return f'#{c.name}' if c else '#deleted-channel' return f'#{c.name}' if c else '#deleted-channel'
else: else:
@ -987,10 +1045,10 @@ class Greedy(List[T]):
__slots__ = ('converter',) __slots__ = ('converter',)
def __init__(self, *, converter: T): def __init__(self, *, converter: T) -> None:
self.converter = converter self.converter: T = converter
def __repr__(self): def __repr__(self) -> str:
converter = getattr(self.converter, '__name__', repr(self.converter)) converter = getattr(self.converter, '__name__', repr(self.converter))
return f'Greedy[{converter}]' return f'Greedy[{converter}]'
@ -1039,11 +1097,11 @@ def get_converter(param: inspect.Parameter) -> Any:
_GenericAlias = type(List[T]) _GenericAlias = type(List[T])
def is_generic_type(tp: Any, *, _GenericAlias: Type = _GenericAlias) -> bool: def is_generic_type(tp: Any, *, _GenericAlias: type = _GenericAlias) -> bool:
return isinstance(tp, type) and issubclass(tp, Generic) or isinstance(tp, _GenericAlias) # type: ignore return isinstance(tp, type) and issubclass(tp, Generic) or isinstance(tp, _GenericAlias)
CONVERTER_MAPPING: Dict[Type[Any], Any] = { CONVERTER_MAPPING: Dict[type, Any] = {
discord.Object: ObjectConverter, discord.Object: ObjectConverter,
discord.Member: MemberConverter, discord.Member: MemberConverter,
discord.User: UserConverter, discord.User: UserConverter,
@ -1064,10 +1122,11 @@ CONVERTER_MAPPING: Dict[Type[Any], Any] = {
discord.Thread: ThreadConverter, discord.Thread: ThreadConverter,
discord.abc.GuildChannel: GuildChannelConverter, discord.abc.GuildChannel: GuildChannelConverter,
discord.GuildSticker: GuildStickerConverter, discord.GuildSticker: GuildStickerConverter,
discord.ScheduledEvent: ScheduledEventConverter,
} }
async def _actual_conversion(ctx: Context, converter, argument: str, param: inspect.Parameter): async def _actual_conversion(ctx: Context[BotT], converter, argument: str, param: inspect.Parameter):
if converter is bool: if converter is bool:
return _convert_to_bool(argument) return _convert_to_bool(argument)
@ -1105,7 +1164,7 @@ async def _actual_conversion(ctx: Context, converter, argument: str, param: insp
raise BadArgument(f'Converting to "{name}" failed for parameter "{param.name}".') from exc raise BadArgument(f'Converting to "{name}" failed for parameter "{param.name}".') from exc
async def run_converters(ctx: Context, converter, argument: str, param: inspect.Parameter): async def run_converters(ctx: Context[BotT], converter: Any, argument: str, param: inspect.Parameter) -> Any:
"""|coro| """|coro|
Runs converters for a given converter, argument, and parameter. Runs converters for a given converter, argument, and parameter.

4
discord/ext/commands/cooldowns.py

@ -220,7 +220,7 @@ class CooldownMapping:
return self._type return self._type
@classmethod @classmethod
def from_cooldown(cls, rate, per, type) -> Self: def from_cooldown(cls, rate: float, per: float, type: Callable[[Message], Any]) -> Self:
return cls(Cooldown(rate, per), type) return cls(Cooldown(rate, per), type)
def _bucket_key(self, msg: Message) -> Any: def _bucket_key(self, msg: Message) -> Any:
@ -297,7 +297,7 @@ class _Semaphore:
def __init__(self, number: int) -> None: def __init__(self, number: int) -> None:
self.value: int = number self.value: int = number
self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() self.loop: asyncio.AbstractEventLoop = asyncio.get_running_loop()
self._waiters: Deque[asyncio.Future] = deque() self._waiters: Deque[asyncio.Future] = deque()
def __repr__(self) -> str: def __repr__(self) -> str:

195
discord/ext/commands/core.py

@ -61,11 +61,15 @@ if TYPE_CHECKING:
from discord.message import Message from discord.message import Message
from ._types import ( from ._types import (
BotT,
ContextT,
Coro, Coro,
CoroFunc, CoroFunc,
Check, Check,
Hook, Hook,
Error, Error,
ErrorT,
HookT,
) )
@ -101,10 +105,8 @@ MISSING: Any = discord.utils.MISSING
T = TypeVar('T') T = TypeVar('T')
CogT = TypeVar('CogT', bound='Optional[Cog]') CogT = TypeVar('CogT', bound='Optional[Cog]')
CommandT = TypeVar('CommandT', bound='Command') CommandT = TypeVar('CommandT', bound='Command')
ContextT = TypeVar('ContextT', bound='Context')
# CHT = TypeVar('CHT', bound='Check') # CHT = TypeVar('CHT', bound='Check')
GroupT = TypeVar('GroupT', bound='Group') GroupT = TypeVar('GroupT', bound='Group')
FuncT = TypeVar('FuncT', bound=Callable[..., Any])
if TYPE_CHECKING: if TYPE_CHECKING:
P = ParamSpec('P') P = ParamSpec('P')
@ -112,7 +114,7 @@ else:
P = TypeVar('P') P = TypeVar('P')
def unwrap_function(function: Callable[..., Any]) -> Callable[..., Any]: def unwrap_function(function: Callable[..., Any], /) -> Callable[..., Any]:
partial = functools.partial partial = functools.partial
while True: while True:
if hasattr(function, '__wrapped__'): if hasattr(function, '__wrapped__'):
@ -126,6 +128,7 @@ def unwrap_function(function: Callable[..., Any]) -> Callable[..., Any]:
def get_signature_parameters( def get_signature_parameters(
function: Callable[..., Any], function: Callable[..., Any],
globalns: Dict[str, Any], globalns: Dict[str, Any],
/,
*, *,
skip_parameters: Optional[int] = None, skip_parameters: Optional[int] = None,
) -> Dict[str, inspect.Parameter]: ) -> Dict[str, inspect.Parameter]:
@ -159,9 +162,9 @@ def get_signature_parameters(
return params return params
def wrap_callback(coro): def wrap_callback(coro: Callable[P, Coro[T]], /) -> Callable[P, Coro[Optional[T]]]:
@functools.wraps(coro) @functools.wraps(coro)
async def wrapped(*args, **kwargs): async def wrapped(*args: P.args, **kwargs: P.kwargs) -> Optional[T]:
try: try:
ret = await coro(*args, **kwargs) ret = await coro(*args, **kwargs)
except CommandError: except CommandError:
@ -175,9 +178,11 @@ def wrap_callback(coro):
return wrapped return wrapped
def hooked_wrapped_callback(command, ctx, coro): def hooked_wrapped_callback(
command: Command[Any, ..., Any], ctx: Context[BotT], coro: Callable[P, Coro[T]], /
) -> Callable[P, Coro[Optional[T]]]:
@functools.wraps(coro) @functools.wraps(coro)
async def wrapped(*args, **kwargs): async def wrapped(*args: P.args, **kwargs: P.kwargs) -> Optional[T]:
try: try:
ret = await coro(*args, **kwargs) ret = await coro(*args, **kwargs)
except CommandError: except CommandError:
@ -191,7 +196,7 @@ def hooked_wrapped_callback(command, ctx, coro):
raise CommandInvokeError(exc) from exc raise CommandInvokeError(exc) from exc
finally: finally:
if command._max_concurrency is not None: if command._max_concurrency is not None:
await command._max_concurrency.release(ctx) await command._max_concurrency.release(ctx.message)
await command.call_after_hooks(ctx) await command.call_after_hooks(ctx)
return ret return ret
@ -318,6 +323,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
Callable[Concatenate[CogT, ContextT, P], Coro[T]], Callable[Concatenate[CogT, ContextT, P], Coro[T]],
Callable[Concatenate[ContextT, P], Coro[T]], Callable[Concatenate[ContextT, P], Coro[T]],
], ],
/,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
if not asyncio.iscoroutinefunction(func): if not asyncio.iscoroutinefunction(func):
@ -359,7 +365,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
except AttributeError: except AttributeError:
checks = kwargs.get('checks', []) checks = kwargs.get('checks', [])
self.checks: List[Check] = checks self.checks: List[Check[ContextT]] = checks
try: try:
cooldown = func.__commands_cooldown__ cooldown = func.__commands_cooldown__
@ -387,8 +393,8 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
self.cog: CogT = None self.cog: CogT = None
# bandaid for the fact that sometimes parent can be the bot instance # bandaid for the fact that sometimes parent can be the bot instance
parent = kwargs.get('parent') parent: Optional[GroupMixin[Any]] = kwargs.get('parent')
self.parent: Optional[GroupMixin] = parent if isinstance(parent, _BaseCommand) else None # type: ignore self.parent: Optional[GroupMixin[Any]] = parent if isinstance(parent, _BaseCommand) else None
self._before_invoke: Optional[Hook] = None self._before_invoke: Optional[Hook] = None
try: try:
@ -422,16 +428,16 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
) -> None: ) -> None:
self._callback = function self._callback = function
unwrap = unwrap_function(function) unwrap = unwrap_function(function)
self.module = unwrap.__module__ self.module: str = unwrap.__module__
try: try:
globalns = unwrap.__globals__ globalns = unwrap.__globals__
except AttributeError: except AttributeError:
globalns = {} globalns = {}
self.params = get_signature_parameters(function, globalns) self.params: Dict[str, inspect.Parameter] = get_signature_parameters(function, globalns)
def add_check(self, func: Check, /) -> None: def add_check(self, func: Check[ContextT], /) -> None:
"""Adds a check to the command. """Adds a check to the command.
This is the non-decorator interface to :func:`.check`. This is the non-decorator interface to :func:`.check`.
@ -450,7 +456,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
self.checks.append(func) self.checks.append(func)
def remove_check(self, func: Check, /) -> None: def remove_check(self, func: Check[ContextT], /) -> None:
"""Removes a check from the command. """Removes a check from the command.
This function is idempotent and will not raise an exception This function is idempotent and will not raise an exception
@ -476,7 +482,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
def update(self, **kwargs: Any) -> None: def update(self, **kwargs: Any) -> None:
"""Updates :class:`Command` instance with updated attribute. """Updates :class:`Command` instance with updated attribute.
This works similarly to the :func:`.command` decorator in terms This works similarly to the :func:`~discord.ext.commands.command` decorator in terms
of parameters in that they are passed to the :class:`Command` or of parameters in that they are passed to the :class:`Command` or
subclass constructors, sans the name and callback. subclass constructors, sans the name and callback.
""" """
@ -484,7 +490,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
self.__init__(self.callback, **dict(self.__original_kwargs__, **kwargs)) self.__init__(self.callback, **dict(self.__original_kwargs__, **kwargs))
self.cog = cog self.cog = cog
async def __call__(self, context: Context, *args: P.args, **kwargs: P.kwargs) -> T: async def __call__(self, context: Context[BotT], /, *args: P.args, **kwargs: P.kwargs) -> T:
"""|coro| """|coro|
Calls the internal callback that the command holds. Calls the internal callback that the command holds.
@ -496,6 +502,10 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
the proper arguments and types to this function. the proper arguments and types to this function.
.. versionadded:: 1.3 .. versionadded:: 1.3
.. versionchanged:: 2.0
``context`` parameter is now positional-only.
""" """
if self.cog is not None: if self.cog is not None:
return await self.callback(self.cog, context, *args, **kwargs) # type: ignore return await self.callback(self.cog, context, *args, **kwargs) # type: ignore
@ -539,7 +549,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
else: else:
return self.copy() return self.copy()
async def dispatch_error(self, ctx: Context, error: Exception) -> None: async def dispatch_error(self, ctx: Context[BotT], error: CommandError, /) -> None:
ctx.command_failed = True ctx.command_failed = True
cog = self.cog cog = self.cog
try: try:
@ -549,7 +559,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
else: else:
injected = wrap_callback(coro) injected = wrap_callback(coro)
if cog is not None: if cog is not None:
await injected(cog, ctx, error) await injected(cog, ctx, error) # type: ignore
else: else:
await injected(ctx, error) await injected(ctx, error)
@ -562,7 +572,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
finally: finally:
ctx.bot.dispatch('command_error', ctx, error) ctx.bot.dispatch('command_error', ctx, error)
async def transform(self, ctx: Context, param: inspect.Parameter) -> Any: async def transform(self, ctx: Context[BotT], param: inspect.Parameter, /) -> Any:
required = param.default is param.empty required = param.default is param.empty
converter = get_converter(param) converter = get_converter(param)
consume_rest_is_special = param.kind == param.KEYWORD_ONLY and not self.rest_is_raw consume_rest_is_special = param.kind == param.KEYWORD_ONLY and not self.rest_is_raw
@ -610,7 +620,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
# type-checker fails to narrow argument # type-checker fails to narrow argument
return await run_converters(ctx, converter, argument, param) # type: ignore return await run_converters(ctx, converter, argument, param) # type: ignore
async def _transform_greedy_pos(self, ctx: Context, param: inspect.Parameter, required: bool, converter: Any) -> Any: async def _transform_greedy_pos(
self, ctx: Context[BotT], param: inspect.Parameter, required: bool, converter: Any
) -> Any:
view = ctx.view view = ctx.view
result = [] result = []
while not view.eof: while not view.eof:
@ -631,7 +643,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
return param.default return param.default
return result return result
async def _transform_greedy_var_pos(self, ctx: Context, param: inspect.Parameter, converter: Any) -> Any: async def _transform_greedy_var_pos(self, ctx: Context[BotT], param: inspect.Parameter, converter: Any) -> Any:
view = ctx.view view = ctx.view
previous = view.index previous = view.index
try: try:
@ -669,7 +681,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
return ' '.join(reversed(entries)) return ' '.join(reversed(entries))
@property @property
def parents(self) -> List[Group]: def parents(self) -> List[Group[Any, ..., Any]]:
"""List[:class:`Group`]: Retrieves the parents of this command. """List[:class:`Group`]: Retrieves the parents of this command.
If the command has no parents then it returns an empty :class:`list`. If the command has no parents then it returns an empty :class:`list`.
@ -687,7 +699,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
return entries return entries
@property @property
def root_parent(self) -> Optional[Group]: def root_parent(self) -> Optional[Group[Any, ..., Any]]:
"""Optional[:class:`Group`]: Retrieves the root parent of this command. """Optional[:class:`Group`]: Retrieves the root parent of this command.
If the command has no parents then it returns ``None``. If the command has no parents then it returns ``None``.
@ -716,7 +728,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
def __str__(self) -> str: def __str__(self) -> str:
return self.qualified_name return self.qualified_name
async def _parse_arguments(self, ctx: Context) -> None: async def _parse_arguments(self, ctx: Context[BotT]) -> None:
ctx.args = [ctx] if self.cog is None else [self.cog, ctx] ctx.args = [ctx] if self.cog is None else [self.cog, ctx]
ctx.kwargs = {} ctx.kwargs = {}
args = ctx.args args = ctx.args
@ -752,7 +764,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
if not self.ignore_extra and not view.eof: if not self.ignore_extra and not view.eof:
raise TooManyArguments('Too many arguments passed to ' + self.qualified_name) raise TooManyArguments('Too many arguments passed to ' + self.qualified_name)
async def call_before_hooks(self, ctx: Context) -> None: async def call_before_hooks(self, ctx: Context[BotT], /) -> None:
# now that we're done preparing we can call the pre-command hooks # now that we're done preparing we can call the pre-command hooks
# first, call the command local hook: # first, call the command local hook:
cog = self.cog cog = self.cog
@ -777,7 +789,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
if hook is not None: if hook is not None:
await hook(ctx) await hook(ctx)
async def call_after_hooks(self, ctx: Context) -> None: async def call_after_hooks(self, ctx: Context[BotT], /) -> None:
cog = self.cog cog = self.cog
if self._after_invoke is not None: if self._after_invoke is not None:
instance = getattr(self._after_invoke, '__self__', cog) instance = getattr(self._after_invoke, '__self__', cog)
@ -796,7 +808,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
if hook is not None: if hook is not None:
await hook(ctx) await hook(ctx)
def _prepare_cooldowns(self, ctx: Context) -> None: def _prepare_cooldowns(self, ctx: Context[BotT]) -> None:
if self._buckets.valid: if self._buckets.valid:
dt = ctx.message.edited_at or ctx.message.created_at dt = ctx.message.edited_at or ctx.message.created_at
current = dt.replace(tzinfo=datetime.timezone.utc).timestamp() current = dt.replace(tzinfo=datetime.timezone.utc).timestamp()
@ -806,7 +818,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
if retry_after: if retry_after:
raise CommandOnCooldown(bucket, retry_after, self._buckets.type) # type: ignore raise CommandOnCooldown(bucket, retry_after, self._buckets.type) # type: ignore
async def prepare(self, ctx: Context) -> None: async def prepare(self, ctx: Context[BotT], /) -> None:
ctx.command = self ctx.command = self
if not await self.can_run(ctx): if not await self.can_run(ctx):
@ -830,9 +842,13 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
await self._max_concurrency.release(ctx) # type: ignore await self._max_concurrency.release(ctx) # type: ignore
raise raise
def is_on_cooldown(self, ctx: Context) -> bool: def is_on_cooldown(self, ctx: Context[BotT], /) -> bool:
"""Checks whether the command is currently on cooldown. """Checks whether the command is currently on cooldown.
.. versionchanged:: 2.0
``ctx`` parameter is now positional-only.
Parameters Parameters
----------- -----------
ctx: :class:`.Context` ctx: :class:`.Context`
@ -851,9 +867,13 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
current = dt.replace(tzinfo=datetime.timezone.utc).timestamp() current = dt.replace(tzinfo=datetime.timezone.utc).timestamp()
return bucket.get_tokens(current) == 0 return bucket.get_tokens(current) == 0
def reset_cooldown(self, ctx: Context) -> None: def reset_cooldown(self, ctx: Context[BotT], /) -> None:
"""Resets the cooldown on this command. """Resets the cooldown on this command.
.. versionchanged:: 2.0
``ctx`` parameter is now positional-only.
Parameters Parameters
----------- -----------
ctx: :class:`.Context` ctx: :class:`.Context`
@ -863,11 +883,15 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
bucket = self._buckets.get_bucket(ctx.message) bucket = self._buckets.get_bucket(ctx.message)
bucket.reset() bucket.reset()
def get_cooldown_retry_after(self, ctx: Context) -> float: def get_cooldown_retry_after(self, ctx: Context[BotT], /) -> float:
"""Retrieves the amount of seconds before this command can be tried again. """Retrieves the amount of seconds before this command can be tried again.
.. versionadded:: 1.4 .. versionadded:: 1.4
.. versionchanged:: 2.0
``ctx`` parameter is now positional-only.
Parameters Parameters
----------- -----------
ctx: :class:`.Context` ctx: :class:`.Context`
@ -887,7 +911,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
return 0.0 return 0.0
async def invoke(self, ctx: Context) -> None: async def invoke(self, ctx: Context[BotT], /) -> None:
await self.prepare(ctx) await self.prepare(ctx)
# terminate the invoked_subcommand chain. # terminate the invoked_subcommand chain.
@ -896,9 +920,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
ctx.invoked_subcommand = None ctx.invoked_subcommand = None
ctx.subcommand_passed = None ctx.subcommand_passed = None
injected = hooked_wrapped_callback(self, ctx, self.callback) injected = hooked_wrapped_callback(self, ctx, self.callback)
await injected(*ctx.args, **ctx.kwargs) await injected(*ctx.args, **ctx.kwargs) # type: ignore
async def reinvoke(self, ctx: Context, *, call_hooks: bool = False) -> None: async def reinvoke(self, ctx: Context[BotT], /, *, call_hooks: bool = False) -> None:
ctx.command = self ctx.command = self
await self._parse_arguments(ctx) await self._parse_arguments(ctx)
@ -915,13 +939,17 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
if call_hooks: if call_hooks:
await self.call_after_hooks(ctx) await self.call_after_hooks(ctx)
def error(self, coro: FuncT) -> FuncT: def error(self, coro: ErrorT, /) -> ErrorT:
"""A decorator that registers a coroutine as a local error handler. """A decorator that registers a coroutine as a local error handler.
A local error handler is an :func:`.on_command_error` event limited to A local error handler is an :func:`.on_command_error` event limited to
a single command. However, the :func:`.on_command_error` is still a single command. However, the :func:`.on_command_error` is still
invoked afterwards as the catch-all. invoked afterwards as the catch-all.
.. versionchanged:: 2.0
``coro`` parameter is now positional-only.
Parameters Parameters
----------- -----------
coro: :ref:`coroutine <coroutine>` coro: :ref:`coroutine <coroutine>`
@ -936,7 +964,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
if not asyncio.iscoroutinefunction(coro): if not asyncio.iscoroutinefunction(coro):
raise TypeError('The error handler must be a coroutine.') raise TypeError('The error handler must be a coroutine.')
self.on_error: Error = coro self.on_error: Error[Any] = coro
return coro return coro
def has_error_handler(self) -> bool: def has_error_handler(self) -> bool:
@ -946,7 +974,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
""" """
return hasattr(self, 'on_error') return hasattr(self, 'on_error')
def before_invoke(self, coro: FuncT) -> FuncT: def before_invoke(self, coro: HookT, /) -> HookT:
"""A decorator that registers a coroutine as a pre-invoke hook. """A decorator that registers a coroutine as a pre-invoke hook.
A pre-invoke hook is called directly before the command is A pre-invoke hook is called directly before the command is
@ -957,6 +985,10 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
See :meth:`.Bot.before_invoke` for more info. See :meth:`.Bot.before_invoke` for more info.
.. versionchanged:: 2.0
``coro`` parameter is now positional-only.
Parameters Parameters
----------- -----------
coro: :ref:`coroutine <coroutine>` coro: :ref:`coroutine <coroutine>`
@ -973,7 +1005,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
self._before_invoke = coro self._before_invoke = coro
return coro return coro
def after_invoke(self, coro: FuncT) -> FuncT: def after_invoke(self, coro: HookT, /) -> HookT:
"""A decorator that registers a coroutine as a post-invoke hook. """A decorator that registers a coroutine as a post-invoke hook.
A post-invoke hook is called directly after the command is A post-invoke hook is called directly after the command is
@ -984,6 +1016,10 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
See :meth:`.Bot.after_invoke` for more info. See :meth:`.Bot.after_invoke` for more info.
.. versionchanged:: 2.0
``coro`` parameter is now positional-only.
Parameters Parameters
----------- -----------
coro: :ref:`coroutine <coroutine>` coro: :ref:`coroutine <coroutine>`
@ -1075,7 +1111,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
return ' '.join(result) return ' '.join(result)
async def can_run(self, ctx: Context) -> bool: async def can_run(self, ctx: Context[BotT], /) -> bool:
"""|coro| """|coro|
Checks if the command can be executed by checking all the predicates Checks if the command can be executed by checking all the predicates
@ -1085,6 +1121,10 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
.. versionchanged:: 1.3 .. versionchanged:: 1.3
Checks whether the command is disabled or not Checks whether the command is disabled or not
.. versionchanged:: 2.0
``ctx`` parameter is now positional-only.
Parameters Parameters
----------- -----------
ctx: :class:`.Context` ctx: :class:`.Context`
@ -1341,11 +1381,11 @@ class GroupMixin(Generic[CogT]):
def command( def command(
self, self,
name: str = MISSING, name: str = MISSING,
cls: Type[Command] = MISSING, cls: Type[Command[Any, ..., Any]] = MISSING,
*args: Any, *args: Any,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""A shortcut decorator that invokes :func:`.command` and adds it to """A shortcut decorator that invokes :func:`~discord.ext.commands.command` and adds it to
the internal command list via :meth:`~.GroupMixin.add_command`. the internal command list via :meth:`~.GroupMixin.add_command`.
Returns Returns
@ -1401,7 +1441,7 @@ class GroupMixin(Generic[CogT]):
def group( def group(
self, self,
name: str = MISSING, name: str = MISSING,
cls: Type[Group] = MISSING, cls: Type[Group[Any, ..., Any]] = MISSING,
*args: Any, *args: Any,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
@ -1461,9 +1501,9 @@ class Group(GroupMixin[CogT], Command[CogT, P, T]):
ret = super().copy() ret = super().copy()
for cmd in self.commands: for cmd in self.commands:
ret.add_command(cmd.copy()) ret.add_command(cmd.copy())
return ret # type: ignore return ret
async def invoke(self, ctx: Context) -> None: async def invoke(self, ctx: Context[BotT], /) -> None:
ctx.invoked_subcommand = None ctx.invoked_subcommand = None
ctx.subcommand_passed = None ctx.subcommand_passed = None
early_invoke = not self.invoke_without_command early_invoke = not self.invoke_without_command
@ -1481,7 +1521,7 @@ class Group(GroupMixin[CogT], Command[CogT, P, T]):
if early_invoke: if early_invoke:
injected = hooked_wrapped_callback(self, ctx, self.callback) injected = hooked_wrapped_callback(self, ctx, self.callback)
await injected(*ctx.args, **ctx.kwargs) await injected(*ctx.args, **ctx.kwargs) # type: ignore
ctx.invoked_parents.append(ctx.invoked_with) # type: ignore ctx.invoked_parents.append(ctx.invoked_with) # type: ignore
@ -1494,7 +1534,7 @@ class Group(GroupMixin[CogT], Command[CogT, P, T]):
view.previous = previous view.previous = previous
await super().invoke(ctx) await super().invoke(ctx)
async def reinvoke(self, ctx: Context, *, call_hooks: bool = False) -> None: async def reinvoke(self, ctx: Context[BotT], /, *, call_hooks: bool = False) -> None:
ctx.invoked_subcommand = None ctx.invoked_subcommand = None
early_invoke = not self.invoke_without_command early_invoke = not self.invoke_without_command
if early_invoke: if early_invoke:
@ -1592,7 +1632,7 @@ def command(
def command( def command(
name: str = MISSING, name: str = MISSING,
cls: Type[Command] = MISSING, cls: Type[Command[Any, ..., Any]] = MISSING,
**attrs: Any, **attrs: Any,
) -> Any: ) -> Any:
"""A decorator that transforms a function into a :class:`.Command` """A decorator that transforms a function into a :class:`.Command`
@ -1662,12 +1702,12 @@ def group(
def group( def group(
name: str = MISSING, name: str = MISSING,
cls: Type[Group] = MISSING, cls: Type[Group[Any, ..., Any]] = MISSING,
**attrs: Any, **attrs: Any,
) -> Any: ) -> Any:
"""A decorator that transforms a function into a :class:`.Group`. """A decorator that transforms a function into a :class:`.Group`.
This is similar to the :func:`.command` decorator but the ``cls`` This is similar to the :func:`~discord.ext.commands.command` decorator but the ``cls``
parameter is set to :class:`Group` by default. parameter is set to :class:`Group` by default.
.. versionchanged:: 1.1 .. versionchanged:: 1.1
@ -1679,7 +1719,7 @@ def group(
return command(name=name, cls=cls, **attrs) return command(name=name, cls=cls, **attrs)
def check(predicate: Check) -> Callable[[T], T]: def check(predicate: Check[ContextT], /) -> Callable[[T], T]:
r"""A decorator that adds a check to the :class:`.Command` or its r"""A decorator that adds a check to the :class:`.Command` or its
subclasses. These checks could be accessed via :attr:`.Command.checks`. subclasses. These checks could be accessed via :attr:`.Command.checks`.
@ -1744,6 +1784,10 @@ def check(predicate: Check) -> Callable[[T], T]:
async def only_me(ctx): async def only_me(ctx):
await ctx.send('Only you!') await ctx.send('Only you!')
.. versionchanged:: 2.0
``predicate`` parameter is now positional-only.
Parameters Parameters
----------- -----------
predicate: Callable[[:class:`Context`], :class:`bool`] predicate: Callable[[:class:`Context`], :class:`bool`]
@ -1774,7 +1818,7 @@ def check(predicate: Check) -> Callable[[T], T]:
return decorator # type: ignore return decorator # type: ignore
def check_any(*checks: Check) -> Callable[[T], T]: def check_any(*checks: Check[ContextT]) -> Callable[[T], T]:
r"""A :func:`check` that is added that checks if any of the checks passed r"""A :func:`check` that is added that checks if any of the checks passed
will pass, i.e. using logical OR. will pass, i.e. using logical OR.
@ -1827,7 +1871,7 @@ def check_any(*checks: Check) -> Callable[[T], T]:
else: else:
unwrapped.append(pred) unwrapped.append(pred)
async def predicate(ctx: Context) -> bool: async def predicate(ctx: Context[BotT]) -> bool:
errors = [] errors = []
for func in unwrapped: for func in unwrapped:
try: try:
@ -1843,7 +1887,7 @@ def check_any(*checks: Check) -> Callable[[T], T]:
return check(predicate) return check(predicate)
def has_role(item: Union[int, str]) -> Callable[[T], T]: def has_role(item: Union[int, str], /) -> Callable[[T], T]:
"""A :func:`.check` that is added that checks if the member invoking the """A :func:`.check` that is added that checks if the member invoking the
command has the role specified via the name or ID specified. command has the role specified via the name or ID specified.
@ -1864,13 +1908,17 @@ def has_role(item: Union[int, str]) -> Callable[[T], T]:
Raise :exc:`.MissingRole` or :exc:`.NoPrivateMessage` Raise :exc:`.MissingRole` or :exc:`.NoPrivateMessage`
instead of generic :exc:`.CheckFailure` instead of generic :exc:`.CheckFailure`
.. versionchanged:: 2.0
``item`` parameter is now positional-only.
Parameters Parameters
----------- -----------
item: Union[:class:`int`, :class:`str`] item: Union[:class:`int`, :class:`str`]
The name or ID of the role to check. The name or ID of the role to check.
""" """
def predicate(ctx: Context) -> bool: def predicate(ctx: Context[BotT]) -> bool:
if ctx.guild is None: if ctx.guild is None:
raise NoPrivateMessage() raise NoPrivateMessage()
@ -1923,7 +1971,7 @@ def has_any_role(*items: Union[int, str]) -> Callable[[T], T]:
raise NoPrivateMessage() raise NoPrivateMessage()
# ctx.guild is None doesn't narrow ctx.author to Member # ctx.guild is None doesn't narrow ctx.author to Member
getter = functools.partial(discord.utils.get, ctx.author.roles) # type: ignore getter = functools.partial(discord.utils.get, ctx.author.roles)
if any(getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None for item in items): if any(getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None for item in items):
return True return True
raise MissingAnyRole(list(items)) raise MissingAnyRole(list(items))
@ -1931,7 +1979,7 @@ def has_any_role(*items: Union[int, str]) -> Callable[[T], T]:
return check(predicate) return check(predicate)
def bot_has_role(item: int) -> Callable[[T], T]: def bot_has_role(item: int, /) -> Callable[[T], T]:
"""Similar to :func:`.has_role` except checks if the bot itself has the """Similar to :func:`.has_role` except checks if the bot itself has the
role. role.
@ -1943,6 +1991,10 @@ def bot_has_role(item: int) -> Callable[[T], T]:
Raise :exc:`.BotMissingRole` or :exc:`.NoPrivateMessage` Raise :exc:`.BotMissingRole` or :exc:`.NoPrivateMessage`
instead of generic :exc:`.CheckFailure` instead of generic :exc:`.CheckFailure`
.. versionchanged:: 2.0
``item`` parameter is now positional-only.
""" """
def predicate(ctx): def predicate(ctx):
@ -2022,7 +2074,7 @@ def has_permissions(**perms: bool) -> Callable[[T], T]:
if invalid: if invalid:
raise TypeError(f"Invalid permission(s): {', '.join(invalid)}") raise TypeError(f"Invalid permission(s): {', '.join(invalid)}")
def predicate(ctx: Context) -> bool: def predicate(ctx: Context[BotT]) -> bool:
ch = ctx.channel ch = ctx.channel
permissions = ch.permissions_for(ctx.author) # type: ignore permissions = ch.permissions_for(ctx.author) # type: ignore
@ -2048,7 +2100,7 @@ def bot_has_permissions(**perms: bool) -> Callable[[T], T]:
if invalid: if invalid:
raise TypeError(f"Invalid permission(s): {', '.join(invalid)}") raise TypeError(f"Invalid permission(s): {', '.join(invalid)}")
def predicate(ctx: Context) -> bool: def predicate(ctx: Context[BotT]) -> bool:
guild = ctx.guild guild = ctx.guild
me = guild.me if guild is not None else ctx.bot.user me = guild.me if guild is not None else ctx.bot.user
permissions = ctx.channel.permissions_for(me) # type: ignore permissions = ctx.channel.permissions_for(me) # type: ignore
@ -2077,7 +2129,7 @@ def has_guild_permissions(**perms: bool) -> Callable[[T], T]:
if invalid: if invalid:
raise TypeError(f"Invalid permission(s): {', '.join(invalid)}") raise TypeError(f"Invalid permission(s): {', '.join(invalid)}")
def predicate(ctx: Context) -> bool: def predicate(ctx: Context[BotT]) -> bool:
if not ctx.guild: if not ctx.guild:
raise NoPrivateMessage raise NoPrivateMessage
@ -2103,7 +2155,7 @@ def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]:
if invalid: if invalid:
raise TypeError(f"Invalid permission(s): {', '.join(invalid)}") raise TypeError(f"Invalid permission(s): {', '.join(invalid)}")
def predicate(ctx: Context) -> bool: def predicate(ctx: Context[BotT]) -> bool:
if not ctx.guild: if not ctx.guild:
raise NoPrivateMessage raise NoPrivateMessage
@ -2129,7 +2181,7 @@ def dm_only() -> Callable[[T], T]:
.. versionadded:: 1.1 .. versionadded:: 1.1
""" """
def predicate(ctx: Context) -> bool: def predicate(ctx: Context[BotT]) -> bool:
if ctx.guild is not None: if ctx.guild is not None:
raise PrivateMessageOnly() raise PrivateMessageOnly()
return True return True
@ -2146,7 +2198,7 @@ def guild_only() -> Callable[[T], T]:
that is inherited from :exc:`.CheckFailure`. that is inherited from :exc:`.CheckFailure`.
""" """
def predicate(ctx: Context) -> bool: def predicate(ctx: Context[BotT]) -> bool:
if ctx.guild is None: if ctx.guild is None:
raise NoPrivateMessage() raise NoPrivateMessage()
return True return True
@ -2164,7 +2216,7 @@ def is_owner() -> Callable[[T], T]:
from :exc:`.CheckFailure`. from :exc:`.CheckFailure`.
""" """
async def predicate(ctx: Context) -> bool: async def predicate(ctx: Context[BotT]) -> bool:
if not await ctx.bot.is_owner(ctx.author): if not await ctx.bot.is_owner(ctx.author):
raise NotOwner('You do not own this bot.') raise NotOwner('You do not own this bot.')
return True return True
@ -2184,7 +2236,7 @@ def is_nsfw() -> Callable[[T], T]:
DM channels will also now pass this check. DM channels will also now pass this check.
""" """
def pred(ctx: Context) -> bool: def pred(ctx: Context[BotT]) -> bool:
ch = ctx.channel ch = ctx.channel
if ctx.guild is None or (isinstance(ch, (discord.TextChannel, discord.Thread)) and ch.is_nsfw()): if ctx.guild is None or (isinstance(ch, (discord.TextChannel, discord.Thread)) and ch.is_nsfw()):
return True return True
@ -2314,7 +2366,7 @@ def max_concurrency(number: int, per: BucketType = BucketType.default, *, wait:
return decorator # type: ignore return decorator # type: ignore
def before_invoke(coro) -> Callable[[T], T]: def before_invoke(coro: Hook[ContextT], /) -> Callable[[T], T]:
"""A decorator that registers a coroutine as a pre-invoke hook. """A decorator that registers a coroutine as a pre-invoke hook.
This allows you to refer to one before invoke hook for several commands that This allows you to refer to one before invoke hook for several commands that
@ -2322,6 +2374,10 @@ def before_invoke(coro) -> Callable[[T], T]:
.. versionadded:: 1.4 .. versionadded:: 1.4
.. versionchanged:: 2.0
``coro`` parameter is now positional-only.
Example Example
--------- ---------
@ -2350,7 +2406,6 @@ def before_invoke(coro) -> Callable[[T], T]:
async def why(self, ctx): # Output: <Nothing> async def why(self, ctx): # Output: <Nothing>
await ctx.send('because someone made me') await ctx.send('because someone made me')
bot.add_cog(What())
""" """
def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]: def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]:
@ -2363,13 +2418,17 @@ def before_invoke(coro) -> Callable[[T], T]:
return decorator # type: ignore return decorator # type: ignore
def after_invoke(coro) -> Callable[[T], T]: def after_invoke(coro: Hook[ContextT], /) -> Callable[[T], T]:
"""A decorator that registers a coroutine as a post-invoke hook. """A decorator that registers a coroutine as a post-invoke hook.
This allows you to refer to one after invoke hook for several commands that This allows you to refer to one after invoke hook for several commands that
do not have to be within the same cog. do not have to be within the same cog.
.. versionadded:: 1.4 .. versionadded:: 1.4
.. versionchanged:: 2.0
``coro`` parameter is now positional-only.
""" """
def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]: def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]:

33
discord/ext/commands/errors.py

@ -39,6 +39,8 @@ if TYPE_CHECKING:
from discord.threads import Thread from discord.threads import Thread
from discord.types.snowflake import Snowflake, SnowflakeList from discord.types.snowflake import Snowflake, SnowflakeList
from ._types import BotT
__all__ = ( __all__ = (
'CommandError', 'CommandError',
@ -70,6 +72,7 @@ __all__ = (
'BadInviteArgument', 'BadInviteArgument',
'EmojiNotFound', 'EmojiNotFound',
'GuildStickerNotFound', 'GuildStickerNotFound',
'ScheduledEventNotFound',
'PartialEmojiConversionFailure', 'PartialEmojiConversionFailure',
'BadBoolArgument', 'BadBoolArgument',
'MissingRole', 'MissingRole',
@ -134,8 +137,8 @@ class ConversionError(CommandError):
the ``__cause__`` attribute. the ``__cause__`` attribute.
""" """
def __init__(self, converter: Converter, original: Exception) -> None: def __init__(self, converter: Converter[Any], original: Exception) -> None:
self.converter: Converter = converter self.converter: Converter[Any] = converter
self.original: Exception = original self.original: Exception = original
@ -223,9 +226,9 @@ class CheckAnyFailure(CheckFailure):
A list of check predicates that failed. A list of check predicates that failed.
""" """
def __init__(self, checks: List[CheckFailure], errors: List[Callable[[Context], bool]]) -> None: def __init__(self, checks: List[CheckFailure], errors: List[Callable[[Context[BotT]], bool]]) -> None:
self.checks: List[CheckFailure] = checks self.checks: List[CheckFailure] = checks
self.errors: List[Callable[[Context], bool]] = errors self.errors: List[Callable[[Context[BotT]], bool]] = errors
super().__init__('You do not have permission to run this command.') super().__init__('You do not have permission to run this command.')
@ -515,6 +518,24 @@ class GuildStickerNotFound(BadArgument):
super().__init__(f'Sticker "{argument}" not found.') super().__init__(f'Sticker "{argument}" not found.')
class ScheduledEventNotFound(BadArgument):
"""Exception raised when the bot can not find the scheduled event.
This inherits from :exc:`BadArgument`
.. versionadded:: 2.0
Attributes
-----------
argument: :class:`str`
The event supplied by the caller that was not found
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'ScheduledEvent "{argument}" not found.')
class BadBoolArgument(BadArgument): class BadBoolArgument(BadArgument):
"""Exception raised when a boolean argument was not convertable. """Exception raised when a boolean argument was not convertable.
@ -788,9 +809,9 @@ class BadUnionArgument(UserInputError):
A list of errors that were caught from failing the conversion. A list of errors that were caught from failing the conversion.
""" """
def __init__(self, param: Parameter, converters: Tuple[Type, ...], errors: List[CommandError]) -> None: def __init__(self, param: Parameter, converters: Tuple[type, ...], errors: List[CommandError]) -> None:
self.param: Parameter = param self.param: Parameter = param
self.converters: Tuple[Type, ...] = converters self.converters: Tuple[type, ...] = converters
self.errors: List[CommandError] = errors self.errors: List[CommandError] = errors
def _get_name(x): def _get_name(x):

24
discord/ext/commands/flags.py

@ -49,8 +49,6 @@ from typing import (
Tuple, Tuple,
List, List,
Any, Any,
Type,
TypeVar,
Union, Union,
) )
@ -70,6 +68,8 @@ if TYPE_CHECKING:
from .context import Context from .context import Context
from ._types import BotT
@dataclass @dataclass
class Flag: class Flag:
@ -148,7 +148,7 @@ def flag(
return Flag(name=name, aliases=aliases, default=default, max_args=max_args, override=override) return Flag(name=name, aliases=aliases, default=default, max_args=max_args, override=override)
def validate_flag_name(name: str, forbidden: Set[str]): def validate_flag_name(name: str, forbidden: Set[str]) -> None:
if not name: if not name:
raise ValueError('flag names should not be empty') raise ValueError('flag names should not be empty')
@ -348,7 +348,7 @@ class FlagsMeta(type):
return type.__new__(cls, name, bases, attrs) return type.__new__(cls, name, bases, attrs)
async def tuple_convert_all(ctx: Context, argument: str, flag: Flag, converter: Any) -> Tuple[Any, ...]: async def tuple_convert_all(ctx: Context[BotT], argument: str, flag: Flag, converter: Any) -> Tuple[Any, ...]:
view = StringView(argument) view = StringView(argument)
results = [] results = []
param: inspect.Parameter = ctx.current_parameter # type: ignore param: inspect.Parameter = ctx.current_parameter # type: ignore
@ -373,7 +373,7 @@ async def tuple_convert_all(ctx: Context, argument: str, flag: Flag, converter:
return tuple(results) return tuple(results)
async def tuple_convert_flag(ctx: Context, argument: str, flag: Flag, converters: Any) -> Tuple[Any, ...]: async def tuple_convert_flag(ctx: Context[BotT], argument: str, flag: Flag, converters: Any) -> Tuple[Any, ...]:
view = StringView(argument) view = StringView(argument)
results = [] results = []
param: inspect.Parameter = ctx.current_parameter # type: ignore param: inspect.Parameter = ctx.current_parameter # type: ignore
@ -401,7 +401,7 @@ async def tuple_convert_flag(ctx: Context, argument: str, flag: Flag, converters
return tuple(results) return tuple(results)
async def convert_flag(ctx, argument: str, flag: Flag, annotation: Any = None) -> Any: async def convert_flag(ctx: Context[BotT], argument: str, flag: Flag, annotation: Any = None) -> Any:
param: inspect.Parameter = ctx.current_parameter # type: ignore param: inspect.Parameter = ctx.current_parameter # type: ignore
annotation = annotation or flag.annotation annotation = annotation or flag.annotation
try: try:
@ -480,12 +480,13 @@ class FlagConverter(metaclass=FlagsMeta):
yield (flag.name, getattr(self, flag.attribute)) yield (flag.name, getattr(self, flag.attribute))
@classmethod @classmethod
async def _construct_default(cls, ctx: Context) -> Self: async def _construct_default(cls, ctx: Context[BotT]) -> Self:
self = cls.__new__(cls) self = cls.__new__(cls)
flags = cls.__commands_flags__ flags = cls.__commands_flags__
for flag in flags.values(): for flag in flags.values():
if callable(flag.default): if callable(flag.default):
default = await maybe_coroutine(flag.default, ctx) # Type checker does not understand that flag.default is a Callable
default = await maybe_coroutine(flag.default, ctx) # type: ignore
setattr(self, flag.attribute, default) setattr(self, flag.attribute, default)
else: else:
setattr(self, flag.attribute, flag.default) setattr(self, flag.attribute, flag.default)
@ -546,7 +547,7 @@ class FlagConverter(metaclass=FlagsMeta):
return result return result
@classmethod @classmethod
async def convert(cls, ctx: Context, argument: str) -> Self: async def convert(cls, ctx: Context[BotT], argument: str) -> Self:
"""|coro| """|coro|
The method that actually converters an argument to the flag mapping. The method that actually converters an argument to the flag mapping.
@ -584,7 +585,8 @@ class FlagConverter(metaclass=FlagsMeta):
raise MissingRequiredFlag(flag) raise MissingRequiredFlag(flag)
else: else:
if callable(flag.default): if callable(flag.default):
default = await maybe_coroutine(flag.default, ctx) # Type checker does not understand flag.default is a Callable
default = await maybe_coroutine(flag.default, ctx) # type: ignore
setattr(self, flag.attribute, default) setattr(self, flag.attribute, default)
else: else:
setattr(self, flag.attribute, flag.default) setattr(self, flag.attribute, flag.default)
@ -610,7 +612,7 @@ class FlagConverter(metaclass=FlagsMeta):
values = [await convert_flag(ctx, value, flag) for value in values] values = [await convert_flag(ctx, value, flag) for value in values]
if flag.cast_to_dict: if flag.cast_to_dict:
values = dict(values) # type: ignore values = dict(values)
setattr(self, flag.attribute, values) setattr(self, flag.attribute, values)

585
discord/ext/commands/help.py

File diff suppressed because it is too large

37
discord/ext/commands/view.py

@ -21,6 +21,11 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations
from typing import Optional
from .errors import UnexpectedQuoteError, InvalidEndOfQuotedStringError, ExpectedClosingQuoteError from .errors import UnexpectedQuoteError, InvalidEndOfQuotedStringError, ExpectedClosingQuoteError
# map from opening quotes to closing quotes # map from opening quotes to closing quotes
@ -47,24 +52,24 @@ _all_quotes = set(_quotes.keys()) | set(_quotes.values())
class StringView: class StringView:
def __init__(self, buffer): def __init__(self, buffer: str) -> None:
self.index = 0 self.index: int = 0
self.buffer = buffer self.buffer: str = buffer
self.end = len(buffer) self.end: int = len(buffer)
self.previous = 0 self.previous = 0
@property @property
def current(self): def current(self) -> Optional[str]:
return None if self.eof else self.buffer[self.index] return None if self.eof else self.buffer[self.index]
@property @property
def eof(self): def eof(self) -> bool:
return self.index >= self.end return self.index >= self.end
def undo(self): def undo(self) -> None:
self.index = self.previous self.index = self.previous
def skip_ws(self): def skip_ws(self) -> bool:
pos = 0 pos = 0
while not self.eof: while not self.eof:
try: try:
@ -79,7 +84,7 @@ class StringView:
self.index += pos self.index += pos
return self.previous != self.index return self.previous != self.index
def skip_string(self, string): def skip_string(self, string: str) -> bool:
strlen = len(string) strlen = len(string)
if self.buffer[self.index : self.index + strlen] == string: if self.buffer[self.index : self.index + strlen] == string:
self.previous = self.index self.previous = self.index
@ -87,19 +92,19 @@ class StringView:
return True return True
return False return False
def read_rest(self): def read_rest(self) -> str:
result = self.buffer[self.index :] result = self.buffer[self.index :]
self.previous = self.index self.previous = self.index
self.index = self.end self.index = self.end
return result return result
def read(self, n): def read(self, n: int) -> str:
result = self.buffer[self.index : self.index + n] result = self.buffer[self.index : self.index + n]
self.previous = self.index self.previous = self.index
self.index += n self.index += n
return result return result
def get(self): def get(self) -> Optional[str]:
try: try:
result = self.buffer[self.index + 1] result = self.buffer[self.index + 1]
except IndexError: except IndexError:
@ -109,7 +114,7 @@ class StringView:
self.index += 1 self.index += 1
return result return result
def get_word(self): def get_word(self) -> str:
pos = 0 pos = 0
while not self.eof: while not self.eof:
try: try:
@ -119,12 +124,12 @@ class StringView:
pos += 1 pos += 1
except IndexError: except IndexError:
break break
self.previous = self.index self.previous: int = self.index
result = self.buffer[self.index : self.index + pos] result = self.buffer[self.index : self.index + pos]
self.index += pos self.index += pos
return result return result
def get_quoted_word(self): def get_quoted_word(self) -> Optional[str]:
current = self.current current = self.current
if current is None: if current is None:
return None return None
@ -187,5 +192,5 @@ class StringView:
result.append(current) result.append(current)
def __repr__(self): def __repr__(self) -> str:
return f'<StringView pos: {self.index} prev: {self.previous} end: {self.end} eof: {self.eof}>' return f'<StringView pos: {self.index} prev: {self.previous} end: {self.end} eof: {self.eof}>'

182
discord/ext/tasks/__init__.py

@ -26,6 +26,7 @@ from __future__ import annotations
import asyncio import asyncio
import datetime import datetime
import logging
from typing import ( from typing import (
Any, Any,
Awaitable, Awaitable,
@ -48,6 +49,8 @@ from collections.abc import Sequence
from discord.backoff import ExponentialBackoff from discord.backoff import ExponentialBackoff
from discord.utils import MISSING from discord.utils import MISSING
_log = logging.getLogger(__name__)
# fmt: off # fmt: off
__all__ = ( __all__ = (
'loop', 'loop',
@ -61,19 +64,61 @@ FT = TypeVar('FT', bound=_func)
ET = TypeVar('ET', bound=Callable[[Any, BaseException], Awaitable[Any]]) ET = TypeVar('ET', bound=Callable[[Any, BaseException], Awaitable[Any]])
def is_ambiguous(dt: datetime.datetime) -> bool:
if dt.tzinfo is None or isinstance(dt.tzinfo, datetime.timezone):
# Naive or fixed timezones are never ambiguous
return False
before = dt.replace(fold=0)
after = dt.replace(fold=1)
same_offset = before.utcoffset() == after.utcoffset()
same_dst = before.dst() == after.dst()
return not (same_offset and same_dst)
def is_imaginary(dt: datetime.datetime) -> bool:
if dt.tzinfo is None or isinstance(dt.tzinfo, datetime.timezone):
# Naive or fixed timezones are never imaginary
return False
tz = dt.tzinfo
dt = dt.replace(tzinfo=None)
roundtrip = dt.replace(tzinfo=tz).astimezone(datetime.timezone.utc).astimezone(tz).replace(tzinfo=None)
return dt != roundtrip
def resolve_datetime(dt: datetime.datetime) -> datetime.datetime:
if dt.tzinfo is None or isinstance(dt.tzinfo, datetime.timezone):
# Naive or fixed requires no resolution
return dt
if is_imaginary(dt):
# Largest gap is probably 24 hours
tomorrow = dt + datetime.timedelta(days=1)
yesterday = dt - datetime.timedelta(days=1)
# utcoffset shouldn't return None since these are aware instances
# If it returns None then the timezone implementation was broken from the get go
return dt + (tomorrow.utcoffset() - yesterday.utcoffset()) # type: ignore
elif is_ambiguous(dt):
return dt.replace(fold=1)
else:
return dt
class SleepHandle: class SleepHandle:
__slots__ = ('future', 'loop', 'handle') __slots__ = ('future', 'loop', 'handle')
def __init__(self, dt: datetime.datetime, *, loop: asyncio.AbstractEventLoop) -> None: def __init__(self, dt: datetime.datetime, *, loop: asyncio.AbstractEventLoop) -> None:
self.loop = loop self.loop: asyncio.AbstractEventLoop = loop
self.future = future = loop.create_future() self.future: asyncio.Future[None] = loop.create_future()
relative_delta = discord.utils.compute_timedelta(dt) relative_delta = discord.utils.compute_timedelta(dt)
self.handle = loop.call_later(relative_delta, future.set_result, True) self.handle = loop.call_later(relative_delta, self.future.set_result, True)
def recalculate(self, dt: datetime.datetime) -> None: def recalculate(self, dt: datetime.datetime) -> None:
self.handle.cancel() self.handle.cancel()
relative_delta = discord.utils.compute_timedelta(dt) relative_delta = discord.utils.compute_timedelta(dt)
self.handle = self.loop.call_later(relative_delta, self.future.set_result, True) self.handle: asyncio.TimerHandle = self.loop.call_later(relative_delta, self.future.set_result, True)
def wait(self) -> asyncio.Future[Any]: def wait(self) -> asyncio.Future[Any]:
return self.future return self.future
@ -101,15 +146,13 @@ class Loop(Generic[LF]):
time: Union[datetime.time, Sequence[datetime.time]], time: Union[datetime.time, Sequence[datetime.time]],
count: Optional[int], count: Optional[int],
reconnect: bool, reconnect: bool,
loop: asyncio.AbstractEventLoop,
) -> None: ) -> None:
self.coro: LF = coro self.coro: LF = coro
self.reconnect: bool = reconnect self.reconnect: bool = reconnect
self.loop: asyncio.AbstractEventLoop = loop
self.count: Optional[int] = count self.count: Optional[int] = count
self._current_loop = 0 self._current_loop = 0
self._handle: SleepHandle = MISSING self._handle: Optional[SleepHandle] = None
self._task: asyncio.Task[None] = MISSING self._task: Optional[asyncio.Task[None]] = None
self._injected = None self._injected = None
self._valid_exception = ( self._valid_exception = (
OSError, OSError,
@ -147,16 +190,20 @@ class Loop(Generic[LF]):
await coro(*args, **kwargs) await coro(*args, **kwargs)
def _try_sleep_until(self, dt: datetime.datetime): def _try_sleep_until(self, dt: datetime.datetime):
self._handle = SleepHandle(dt=dt, loop=self.loop) self._handle = SleepHandle(dt=dt, loop=asyncio.get_running_loop())
return self._handle.wait() return self._handle.wait()
def _is_relative_time(self) -> bool:
return self._time is MISSING
def _is_explicit_time(self) -> bool:
return self._time is not MISSING
async def _loop(self, *args: Any, **kwargs: Any) -> None: async def _loop(self, *args: Any, **kwargs: Any) -> None:
backoff = ExponentialBackoff() backoff = ExponentialBackoff()
await self._call_loop_function('before_loop') await self._call_loop_function('before_loop')
self._last_iteration_failed = False self._last_iteration_failed = False
if self._time is not MISSING: if self._is_explicit_time():
# the time index should be prepared every time the internal loop is started
self._prepare_time_index()
self._next_iteration = self._get_next_sleep_time() self._next_iteration = self._get_next_sleep_time()
else: else:
self._next_iteration = datetime.datetime.now(datetime.timezone.utc) self._next_iteration = datetime.datetime.now(datetime.timezone.utc)
@ -166,11 +213,30 @@ class Loop(Generic[LF]):
return return
while True: while True:
# sleep before the body of the task for explicit time intervals # sleep before the body of the task for explicit time intervals
if self._time is not MISSING: if self._is_explicit_time():
await self._try_sleep_until(self._next_iteration) await self._try_sleep_until(self._next_iteration)
if not self._last_iteration_failed: if not self._last_iteration_failed:
self._last_iteration = self._next_iteration self._last_iteration = self._next_iteration
self._next_iteration = self._get_next_sleep_time() self._next_iteration = self._get_next_sleep_time()
# In order to account for clock drift, we need to ensure that
# the next iteration always follows the last iteration.
# Sometimes asyncio is cheeky and wakes up a few microseconds before our target
# time, causing it to repeat a run.
while self._is_explicit_time() and self._next_iteration <= self._last_iteration:
_log.warn(
(
'Clock drift detected for task %s. Woke up at %s but needed to sleep until %s. '
'Sleeping until %s again to correct clock'
),
self.coro.__qualname__,
discord.utils.utcnow(),
self._next_iteration,
self._next_iteration,
)
await self._try_sleep_until(self._next_iteration)
self._next_iteration = self._get_next_sleep_time()
try: try:
await self.coro(*args, **kwargs) await self.coro(*args, **kwargs)
self._last_iteration_failed = False self._last_iteration_failed = False
@ -184,7 +250,7 @@ class Loop(Generic[LF]):
return return
# sleep after the body of the task for relative time intervals # sleep after the body of the task for relative time intervals
if self._time is MISSING: if self._is_relative_time():
await self._try_sleep_until(self._next_iteration) await self._try_sleep_until(self._next_iteration)
self._current_loop += 1 self._current_loop += 1
@ -200,6 +266,7 @@ class Loop(Generic[LF]):
raise exc raise exc
finally: finally:
await self._call_loop_function('after_loop') await self._call_loop_function('after_loop')
if self._handle:
self._handle.cancel() self._handle.cancel()
self._is_being_cancelled = False self._is_being_cancelled = False
self._current_loop = 0 self._current_loop = 0
@ -218,7 +285,6 @@ class Loop(Generic[LF]):
time=self._time, time=self._time,
count=self.count, count=self.count,
reconnect=self.reconnect, reconnect=self.reconnect,
loop=self.loop,
) )
copy._injected = obj copy._injected = obj
copy._before_loop = self._before_loop copy._before_loop = self._before_loop
@ -325,16 +391,13 @@ class Loop(Generic[LF]):
The task that has been created. The task that has been created.
""" """
if self._task is not MISSING and not self._task.done(): if self._task and not self._task.done():
raise RuntimeError('Task is already launched and is not completed.') raise RuntimeError('Task is already launched and is not completed.')
if self._injected is not None: if self._injected is not None:
args = (self._injected, *args) args = (self._injected, *args)
if self.loop is MISSING: self._task = asyncio.create_task(self._loop(*args, **kwargs))
self.loop = asyncio.get_event_loop()
self._task = self.loop.create_task(self._loop(*args, **kwargs))
return self._task return self._task
def stop(self) -> None: def stop(self) -> None:
@ -358,7 +421,7 @@ class Loop(Generic[LF]):
.. versionadded:: 1.2 .. versionadded:: 1.2
""" """
if self._task is not MISSING and not self._task.done(): if self._task and not self._task.done():
self._stop_next_iteration = True self._stop_next_iteration = True
def _can_be_cancelled(self) -> bool: def _can_be_cancelled(self) -> bool:
@ -366,7 +429,7 @@ class Loop(Generic[LF]):
def cancel(self) -> None: def cancel(self) -> None:
"""Cancels the internal task, if it is running.""" """Cancels the internal task, if it is running."""
if self._can_be_cancelled(): if self._can_be_cancelled() and self._task:
self._task.cancel() self._task.cancel()
def restart(self, *args: Any, **kwargs: Any) -> None: def restart(self, *args: Any, **kwargs: Any) -> None:
@ -386,10 +449,11 @@ class Loop(Generic[LF]):
""" """
def restart_when_over(fut: Any, *, args: Any = args, kwargs: Any = kwargs) -> None: def restart_when_over(fut: Any, *, args: Any = args, kwargs: Any = kwargs) -> None:
if self._task:
self._task.remove_done_callback(restart_when_over) self._task.remove_done_callback(restart_when_over)
self.start(*args, **kwargs) self.start(*args, **kwargs)
if self._can_be_cancelled(): if self._can_be_cancelled() and self._task:
self._task.add_done_callback(restart_when_over) self._task.add_done_callback(restart_when_over)
self._task.cancel() self._task.cancel()
@ -468,7 +532,7 @@ class Loop(Generic[LF]):
.. versionadded:: 1.4 .. versionadded:: 1.4
""" """
return not bool(self._task.done()) if self._task is not MISSING else False return not bool(self._task.done()) if self._task else False
async def _error(self, *args: Any) -> None: async def _error(self, *args: Any) -> None:
exception: Exception = args[-1] exception: Exception = args[-1]
@ -557,47 +621,50 @@ class Loop(Generic[LF]):
self._error = coro # type: ignore self._error = coro # type: ignore
return coro return coro
def _get_next_sleep_time(self) -> datetime.datetime: def _get_next_sleep_time(self, now: datetime.datetime = MISSING) -> datetime.datetime:
if self._sleep is not MISSING: if self._sleep is not MISSING:
return self._last_iteration + datetime.timedelta(seconds=self._sleep) return self._last_iteration + datetime.timedelta(seconds=self._sleep)
if self._time_index >= len(self._time): if now is MISSING:
self._time_index = 0 now = datetime.datetime.now(datetime.timezone.utc)
if self._current_loop == 0:
# if we're at the last index on the first iteration, we need to sleep until tomorrow
return datetime.datetime.combine(
datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1), self._time[0]
)
next_time = self._time[self._time_index]
if self._current_loop == 0: index = self._start_time_relative_to(now)
self._time_index += 1
return datetime.datetime.combine(datetime.datetime.now(datetime.timezone.utc), next_time)
next_date = self._last_iteration if index is None:
if self._time_index == 0: time = self._time[0]
# we can assume that the earliest time should be scheduled for "tomorrow" tomorrow = now.astimezone(time.tzinfo) + datetime.timedelta(days=1)
next_date += datetime.timedelta(days=1) date = tomorrow.date()
else:
time = self._time[index]
date = now.astimezone(time.tzinfo).date()
self._time_index += 1 dt = datetime.datetime.combine(date, time, tzinfo=time.tzinfo)
return datetime.datetime.combine(next_date, next_time) return resolve_datetime(dt)
def _prepare_time_index(self, now: datetime.datetime = MISSING) -> None: def _start_time_relative_to(self, now: datetime.datetime) -> Optional[int]:
# now kwarg should be a datetime.datetime representing the time "now" # now kwarg should be a datetime.datetime representing the time "now"
# to calculate the next time index from # to calculate the next time index from
# pre-condition: self._time is set # pre-condition: self._time is set
time_now = (
now if now is not MISSING else datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0) # Sole time comparisons are apparently broken, therefore, attach today's date
).timetz() # to it in order to make the comparisons make sense.
idx = -1 # For example, if given a list of times [0, 3, 18]
# If it's 04:00 today then we know we have to wait until 18:00 today
# If it's 19:00 today then we know we we have to wait until 00:00 tomorrow
# Note that timezones need to be taken into consideration for this to work.
# If the timezone is set to UTC+9 and the now timezone is UTC
# A conversion needs to be done.
# i.e. 03:00 UTC+9 -> 18:00 UTC the previous day
for idx, time in enumerate(self._time): for idx, time in enumerate(self._time):
if time >= time_now: # Convert the current time to the target timezone
self._time_index = idx # e.g. 18:00 UTC -> 03:00 UTC+9
break # Then compare the time instances to see if they're the same
start = now.astimezone(time.tzinfo)
if time >= start.timetz():
return idx
else: else:
self._time_index = idx + 1 return None
def _get_time_parameter( def _get_time_parameter(
self, self,
@ -687,12 +754,8 @@ class Loop(Generic[LF]):
self._sleep = self._seconds = self._minutes = self._hours = MISSING self._sleep = self._seconds = self._minutes = self._hours = MISSING
if self.is_running(): if self.is_running():
if self._time is not MISSING:
# prepare the next time index starting from after the last iteration
self._prepare_time_index(now=self._last_iteration)
self._next_iteration = self._get_next_sleep_time() self._next_iteration = self._get_next_sleep_time()
if self._handle is not MISSING and not self._handle.done(): if self._handle and not self._handle.done():
# the loop is sleeping, recalculate based on new interval # the loop is sleeping, recalculate based on new interval
self._handle.recalculate(self._next_iteration) self._handle.recalculate(self._next_iteration)
@ -705,7 +768,6 @@ def loop(
time: Union[datetime.time, Sequence[datetime.time]] = MISSING, time: Union[datetime.time, Sequence[datetime.time]] = MISSING,
count: Optional[int] = None, count: Optional[int] = None,
reconnect: bool = True, reconnect: bool = True,
loop: asyncio.AbstractEventLoop = MISSING,
) -> Callable[[LF], Loop[LF]]: ) -> Callable[[LF], Loop[LF]]:
"""A decorator that schedules a task in the background for you with """A decorator that schedules a task in the background for you with
optional reconnect logic. The decorator returns a :class:`Loop`. optional reconnect logic. The decorator returns a :class:`Loop`.
@ -738,9 +800,6 @@ def loop(
Whether to handle errors and restart the task Whether to handle errors and restart the task
using an exponential back-off algorithm similar to the using an exponential back-off algorithm similar to the
one used in :meth:`discord.Client.connect`. one used in :meth:`discord.Client.connect`.
loop: :class:`asyncio.AbstractEventLoop`
The loop to use to register the task, if not given
defaults to :func:`asyncio.get_event_loop`.
Raises Raises
-------- --------
@ -760,7 +819,6 @@ def loop(
count=count, count=count,
time=time, time=time,
reconnect=reconnect, reconnect=reconnect,
loop=loop,
) )
return decorator return decorator

2
discord/file.py

@ -78,7 +78,7 @@ class File:
def __init__( def __init__(
self, self,
fp: Union[str, bytes, os.PathLike, io.BufferedIOBase], fp: Union[str, bytes, os.PathLike[Any], io.BufferedIOBase],
filename: Optional[str] = None, filename: Optional[str] = None,
*, *,
spoiler: bool = False, spoiler: bool = False,

21
discord/flags.py

@ -45,8 +45,8 @@ BF = TypeVar('BF', bound='BaseFlags')
class flag_value: class flag_value:
def __init__(self, func: Callable[[Any], int]): def __init__(self, func: Callable[[Any], int]):
self.flag = func(None) self.flag: int = func(None)
self.__doc__ = func.__doc__ self.__doc__: Optional[str] = func.__doc__
@overload @overload
def __get__(self, instance: None, owner: Type[BF]) -> Self: def __get__(self, instance: None, owner: Type[BF]) -> Self:
@ -64,7 +64,7 @@ class flag_value:
def __set__(self, instance: BaseFlags, value: bool) -> None: def __set__(self, instance: BaseFlags, value: bool) -> None:
instance._set_flag(self.flag, value) instance._set_flag(self.flag, value)
def __repr__(self): def __repr__(self) -> str:
return f'<flag_value flag={self.flag!r}>' return f'<flag_value flag={self.flag!r}>'
@ -72,8 +72,8 @@ class alias_flag_value(flag_value):
pass pass
def fill_with_flags(*, inverted: bool = False): def fill_with_flags(*, inverted: bool = False) -> Callable[[Type[BF]], Type[BF]]:
def decorator(cls: Type[BF]): def decorator(cls: Type[BF]) -> Type[BF]:
# fmt: off # fmt: off
cls.VALID_FLAGS = { cls.VALID_FLAGS = {
name: value.flag name: value.flag
@ -115,10 +115,10 @@ class BaseFlags:
self.value = value self.value = value
return self return self
def __eq__(self, other: Any) -> bool: def __eq__(self, other: object) -> bool:
return isinstance(other, self.__class__) and self.value == other.value return isinstance(other, self.__class__) and self.value == other.value
def __ne__(self, other: Any) -> bool: def __ne__(self, other: object) -> bool:
return not self.__eq__(other) return not self.__eq__(other)
def __hash__(self) -> int: def __hash__(self) -> int:
@ -533,6 +533,11 @@ class PrivateUserFlags(PublicUserFlags):
""":class:`bool`: Returns ``True`` if the user has a partner or a verification application.""" """:class:`bool`: Returns ``True`` if the user has a partner or a verification application."""
return UserFlags.partner_or_verification_application.value return UserFlags.partner_or_verification_application.value
@flag_value
def disable_premium(self):
""":class:`bool`: Returns ``True`` if the user bought premium but has it manually disabled."""
return UserFlags.disable_premium.value
@fill_with_flags() @fill_with_flags()
class MemberCacheFlags(BaseFlags): class MemberCacheFlags(BaseFlags):
@ -576,7 +581,7 @@ class MemberCacheFlags(BaseFlags):
def __init__(self, **kwargs: bool): def __init__(self, **kwargs: bool):
bits = max(self.VALID_FLAGS.values()).bit_length() bits = max(self.VALID_FLAGS.values()).bit_length()
self.value = (1 << bits) - 1 self.value: int = (1 << bits) - 1
for key, value in kwargs.items(): for key, value in kwargs.items():
if key not in self.VALID_FLAGS: if key not in self.VALID_FLAGS:
raise TypeError(f'{key!r} is not a valid flag name.') raise TypeError(f'{key!r} is not a valid flag name.')

84
discord/gateway.py

@ -25,7 +25,6 @@ from __future__ import annotations
import asyncio import asyncio
from collections import deque from collections import deque
import concurrent.futures
import logging import logging
import struct import struct
import time import time
@ -54,8 +53,11 @@ __all__ = (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import Self
from .client import Client from .client import Client
from .state import ConnectionState from .state import ConnectionState
from .types.snowflake import Snowflake
from .voice_client import VoiceClient from .voice_client import VoiceClient
@ -108,9 +110,6 @@ class GatewayRatelimiter:
return self.per - (current - self.window) return self.per - (current - self.window)
self.remaining -= 1 self.remaining -= 1
if self.remaining == 0:
self.window = current
return 0.0 return 0.0
async def block(self) -> None: async def block(self) -> None:
@ -222,7 +221,7 @@ class VoiceKeepAliveHandler(KeepAliveHandler):
ack_time = time.perf_counter() ack_time = time.perf_counter()
self._last_ack = ack_time self._last_ack = ack_time
self._last_recv = ack_time self._last_recv = ack_time
self.latency = ack_time - self._last_send self.latency: float = ack_time - self._last_send
self.recent_ack_latencies.append(self.latency) self.recent_ack_latencies.append(self.latency)
if self.latency > 10: if self.latency > 10:
_log.warning(self.behind_msg, self.latency) _log.warning(self.behind_msg, self.latency)
@ -345,7 +344,7 @@ class DiscordWebSocket:
@classmethod @classmethod
async def from_client( async def from_client(
cls: Type[DWS], cls,
client: Client, client: Client,
*, *,
initial: bool = False, initial: bool = False,
@ -353,7 +352,7 @@ class DiscordWebSocket:
session: Optional[str] = None, session: Optional[str] = None,
sequence: Optional[int] = None, sequence: Optional[int] = None,
resume: bool = False, resume: bool = False,
) -> DWS: ) -> Self:
"""Creates a main websocket for Discord from a :class:`Client`. """Creates a main websocket for Discord from a :class:`Client`.
This is for internal use only. This is for internal use only.
@ -662,21 +661,21 @@ class DiscordWebSocket:
if activities is not None: if activities is not None:
if not all(isinstance(activity, BaseActivity) for activity in activities): if not all(isinstance(activity, BaseActivity) for activity in activities):
raise TypeError('activity must derive from BaseActivity') raise TypeError('activity must derive from BaseActivity')
activities = [activity.to_dict() for activity in activities] activities_data = [activity.to_dict() for activity in activities]
else: else:
activities = [] activities_data = []
if status == 'idle': if status == 'idle':
since = int(time.time() * 1000) since = int(time.time() * 1000)
payload = {'op': self.PRESENCE, 'd': {'activities': activities, 'afk': afk, 'since': since, 'status': str(status)}} payload = {'op': self.PRESENCE, 'd': {'activities': activities_data, 'afk': afk, 'since': since, 'status': str(status)}}
sent = utils._to_json(payload) sent = utils._to_json(payload)
_log.debug('Sending "%s" to change presence.', sent) _log.debug('Sending "%s" to change presence.', sent)
await self.send(sent) await self.send(sent)
async def request_lazy_guild( async def request_lazy_guild(
self, guild_id, *, typing=None, threads=None, activities=None, members=None, channels=None, thread_member_lists=None self, guild_id: Snowflake, *, typing: Optional[bool] = None, threads: Optional[bool] = None, activities: Optional[bool] = None, members: Optional[List[Snowflake]]=None, channels: Optional[Dict[Snowflake, List[List[int]]]]=None, thread_member_lists: Optional[List[Snowflake]]=None
): ):
payload = { payload = {
'op': self.GUILD_SUBSCRIBE, 'op': self.GUILD_SUBSCRIBE,
@ -704,11 +703,11 @@ class DiscordWebSocket:
async def request_chunks( async def request_chunks(
self, self,
guild_ids: List[int], guild_ids: List[Snowflake],
query: Optional[str] = None, query: Optional[str] = None,
*, *,
limit: Optional[int] = None, limit: Optional[int] = None,
user_ids: Optional[List[int]] = None, user_ids: Optional[List[Snowflake]] = None,
presences: bool = True, presences: bool = True,
nonce: Optional[str] = None, nonce: Optional[str] = None,
) -> None: ) -> None:
@ -723,7 +722,7 @@ class DiscordWebSocket:
}, },
} }
if nonce: if nonce is not None:
payload['d']['nonce'] = nonce payload['d']['nonce'] = nonce
await self.send_as_json(payload) await self.send_as_json(payload)
@ -755,7 +754,7 @@ class DiscordWebSocket:
_log.debug('Updating %s voice state to %s.', guild_id or 'client', payload) _log.debug('Updating %s voice state to %s.', guild_id or 'client', payload)
await self.send_as_json(payload) await self.send_as_json(payload)
async def access_dm(self, channel_id: int): async def access_dm(self, channel_id: Snowflake):
payload = {'op': self.CALL_CONNECT, 'd': {'channel_id': str(channel_id)}} payload = {'op': self.CALL_CONNECT, 'd': {'channel_id': str(channel_id)}}
_log.debug('Sending ACCESS_DM for channel %s.', channel_id) _log.debug('Sending ACCESS_DM for channel %s.', channel_id)
@ -763,7 +762,7 @@ class DiscordWebSocket:
async def request_commands( async def request_commands(
self, self,
guild_id: int, guild_id: Snowflake,
type: int, type: int,
*, *,
nonce: Optional[str] = None, nonce: Optional[str] = None,
@ -771,13 +770,13 @@ class DiscordWebSocket:
applications: Optional[bool] = None, applications: Optional[bool] = None,
offset: int = 0, offset: int = 0,
query: Optional[str] = None, query: Optional[str] = None,
command_ids: Optional[List[int]] = None, command_ids: Optional[List[Snowflake]] = None,
application_id: Optional[int] = None, application_id: Optional[Snowflake] = None,
) -> None: ) -> None:
payload = { payload = {
'op': self.REQUEST_COMMANDS, 'op': self.REQUEST_COMMANDS,
'd': { 'd': {
'guild_id': guild_id, 'guild_id': str(guild_id),
'type': type, 'type': type,
}, },
} }
@ -795,7 +794,7 @@ class DiscordWebSocket:
if command_ids is not None: if command_ids is not None:
payload['d']['command_ids'] = command_ids payload['d']['command_ids'] = command_ids
if application_id is not None: if application_id is not None:
payload['d']['application_id'] = application_id payload['d']['application_id'] = str(application_id)
await self.send_as_json(payload) await self.send_as_json(payload)
@ -871,11 +870,11 @@ class DiscordVoiceWebSocket:
*, *,
hook: Optional[Callable[..., Coroutine[Any, Any, Any]]] = None, hook: Optional[Callable[..., Coroutine[Any, Any, Any]]] = None,
) -> None: ) -> None:
self.ws = socket self.ws: aiohttp.ClientWebSocketResponse = socket
self.loop = loop self.loop: asyncio.AbstractEventLoop = loop
self._keep_alive = None self._keep_alive: Optional[VoiceKeepAliveHandler] = None
self._close_code = None self._close_code: Optional[int] = None
self.secret_key = None self.secret_key: Optional[str] = None
if hook: if hook:
self._hook = hook # type: ignore - type-checker doesn't like overriding methods self._hook = hook # type: ignore - type-checker doesn't like overriding methods
@ -914,7 +913,9 @@ class DiscordVoiceWebSocket:
await self.send_as_json(payload) await self.send_as_json(payload)
@classmethod @classmethod
async def from_client(cls: Type[DVWS], client: VoiceClient, *, resume=False, hook=None) -> DVWS: async def from_client(
cls, client: VoiceClient, *, resume: bool = False, hook: Optional[Callable[..., Coroutine[Any, Any, Any]]] = None
) -> Self:
"""Creates a voice websocket for the :class:`VoiceClient`.""" """Creates a voice websocket for the :class:`VoiceClient`."""
gateway = 'wss://' + client.endpoint + '/?v=4' gateway = 'wss://' + client.endpoint + '/?v=4'
http = client._state.http http = client._state.http
@ -971,29 +972,30 @@ class DiscordVoiceWebSocket:
async def received_message(self, msg: Dict[str, Any]) -> None: async def received_message(self, msg: Dict[str, Any]) -> None:
_log.debug('Voice gateway event: %s.', msg) _log.debug('Voice gateway event: %s.', msg)
op = msg['op'] op = msg['op']
data = msg.get('d') data = msg['d'] # According to Discord this key is always given
if op == self.READY: if op == self.READY:
await self.initial_connection(data) # type: ignore - type-checker thinks data could be None await self.initial_connection(data)
elif op == self.HEARTBEAT_ACK: elif op == self.HEARTBEAT_ACK:
self._keep_alive.ack() # type: ignore - _keep_alive can't be None at this point if self._keep_alive:
self._keep_alive.ack()
elif op == self.RESUMED: elif op == self.RESUMED:
_log.info('Voice RESUME succeeded.') _log.info('Voice RESUME succeeded.')
self.secret_key = self._connection.secret_key self.secret_key = self._connection.secret_key
elif op == self.SELECT_PROTOCOL_ACK: elif op == self.SELECT_PROTOCOL_ACK:
self._connection.mode = data['mode'] # type: ignore - data can't be None at this point self._connection.mode = data['mode']
await self.load_secret_key(data) # type: ignore - data can't be None at this point await self.load_secret_key(data)
elif op == self.HELLO: elif op == self.HELLO:
interval = data['heartbeat_interval'] / 1000.0 # type: ignore - type-checker thinks data could be None interval = data['heartbeat_interval'] / 1000.0
self._keep_alive = VoiceKeepAliveHandler(ws=self, interval=min(interval, 5.0)) self._keep_alive = VoiceKeepAliveHandler(ws=self, interval=min(interval, 5.0))
self._keep_alive.start() self._keep_alive.start()
elif op == self.SPEAKING: elif op == self.SPEAKING:
state = self._connection state = self._connection
user_id = int(data['user_id']) # type: ignore - data can't be None at this point user_id = int(data['user_id'])
speaking = data['speaking'] # type: ignore - data can't be None at this point speaking = data['speaking']
ssrc = state._flip_ssrc(user_id) ssrc = state._flip_ssrc(user_id)
if ssrc is None: if ssrc is None:
state._set_ssrc(user_id, SSRC(data['ssrc'], speaking)) # type: ignore - data can't be None at this point state._set_ssrc(user_id, SSRC(data['ssrc'], speaking))
else: else:
ssrc.speaking = speaking ssrc.speaking = speaking
@ -1019,17 +1021,17 @@ class DiscordVoiceWebSocket:
# The IP is ascii starting at the 4th byte and ending at the first null # The IP is ascii starting at the 4th byte and ending at the first null
ip_start = 4 ip_start = 4
ip_end = recv.index(0, ip_start) ip_end = recv.index(0, ip_start)
state.endpoint_ip = recv[ip_start:ip_end].decode('ascii') state.ip = recv[ip_start:ip_end].decode('ascii')
state.voice_port = struct.unpack_from('>H', recv, len(recv) - 2)[0] state.port = struct.unpack_from('>H', recv, len(recv) - 2)[0]
_log.debug('detected ip: %s port: %s', state.endpoint_ip, state.voice_port) _log.debug('detected ip: %s port: %s', state.ip, state.port)
# There *should* always be at least one supported mode (xsalsa20_poly1305) # There *should* always be at least one supported mode (xsalsa20_poly1305)
modes = [mode for mode in data['modes'] if mode in self._connection.supported_modes] modes = [mode for mode in data['modes'] if mode in self._connection.supported_modes]
_log.debug('Received supported encryption modes: %s.', ", ".join(modes)) _log.debug('Received supported encryption modes: %s.', ", ".join(modes))
mode = modes[0] mode = modes[0]
await self.select_protocol(state.endpoint_ip, state.voice_port, mode) await self.select_protocol(state.ip, state.port, mode)
_log.info('Selected the voice protocol for use: %s.', mode) _log.info('Selected the voice protocol for use: %s.', mode)
@property @property
@ -1047,9 +1049,9 @@ class DiscordVoiceWebSocket:
return sum(heartbeat.recent_ack_latencies) / len(heartbeat.recent_ack_latencies) return sum(heartbeat.recent_ack_latencies) / len(heartbeat.recent_ack_latencies)
async def load_secret_key(self, data): async def load_secret_key(self, data: Dict[str, Any]) -> None:
_log.info('Received secret key for voice connection.') _log.info('Received secret key for voice connection.')
self.secret_key = self._connection.secret_key = data.get('secret_key') self.secret_key = self._connection.secret_key = data['secret_key']
await self.speak() await self.speak()
await self.speak(SpeakingState.none) await self.speak(SpeakingState.none)

139
discord/guild.py

@ -32,9 +32,11 @@ from typing import (
Any, Any,
AsyncIterator, AsyncIterator,
ClassVar, ClassVar,
Collection,
Coroutine, Coroutine,
Dict, Dict,
List, List,
Mapping,
NamedTuple, NamedTuple,
Sequence, Sequence,
Set, Set,
@ -129,6 +131,7 @@ if TYPE_CHECKING:
) )
from .types.integration import IntegrationType from .types.integration import IntegrationType
from .types.snowflake import SnowflakeList from .types.snowflake import SnowflakeList
from .types.widget import EditWidgetSettings
VocalGuildChannel = Union[VoiceChannel, StageChannel] VocalGuildChannel = Union[VoiceChannel, StageChannel]
GuildChannel = Union[VocalGuildChannel, TextChannel, CategoryChannel, StoreChannel] GuildChannel = Union[VocalGuildChannel, TextChannel, CategoryChannel, StoreChannel]
@ -340,6 +343,7 @@ class Guild(Hashable):
self._state: ConnectionState = state self._state: ConnectionState = state
self.notification_settings: Optional[GuildSettings] = None self.notification_settings: Optional[GuildSettings] = None
self.command_counts: Optional[CommandCounts] = None self.command_counts: Optional[CommandCounts] = None
self._member_count: int = 0
self._from_data(data) self._from_data(data)
def _add_channel(self, channel: GuildChannel, /) -> None: def _add_channel(self, channel: GuildChannel, /) -> None:
@ -384,7 +388,7 @@ class Guild(Hashable):
('id', self.id), ('id', self.id),
('name', self.name), ('name', self.name),
('chunked', self.chunked), ('chunked', self.chunked),
('member_count', getattr(self, '_member_count', None)), ('member_count', self._member_count),
) )
inner = ' '.join('%s=%r' % t for t in attrs) inner = ' '.join('%s=%r' % t for t in attrs)
return f'<Guild {inner}>' return f'<Guild {inner}>'
@ -435,9 +439,10 @@ class Guild(Hashable):
return role return role
def _from_data(self, guild: GuildPayload) -> None: def _from_data(self, guild: GuildPayload) -> None:
member_count = guild.get('member_count', guild.get('approximate_member_count')) try:
if member_count is not None: self._member_count: int = guild['member_count']
self._member_count: int = member_count except KeyError:
pass
self.id: int = int(guild['id']) self.id: int = int(guild['id'])
self.name: str = guild.get('name', '') self.name: str = guild.get('name', '')
@ -506,7 +511,7 @@ class Guild(Hashable):
self.owner_application_id: Optional[int] = utils._get_as_snowflake(guild, 'application_id') self.owner_application_id: Optional[int] = utils._get_as_snowflake(guild, 'application_id')
self.premium_progress_bar_enabled: bool = guild.get('premium_progress_bar_enabled', False) self.premium_progress_bar_enabled: bool = guild.get('premium_progress_bar_enabled', False)
large = None if member_count is None else member_count >= 250 large = None if self._member_count is 0 else self._member_count >= 250
self._large: Optional[bool] = guild.get('large', large) self._large: Optional[bool] = guild.get('large', large)
if (settings := guild.get('settings')) is not None: if (settings := guild.get('settings')) is not None:
@ -552,9 +557,8 @@ class Guild(Hashable):
members, which for this library is set to the maximum of 250. members, which for this library is set to the maximum of 250.
""" """
if self._large is None: if self._large is None:
try: if self._member_count is not None:
return self._member_count >= 250 return self._member_count >= 250
except AttributeError:
return len(self._members) >= 250 return len(self._members) >= 250
return self._large return self._large
@ -958,13 +962,16 @@ class Guild(Hashable):
return Asset._from_guild_image(self._state, self.id, self._discovery_splash, path='discovery-splashes') return Asset._from_guild_image(self._state, self.id, self._discovery_splash, path='discovery-splashes')
@property @property
def member_count(self) -> int: def member_count(self) -> Optional[int]:
""":class:`int`: Returns the true member count regardless of it being loaded fully or not. """Optional[:class:`int`]: Returns the member count if available.
.. warning:: .. warning::
Due to a Discord limitation, this may not always be up-to-date and accurate. Due to a Discord limitation, this may not always be up-to-date and accurate.
.. versionchanged:: 2.0
Now returns an ``Optional[int]``.
""" """
return self._member_count return self._member_count
@ -1054,7 +1061,7 @@ class Guild(Hashable):
self, self,
name: str, name: str,
channel_type: Literal[ChannelType.text], channel_type: Literal[ChannelType.text],
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ..., overwrites: Mapping[Union[Role, Member], PermissionOverwrite] = ...,
category: Optional[Snowflake] = ..., category: Optional[Snowflake] = ...,
**options: Any, **options: Any,
) -> Coroutine[Any, Any, TextChannelPayload]: ) -> Coroutine[Any, Any, TextChannelPayload]:
@ -1065,7 +1072,7 @@ class Guild(Hashable):
self, self,
name: str, name: str,
channel_type: Literal[ChannelType.voice], channel_type: Literal[ChannelType.voice],
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ..., overwrites: Mapping[Union[Role, Member], PermissionOverwrite] = ...,
category: Optional[Snowflake] = ..., category: Optional[Snowflake] = ...,
**options: Any, **options: Any,
) -> Coroutine[Any, Any, VoiceChannelPayload]: ) -> Coroutine[Any, Any, VoiceChannelPayload]:
@ -1076,7 +1083,7 @@ class Guild(Hashable):
self, self,
name: str, name: str,
channel_type: Literal[ChannelType.stage_voice], channel_type: Literal[ChannelType.stage_voice],
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ..., overwrites: Mapping[Union[Role, Member], PermissionOverwrite] = ...,
category: Optional[Snowflake] = ..., category: Optional[Snowflake] = ...,
**options: Any, **options: Any,
) -> Coroutine[Any, Any, StageChannelPayload]: ) -> Coroutine[Any, Any, StageChannelPayload]:
@ -1087,7 +1094,7 @@ class Guild(Hashable):
self, self,
name: str, name: str,
channel_type: Literal[ChannelType.category], channel_type: Literal[ChannelType.category],
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ..., overwrites: Mapping[Union[Role, Member], PermissionOverwrite] = ...,
category: Optional[Snowflake] = ..., category: Optional[Snowflake] = ...,
**options: Any, **options: Any,
) -> Coroutine[Any, Any, CategoryChannelPayload]: ) -> Coroutine[Any, Any, CategoryChannelPayload]:
@ -1098,7 +1105,7 @@ class Guild(Hashable):
self, self,
name: str, name: str,
channel_type: Literal[ChannelType.news], channel_type: Literal[ChannelType.news],
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ..., overwrites: Mapping[Union[Role, Member], PermissionOverwrite] = ...,
category: Optional[Snowflake] = ..., category: Optional[Snowflake] = ...,
**options: Any, **options: Any,
) -> Coroutine[Any, Any, NewsChannelPayload]: ) -> Coroutine[Any, Any, NewsChannelPayload]:
@ -1109,7 +1116,7 @@ class Guild(Hashable):
self, self,
name: str, name: str,
channel_type: Literal[ChannelType.store], channel_type: Literal[ChannelType.store],
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ..., overwrites: Mapping[Union[Role, Member], PermissionOverwrite] = ...,
category: Optional[Snowflake] = ..., category: Optional[Snowflake] = ...,
**options: Any, **options: Any,
) -> Coroutine[Any, Any, StoreChannelPayload]: ) -> Coroutine[Any, Any, StoreChannelPayload]:
@ -1120,7 +1127,7 @@ class Guild(Hashable):
self, self,
name: str, name: str,
channel_type: Literal[ChannelType.text], channel_type: Literal[ChannelType.text],
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ..., overwrites: Mapping[Union[Role, Member], PermissionOverwrite] = ...,
category: Optional[Snowflake] = ..., category: Optional[Snowflake] = ...,
**options: Any, **options: Any,
) -> Coroutine[Any, Any, GuildChannelPayload]: ) -> Coroutine[Any, Any, GuildChannelPayload]:
@ -1131,7 +1138,7 @@ class Guild(Hashable):
self, self,
name: str, name: str,
channel_type: ChannelType, channel_type: ChannelType,
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ..., overwrites: Mapping[Union[Role, Member], PermissionOverwrite] = ...,
category: Optional[Snowflake] = ..., category: Optional[Snowflake] = ...,
**options: Any, **options: Any,
) -> Coroutine[Any, Any, GuildChannelPayload]: ) -> Coroutine[Any, Any, GuildChannelPayload]:
@ -1141,13 +1148,13 @@ class Guild(Hashable):
self, self,
name: str, name: str,
channel_type: ChannelType, channel_type: ChannelType,
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = MISSING, overwrites: Mapping[Union[Role, Member], PermissionOverwrite] = MISSING,
category: Optional[Snowflake] = None, category: Optional[Snowflake] = None,
**options: Any, **options: Any,
) -> Coroutine[Any, Any, GuildChannelPayload]: ) -> Coroutine[Any, Any, GuildChannelPayload]:
if overwrites is MISSING: if overwrites is MISSING:
overwrites = {} overwrites = {}
elif not isinstance(overwrites, dict): elif not isinstance(overwrites, Mapping):
raise TypeError('overwrites parameter expects a dict') raise TypeError('overwrites parameter expects a dict')
perms = [] perms = []
@ -1180,7 +1187,7 @@ class Guild(Hashable):
topic: str = MISSING, topic: str = MISSING,
slowmode_delay: int = MISSING, slowmode_delay: int = MISSING,
nsfw: bool = MISSING, nsfw: bool = MISSING,
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = MISSING, overwrites: Mapping[Union[Role, Member], PermissionOverwrite] = MISSING,
) -> TextChannel: ) -> TextChannel:
"""|coro| """|coro|
@ -1201,8 +1208,8 @@ class Guild(Hashable):
will be required to update the position of the channel in the channel list. will be required to update the position of the channel in the channel list.
.. versionchanged:: 2.0 .. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising This function will now raise :exc:`TypeError` instead of
:exc:`TypeError`. ``InvalidArgument``.
Examples Examples
---------- ----------
@ -1297,15 +1304,15 @@ class Guild(Hashable):
user_limit: int = MISSING, user_limit: int = MISSING,
rtc_region: Optional[str] = MISSING, rtc_region: Optional[str] = MISSING,
video_quality_mode: VideoQualityMode = MISSING, video_quality_mode: VideoQualityMode = MISSING,
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = MISSING, overwrites: Mapping[Union[Role, Member], PermissionOverwrite] = MISSING,
) -> VoiceChannel: ) -> VoiceChannel:
"""|coro| """|coro|
This is similar to :meth:`create_text_channel` except makes a :class:`VoiceChannel` instead. This is similar to :meth:`create_text_channel` except makes a :class:`VoiceChannel` instead.
.. versionchanged:: 2.0 .. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising This function will now raise :exc:`TypeError` instead of
:exc:`TypeError`. ``InvalidArgument``.
Parameters Parameters
----------- -----------
@ -1383,7 +1390,7 @@ class Guild(Hashable):
*, *,
topic: str, topic: str,
position: int = MISSING, position: int = MISSING,
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = MISSING, overwrites: Mapping[Union[Role, Member], PermissionOverwrite] = MISSING,
category: Optional[CategoryChannel] = None, category: Optional[CategoryChannel] = None,
reason: Optional[str] = None, reason: Optional[str] = None,
) -> StageChannel: ) -> StageChannel:
@ -1394,8 +1401,8 @@ class Guild(Hashable):
.. versionadded:: 1.7 .. versionadded:: 1.7
.. versionchanged:: 2.0 .. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising This function will now raise :exc:`TypeError` instead of
:exc:`TypeError`. ``InvalidArgument``.
Parameters Parameters
----------- -----------
@ -1451,7 +1458,7 @@ class Guild(Hashable):
self, self,
name: str, name: str,
*, *,
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = MISSING, overwrites: Mapping[Union[Role, Member], PermissionOverwrite] = MISSING,
reason: Optional[str] = None, reason: Optional[str] = None,
position: int = MISSING, position: int = MISSING,
) -> CategoryChannel: ) -> CategoryChannel:
@ -1465,8 +1472,8 @@ class Guild(Hashable):
cannot have categories. cannot have categories.
.. versionchanged:: 2.0 .. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising This function will now raise :exc:`TypeError` instead of
:exc:`TypeError`. ``InvalidArgument``.
Raises Raises
------ ------
@ -1575,8 +1582,8 @@ class Guild(Hashable):
The ``region`` keyword parameter has been removed. The ``region`` keyword parameter has been removed.
.. versionchanged:: 2.0 .. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising This function will now raise :exc:`TypeError` or
:exc:`ValueError` or :exc:`TypeError` in various cases. :exc:`ValueError` instead of ``InvalidArgument``.
.. versionchanged:: 2.0 .. versionchanged:: 2.0
The ``preferred_locale`` keyword parameter now accepts an enum instead of :class:`str`. The ``preferred_locale`` keyword parameter now accepts an enum instead of :class:`str`.
@ -1635,11 +1642,11 @@ class Guild(Hashable):
The new preferred locale for the guild. Used as the primary language in the guild. The new preferred locale for the guild. Used as the primary language in the guild.
rules_channel: Optional[:class:`TextChannel`] rules_channel: Optional[:class:`TextChannel`]
The new channel that is used for rules. This is only available to The new channel that is used for rules. This is only available to
guilds that contain ``PUBLIC`` in :attr:`Guild.features`. Could be ``None`` for no rules guilds that contain ``COMMUNITY`` in :attr:`Guild.features`. Could be ``None`` for no rules
channel. channel.
public_updates_channel: Optional[:class:`TextChannel`] public_updates_channel: Optional[:class:`TextChannel`]
The new channel that is used for public updates from Discord. This is only available to The new channel that is used for public updates from Discord. This is only available to
guilds that contain ``PUBLIC`` in :attr:`Guild.features`. Could be ``None`` for no guilds that contain ``COMMUNITY`` in :attr:`Guild.features`. Could be ``None`` for no
public updates channel. public updates channel.
premium_progress_bar_enabled: :class:`bool` premium_progress_bar_enabled: :class:`bool`
Whether the premium AKA server boost level progress bar should be enabled for the guild. Whether the premium AKA server boost level progress bar should be enabled for the guild.
@ -1875,7 +1882,7 @@ class Guild(Hashable):
HTTPException HTTPException
Fetching the profile failed. Fetching the profile failed.
InvalidData InvalidData
The profile is not from this guild. The member is not in this guild.
Returns Returns
-------- --------
@ -1928,7 +1935,7 @@ class Guild(Hashable):
:class:`BanEntry` :class:`BanEntry`
The :class:`BanEntry` object for the specified user. The :class:`BanEntry` object for the specified user.
""" """
data: BanPayload = await self._state.http.get_ban(user.id, self.id) data = await self._state.http.get_ban(user.id, self.id)
return BanEntry(user=User(state=self._state, data=data['user']), reason=data['reason']) return BanEntry(user=User(state=self._state, data=data['user']), reason=data['reason'])
async def fetch_channel(self, channel_id: int, /) -> Union[GuildChannel, Thread]: async def fetch_channel(self, channel_id: int, /) -> Union[GuildChannel, Thread]:
@ -1996,7 +2003,7 @@ class Guild(Hashable):
List[:class:`BanEntry`] List[:class:`BanEntry`]
A list of :class:`BanEntry` objects. A list of :class:`BanEntry` objects.
""" """
data: List[BanPayload] = await self._state.http.get_bans(self.id) data = await self._state.http.get_bans(self.id)
return [BanEntry(user=User(state=self._state, data=e['user']), reason=e['reason']) for e in data] return [BanEntry(user=User(state=self._state, data=e['user']), reason=e['reason']) for e in data]
async def prune_members( async def prune_members(
@ -2004,7 +2011,7 @@ class Guild(Hashable):
*, *,
days: int, days: int,
compute_prune_count: bool = True, compute_prune_count: bool = True,
roles: List[Snowflake] = MISSING, roles: Collection[Snowflake] = MISSING,
reason: Optional[str] = None, reason: Optional[str] = None,
) -> Optional[int]: ) -> Optional[int]:
r"""|coro| r"""|coro|
@ -2026,8 +2033,8 @@ class Guild(Hashable):
The ``roles`` keyword-only parameter was added. The ``roles`` keyword-only parameter was added.
.. versionchanged:: 2.0 .. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising This function will now raise :exc:`TypeError` instead of
:exc:`TypeError`. ``InvalidArgument``.
Parameters Parameters
----------- -----------
@ -2118,7 +2125,7 @@ class Guild(Hashable):
data = await self._state.http.guild_webhooks(self.id) data = await self._state.http.guild_webhooks(self.id)
return [Webhook.from_state(d, state=self._state) for d in data] return [Webhook.from_state(d, state=self._state) for d in data]
async def estimate_pruned_members(self, *, days: int, roles: List[Snowflake] = MISSING) -> Optional[int]: async def estimate_pruned_members(self, *, days: int, roles: Collection[Snowflake] = MISSING) -> Optional[int]:
"""|coro| """|coro|
Similar to :meth:`prune_members` except instead of actually Similar to :meth:`prune_members` except instead of actually
@ -2129,8 +2136,8 @@ class Guild(Hashable):
The returned value can be ``None``. The returned value can be ``None``.
.. versionchanged:: 2.0 .. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising This function will now raise :exc:`TypeError` instead of
:exc:`TypeError`. ``InvalidArgument``.
Parameters Parameters
----------- -----------
@ -2673,7 +2680,7 @@ class Guild(Hashable):
*, *,
name: str, name: str,
image: bytes, image: bytes,
roles: List[Role] = MISSING, roles: Collection[Role] = MISSING,
reason: Optional[str] = None, reason: Optional[str] = None,
) -> Emoji: ) -> Emoji:
r"""|coro| r"""|coro|
@ -2831,8 +2838,8 @@ class Guild(Hashable):
The ``display_icon`` keyword-only parameter was added. The ``display_icon`` keyword-only parameter was added.
.. versionchanged:: 2.0 .. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising This function will now raise :exc:`TypeError` instead of
:exc:`TypeError`. ``InvalidArgument``.
Parameters Parameters
----------- -----------
@ -2925,7 +2932,7 @@ class Guild(Hashable):
# TODO: add to cache # TODO: add to cache
return role return role
async def edit_role_positions(self, positions: Dict[Snowflake, int], *, reason: Optional[str] = None) -> List[Role]: async def edit_role_positions(self, positions: Mapping[Snowflake, int], *, reason: Optional[str] = None) -> List[Role]:
"""|coro| """|coro|
Bulk edits a list of :class:`Role` in the guild. Bulk edits a list of :class:`Role` in the guild.
@ -2936,8 +2943,8 @@ class Guild(Hashable):
.. versionadded:: 1.4 .. versionadded:: 1.4
.. versionchanged:: 2.0 .. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising This function will now raise :exc:`TypeError` instead of
:exc:`TypeError`. ``InvalidArgument``.
Example Example
---------- ----------
@ -2974,7 +2981,7 @@ class Guild(Hashable):
List[:class:`Role`] List[:class:`Role`]
A list of all the roles in the guild. A list of all the roles in the guild.
""" """
if not isinstance(positions, dict): if not isinstance(positions, Mapping):
raise TypeError('positions parameter expects a dict') raise TypeError('positions parameter expects a dict')
role_positions = [] role_positions = []
@ -3024,7 +3031,7 @@ class Guild(Hashable):
user: Snowflake, user: Snowflake,
*, *,
reason: Optional[str] = None, reason: Optional[str] = None,
delete_message_days: Literal[0, 1, 2, 3, 4, 5, 6, 7] = 1, delete_message_days: int = 1,
) -> None: ) -> None:
"""|coro| """|coro|
@ -3080,6 +3087,16 @@ class Guild(Hashable):
""" """
await self._state.http.unban(user.id, self.id, reason=reason) await self._state.http.unban(user.id, self.id, reason=reason)
@property
def vanity_url(self) -> Optional[str]:
"""Optional[:class:`str`]: The Discord vanity invite URL for this guild, if available.
.. versionadded:: 2.0
"""
if self.vanity_url_code is None:
return None
return f'{Invite.BASE}/{self.vanity_url_code}'
async def vanity_invite(self) -> Optional[Invite]: async def vanity_invite(self) -> Optional[Invite]:
"""|coro| """|coro|
@ -3230,7 +3247,7 @@ class Guild(Hashable):
after = Object(id=utils.time_snowflake(after, high=True)) after = Object(id=utils.time_snowflake(after, high=True))
if oldest_first is MISSING: if oldest_first is MISSING:
reverse = after is not None reverse = after is not MISSING
else: else:
reverse = oldest_first reverse = oldest_first
@ -3340,7 +3357,7 @@ class Guild(Hashable):
HTTPException HTTPException
Editing the widget failed. Editing the widget failed.
""" """
payload = {} payload: EditWidgetSettings = {}
if channel is not MISSING: if channel is not MISSING:
payload['channel_id'] = None if channel is None else channel.id payload['channel_id'] = None if channel is None else channel.id
if enabled is not MISSING: if enabled is not MISSING:
@ -3452,14 +3469,14 @@ class Guild(Hashable):
The members that belong to this guild. The members that belong to this guild.
""" """
if self._offline_members_hidden: if self._offline_members_hidden:
raise ClientException('This guild cannot be chunked.') raise ClientException('This guild cannot be chunked')
if self._state.is_guild_evicted(self): if self._state.is_guild_evicted(self):
raise ClientException('This guild is no longer available.') raise ClientException('This guild is no longer available')
members = await self._state.chunk_guild(self, channels=[channel] if channel else []) members = await self._state.chunk_guild(self, channels=[channel] if channel else [])
if members is None: if members is None:
raise ClientException('Chunking failed.') raise ClientException('Chunking failed')
return members # type: ignore return members
async def fetch_members( async def fetch_members(
self, self,
@ -3509,14 +3526,14 @@ class Guild(Hashable):
The members that belong to this guild (offline members may not be included). The members that belong to this guild (offline members may not be included).
""" """
if self._state.is_guild_evicted(self): if self._state.is_guild_evicted(self):
raise ClientException('This guild is no longer available.') raise ClientException('This guild is no longer available')
members = await self._state.scrape_guild( members = await self._state.scrape_guild(
self, cache=cache, force_scraping=force_scraping, delay=delay, channels=channels self, cache=cache, force_scraping=force_scraping, delay=delay, channels=channels
) )
if members is None: if members is None:
raise ClientException('Fetching members failed') raise ClientException('Fetching members failed')
return members # type: ignore return members
async def query_members( async def query_members(
self, self,
@ -3533,7 +3550,7 @@ class Guild(Hashable):
Request members that belong to this guild whose username starts with Request members that belong to this guild whose username starts with
the query given. the query given.
This is a websocket operation and can be slow. This is a websocket operation.
.. note:: .. note::
This is preferrable to using :meth:`fetch_member` as the client uses This is preferrable to using :meth:`fetch_member` as the client uses

73
discord/http.py

@ -57,6 +57,7 @@ from .errors import HTTPException, Forbidden, NotFound, LoginFailure, DiscordSer
from .file import File from .file import File
from .tracking import ContextProperties from .tracking import ContextProperties
from . import utils from . import utils
from .mentions import AllowedMentions
from .utils import MISSING from .utils import MISSING
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
@ -78,6 +79,7 @@ if TYPE_CHECKING:
appinfo, appinfo,
audit_log, audit_log,
channel, channel,
command,
emoji, emoji,
guild, guild,
integration, integration,
@ -91,7 +93,6 @@ if TYPE_CHECKING:
widget, widget,
team, team,
threads, threads,
voice,
scheduled_event, scheduled_event,
sticker, sticker,
welcome_screen, welcome_screen,
@ -121,9 +122,9 @@ async def json_or_text(response: aiohttp.ClientResponse) -> Union[Dict[str, Any]
class MultipartParameters(NamedTuple): class MultipartParameters(NamedTuple):
payload: Optional[Dict[str, Any]] payload: Optional[Dict[str, Any]]
multipart: Optional[List[Dict[str, Any]]] multipart: Optional[List[Dict[str, Any]]]
files: Optional[List[File]] files: Optional[Sequence[File]]
def __enter__(self): def __enter__(self) -> Self:
return self return self
def __exit__( def __exit__(
@ -146,10 +147,10 @@ def handle_message_parameters(
nonce: Optional[Union[int, str]] = None, nonce: Optional[Union[int, str]] = None,
flags: MessageFlags = MISSING, flags: MessageFlags = MISSING,
file: File = MISSING, file: File = MISSING,
files: List[File] = MISSING, files: Sequence[File] = MISSING,
embed: Optional[Embed] = MISSING, embed: Optional[Embed] = MISSING,
embeds: List[Embed] = MISSING, embeds: Sequence[Embed] = MISSING,
attachments: List[Union[Attachment, File]] = MISSING, attachments: Sequence[Union[Attachment, File]] = MISSING,
allowed_mentions: Optional[AllowedMentions] = MISSING, allowed_mentions: Optional[AllowedMentions] = MISSING,
message_reference: Optional[message.MessageReference] = MISSING, message_reference: Optional[message.MessageReference] = MISSING,
stickers: Optional[SnowflakeList] = MISSING, stickers: Optional[SnowflakeList] = MISSING,
@ -215,15 +216,12 @@ def handle_message_parameters(
payload['allowed_mentions'] = previous_allowed_mentions.to_dict() payload['allowed_mentions'] = previous_allowed_mentions.to_dict()
if mention_author is not None: if mention_author is not None:
try: if 'allowed_mentions' not in payload:
payload['allowed_mentions'] = AllowedMentions().to_dict()
payload['allowed_mentions']['replied_user'] = mention_author payload['allowed_mentions']['replied_user'] = mention_author
except KeyError:
payload['allowed_mentions'] = {
'replied_user': mention_author,
}
if attachments is MISSING: if attachments is MISSING:
attachments = files # type: ignore attachments = files
else: else:
files = [a for a in attachments if isinstance(a, File)] files = [a for a in attachments if isinstance(a, File)]
@ -322,23 +320,24 @@ class HTTPClient:
def __init__( def __init__(
self, self,
loop: asyncio.AbstractEventLoop,
connector: Optional[aiohttp.BaseConnector] = None, connector: Optional[aiohttp.BaseConnector] = None,
*, *,
proxy: Optional[str] = None, proxy: Optional[str] = None,
proxy_auth: Optional[aiohttp.BasicAuth] = None, proxy_auth: Optional[aiohttp.BasicAuth] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
unsync_clock: bool = True, unsync_clock: bool = True,
http_trace: Optional[aiohttp.TraceConfig] = None,
) -> None: ) -> None:
self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop self.loop: asyncio.AbstractEventLoop = loop
self.connector: aiohttp.BaseConnector = connector or aiohttp.TCPConnector(limit=0) self.connector: aiohttp.BaseConnector = connector or MISSING
self.__session: aiohttp.ClientSession = MISSING self.__session: aiohttp.ClientSession = MISSING
self._locks: weakref.WeakValueDictionary = weakref.WeakValueDictionary() self._locks: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
self._global_over: asyncio.Event = asyncio.Event() self._global_over: asyncio.Event = MISSING
self._global_over.set()
self.token: Optional[str] = None self.token: Optional[str] = None
self.ack_token: Optional[str] = None self.ack_token: Optional[str] = None
self.proxy: Optional[str] = proxy self.proxy: Optional[str] = proxy
self.proxy_auth: Optional[aiohttp.BasicAuth] = proxy_auth self.proxy_auth: Optional[aiohttp.BasicAuth] = proxy_auth
self.http_trace: Optional[aiohttp.TraceConfig] = http_trace
self.use_clock: bool = not unsync_clock self.use_clock: bool = not unsync_clock
self.user_agent: str = MISSING self.user_agent: str = MISSING
@ -357,7 +356,12 @@ class HTTPClient:
async def startup(self) -> None: async def startup(self) -> None:
if self._started: if self._started:
return return
self.__session = session = aiohttp.ClientSession(connector=self.connector)
self.__session = session = aiohttp.ClientSession(
connector=self.connector,
loop=self.loop,
trace_configs=None if self.http_trace is None else [self.http_trace],
)
self.user_agent, self.browser_version, self.client_build_number = ua, bv, bn = await utils._get_info(session) self.user_agent, self.browser_version, self.client_build_number = ua, bv, bn = await utils._get_info(session)
_log.info('Found user agent %s (%s), build number %s.', ua, bv, bn) _log.info('Found user agent %s (%s), build number %s.', ua, bv, bn)
self.super_properties = sp = { self.super_properties = sp = {
@ -587,7 +591,11 @@ class HTTPClient:
def recreate(self) -> None: def recreate(self) -> None:
if self.__session and self.__session.closed: if self.__session and self.__session.closed:
self.__session = aiohttp.ClientSession(connector=self.connector) self.__session = aiohttp.ClientSession(
connector=self.connector,
loop=self.loop,
trace_configs=None if self.http_trace is None else [self.http_trace],
)
async def close(self) -> None: async def close(self) -> None:
if self.__session: if self.__session:
@ -599,13 +607,19 @@ class HTTPClient:
self.token = token self.token = token
self.ack_token = None self.ack_token = None
def get_me(self, with_analytics_token=True) -> Response[user.User]: def get_me(self, with_analytics_token: bool = True) -> Response[user.User]:
params = {'with_analytics_token': str(with_analytics_token).lower()} params = {'with_analytics_token': str(with_analytics_token).lower()}
return self.request(Route('GET', '/users/@me'), params=params) return self.request(Route('GET', '/users/@me'), params=params)
async def static_login(self, token: str) -> user.User: async def static_login(self, token: str) -> user.User:
old_token, self.token = self.token, token old_token, self.token = self.token, token
if self.connector is MISSING:
self.connector = aiohttp.TCPConnector(loop=self.loop, limit=0)
self._global_over = asyncio.Event()
self._global_over.set()
await self.startup() await self.startup()
try: try:
@ -1254,15 +1268,14 @@ class HTTPClient:
def sync_template(self, guild_id: Snowflake, code: str) -> Response[template.Template]: def sync_template(self, guild_id: Snowflake, code: str) -> Response[template.Template]:
return self.request(Route('PUT', '/guilds/{guild_id}/templates/{code}', guild_id=guild_id, code=code)) return self.request(Route('PUT', '/guilds/{guild_id}/templates/{code}', guild_id=guild_id, code=code))
def edit_template(self, guild_id: Snowflake, code: str, payload) -> Response[template.Template]: def edit_template(self, guild_id: Snowflake, code: str, payload: Dict[str, Any]) -> Response[template.Template]:
r = Route('PATCH', '/guilds/{guild_id}/templates/{code}', guild_id=guild_id, code=code)
valid_keys = ( valid_keys = (
'name', 'name',
'description', 'description',
) )
payload = {k: v for k, v in payload.items() if k in valid_keys} payload = {k: v for k, v in payload.items() if k in valid_keys}
return self.request(r, json=payload) return self.request(Route('PATCH', '/guilds/{guild_id}/templates/{code}', guild_id=guild_id, code=code), json=payload)
def delete_template(self, guild_id: Snowflake, code: str) -> Response[None]: def delete_template(self, guild_id: Snowflake, code: str) -> Response[None]:
return self.request(Route('DELETE', '/guilds/{guild_id}/templates/{code}', guild_id=guild_id, code=code)) return self.request(Route('DELETE', '/guilds/{guild_id}/templates/{code}', guild_id=guild_id, code=code))
@ -1300,7 +1313,7 @@ class HTTPClient:
guild_id: Snowflake, guild_id: Snowflake,
days: int, days: int,
compute_prune_count: bool, compute_prune_count: bool,
roles: List[str], roles: Iterable[str],
*, *,
reason: Optional[str] = None, reason: Optional[str] = None,
) -> Response[guild.GuildPrune]: ) -> Response[guild.GuildPrune]:
@ -1317,7 +1330,7 @@ class HTTPClient:
self, self,
guild_id: Snowflake, guild_id: Snowflake,
days: int, days: int,
roles: List[str], roles: Iterable[str],
) -> Response[guild.GuildPrune]: ) -> Response[guild.GuildPrune]:
params: Dict[str, Any] = { params: Dict[str, Any] = {
'days': days, 'days': days,
@ -1330,6 +1343,9 @@ class HTTPClient:
def get_sticker(self, sticker_id: Snowflake) -> Response[sticker.Sticker]: def get_sticker(self, sticker_id: Snowflake) -> Response[sticker.Sticker]:
return self.request(Route('GET', '/stickers/{sticker_id}', sticker_id=sticker_id)) return self.request(Route('GET', '/stickers/{sticker_id}', sticker_id=sticker_id))
def get_sticker_guild(self, sticker_id: Snowflake) -> Response[guild.Guild]:
return self.request(Route('GET', '/stickers/{sticker_id}/guild', sticker_id=sticker_id))
def list_premium_sticker_packs( def list_premium_sticker_packs(
self, country: str = 'US', locale: str = 'en-US', payment_source_id: Snowflake = MISSING self, country: str = 'US', locale: str = 'en-US', payment_source_id: Snowflake = MISSING
) -> Response[sticker.ListPremiumStickerPacks]: ) -> Response[sticker.ListPremiumStickerPacks]:
@ -1413,6 +1429,9 @@ class HTTPClient:
def get_custom_emoji(self, guild_id: Snowflake, emoji_id: Snowflake) -> Response[emoji.Emoji]: def get_custom_emoji(self, guild_id: Snowflake, emoji_id: Snowflake) -> Response[emoji.Emoji]:
return self.request(Route('GET', '/guilds/{guild_id}/emojis/{emoji_id}', guild_id=guild_id, emoji_id=emoji_id)) return self.request(Route('GET', '/guilds/{guild_id}/emojis/{emoji_id}', guild_id=guild_id, emoji_id=emoji_id))
def get_emoji_guild(self, emoji_id: Snowflake) -> Response[guild.Guild]:
return self.request(Route('GET', '/emojis/{emoji_id}', emoji_id=emoji_id))
def create_custom_emoji( def create_custom_emoji(
self, self,
guild_id: Snowflake, guild_id: Snowflake,
@ -1533,7 +1552,9 @@ class HTTPClient:
def get_widget(self, guild_id: Snowflake) -> Response[widget.Widget]: def get_widget(self, guild_id: Snowflake) -> Response[widget.Widget]:
return self.request(Route('GET', '/guilds/{guild_id}/widget.json', guild_id=guild_id)) return self.request(Route('GET', '/guilds/{guild_id}/widget.json', guild_id=guild_id))
def edit_widget(self, guild_id: Snowflake, payload, reason: Optional[str] = None) -> Response[widget.WidgetSettings]: def edit_widget(
self, guild_id: Snowflake, payload: widget.EditWidgetSettings, reason: Optional[str] = None
) -> Response[widget.WidgetSettings]:
return self.request(Route('PATCH', '/guilds/{guild_id}/widget', guild_id=guild_id), json=payload, reason=reason) return self.request(Route('PATCH', '/guilds/{guild_id}/widget', guild_id=guild_id), json=payload, reason=reason)
def get_welcome_screen(self, guild_id: Snowflake) -> Response[welcome_screen.WelcomeScreen]: def get_welcome_screen(self, guild_id: Snowflake) -> Response[welcome_screen.WelcomeScreen]:

21
discord/integrations.py

@ -39,6 +39,9 @@ __all__ = (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from .guild import Guild
from .role import Role
from .state import ConnectionState
from .types.integration import ( from .types.integration import (
IntegrationAccount as IntegrationAccountPayload, IntegrationAccount as IntegrationAccountPayload,
Integration as IntegrationPayload, Integration as IntegrationPayload,
@ -47,8 +50,6 @@ if TYPE_CHECKING:
IntegrationType, IntegrationType,
IntegrationApplication as IntegrationApplicationPayload, IntegrationApplication as IntegrationApplicationPayload,
) )
from .guild import Guild
from .role import Role
class IntegrationAccount: class IntegrationAccount:
@ -109,11 +110,11 @@ class Integration:
) )
def __init__(self, *, data: IntegrationPayload, guild: Guild) -> None: def __init__(self, *, data: IntegrationPayload, guild: Guild) -> None:
self.guild = guild self.guild: Guild = guild
self._state = guild._state self._state: ConnectionState = guild._state
self._from_data(data) self._from_data(data)
def __repr__(self): def __repr__(self) -> str:
return f"<{self.__class__.__name__} id={self.id} name={self.name!r}>" return f"<{self.__class__.__name__} id={self.id} name={self.name!r}>"
def _from_data(self, data: IntegrationPayload) -> None: def _from_data(self, data: IntegrationPayload) -> None:
@ -123,7 +124,7 @@ class Integration:
self.account: IntegrationAccount = IntegrationAccount(data['account']) self.account: IntegrationAccount = IntegrationAccount(data['account'])
user = data.get('user') user = data.get('user')
self.user = User(state=self._state, data=user) if user else None self.user: Optional[User] = User(state=self._state, data=user) if user else None
self.enabled: bool = data['enabled'] self.enabled: bool = data['enabled']
async def delete(self, *, reason: Optional[str] = None) -> None: async def delete(self, *, reason: Optional[str] = None) -> None:
@ -232,8 +233,8 @@ class StreamIntegration(Integration):
do this. do this.
.. versionchanged:: 2.0 .. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising This function will now raise :exc:`TypeError` instead of
:exc:`TypeError`. ``InvalidArgument``.
Parameters Parameters
----------- -----------
@ -319,7 +320,7 @@ class IntegrationApplication:
'user', 'user',
) )
def __init__(self, *, data: IntegrationApplicationPayload, state): def __init__(self, *, data: IntegrationApplicationPayload, state: ConnectionState) -> None:
self.id: int = int(data['id']) self.id: int = int(data['id'])
self.name: str = data['name'] self.name: str = data['name']
self.icon: Optional[str] = data['icon'] self.icon: Optional[str] = data['icon']
@ -358,7 +359,7 @@ class BotIntegration(Integration):
def _from_data(self, data: BotIntegrationPayload) -> None: def _from_data(self, data: BotIntegrationPayload) -> None:
super()._from_data(data) super()._from_data(data)
self.application = IntegrationApplication(data=data['application'], state=self._state) self.application: IntegrationApplication = IntegrationApplication(data=data['application'], state=self._state)
def _integration_factory(value: str) -> Tuple[Type[Integration], str]: def _integration_factory(value: str) -> Tuple[Type[Integration], str]:

59
discord/invite.py

@ -30,7 +30,7 @@ from .utils import parse_time, snowflake_time, _get_as_snowflake, MISSING
from .object import Object from .object import Object
from .mixins import Hashable from .mixins import Hashable
from .scheduled_event import ScheduledEvent from .scheduled_event import ScheduledEvent
from .enums import ChannelType, VerificationLevel, InviteTarget, InviteType, try_enum from .enums import ChannelType, VerificationLevel, InviteTarget, InviteType, NSFWLevel, try_enum
from .welcome_screen import WelcomeScreen from .welcome_screen import WelcomeScreen
__all__ = ( __all__ = (
@ -165,9 +165,34 @@ class PartialInviteGuild:
A list of features the guild has. See :attr:`Guild.features` for more information. A list of features the guild has. See :attr:`Guild.features` for more information.
description: Optional[:class:`str`] description: Optional[:class:`str`]
The partial guild's description. The partial guild's description.
nsfw_level: :class:`NSFWLevel`
The partial guild's NSFW level.
.. versionadded:: 2.0
vanity_url_code: Optional[:class:`str`]
The partial guild's vanity URL code, if available.
.. versionadded:: 2.0
premium_subscription_count: :class:`int`
The number of "boosts" the partial guild currently has.
.. versionadded:: 2.0
""" """
__slots__ = ('_state', 'features', '_icon', '_banner', 'id', 'name', '_splash', 'verification_level', 'description') __slots__ = (
'_state',
'_icon',
'_banner',
'_splash',
'features',
'id',
'name',
'verification_level',
'description',
'vanity_url_code',
'nsfw_level',
'premium_subscription_count',
)
def __init__(self, state: ConnectionState, data: InviteGuildPayload, id: int): def __init__(self, state: ConnectionState, data: InviteGuildPayload, id: int):
self._state: ConnectionState = state self._state: ConnectionState = state
@ -179,6 +204,9 @@ class PartialInviteGuild:
self._splash: Optional[str] = data.get('splash') self._splash: Optional[str] = data.get('splash')
self.verification_level: VerificationLevel = try_enum(VerificationLevel, data.get('verification_level')) self.verification_level: VerificationLevel = try_enum(VerificationLevel, data.get('verification_level'))
self.description: Optional[str] = data.get('description') self.description: Optional[str] = data.get('description')
self.vanity_url_code: Optional[str] = data.get('vanity_url_code')
self.nsfw_level: NSFWLevel = try_enum(NSFWLevel, data.get('nsfw_level', 0))
self.premium_subscription_count: int = data.get('premium_subscription_count') or 0
def __str__(self) -> str: def __str__(self) -> str:
return self.name return self.name
@ -194,6 +222,16 @@ class PartialInviteGuild:
""":class:`datetime.datetime`: Returns the guild's creation time in UTC.""" """:class:`datetime.datetime`: Returns the guild's creation time in UTC."""
return snowflake_time(self.id) return snowflake_time(self.id)
@property
def vanity_url(self) -> Optional[str]:
"""Optional[:class:`str`]: The Discord vanity invite URL for this partial guild, if available.
.. versionadded:: 2.0
"""
if self.vanity_url_code is None:
return None
return f'{Invite.BASE}/{self.vanity_url_code}'
@property @property
def icon(self) -> Optional[Asset]: def icon(self) -> Optional[Asset]:
"""Optional[:class:`Asset`]: Returns the guild's icon asset, if available.""" """Optional[:class:`Asset`]: Returns the guild's icon asset, if available."""
@ -446,12 +484,13 @@ class Invite(Hashable):
@classmethod @classmethod
def from_gateway(cls, *, state: ConnectionState, data: GatewayInvitePayload) -> Self: def from_gateway(cls, *, state: ConnectionState, data: GatewayInvitePayload) -> Self:
guild_id: Optional[int] = _get_as_snowflake(data, 'guild_id') guild_id: Optional[int] = _get_as_snowflake(data, 'guild_id')
guild: Optional[Union[Guild, Object]] = state._get_guild(guild_id)
channel_id = _get_as_snowflake(data, 'channel_id') channel_id = _get_as_snowflake(data, 'channel_id')
if guild_id is not None: if guild is not None:
guild: Optional[Union[Guild, Object]] = state._get_guild(guild_id) or Object(id=guild_id) channel = (guild.get_channel(channel_id) or Object(id=channel_id)) if channel_id is not None else None
if channel_id is not None: else:
channel: Optional[InviteChannelType] = state.get_channel(channel_id) or Object(id=channel_id) # type: ignore guild = Object(id=guild_id) if guild_id is not None else None
channel = Object(id=channel_id) if channel_id is not None else None
return cls(state=state, data=data, guild=guild, channel=channel) # type: ignore return cls(state=state, data=data, guild=guild, channel=channel) # type: ignore
@ -543,7 +582,7 @@ class Invite(Hashable):
Raises Raises
------ ------
:exc:`.HTTPException` HTTPException
Using the invite failed. Using the invite failed.
Returns Returns
@ -587,7 +626,7 @@ class Invite(Hashable):
Raises Raises
------ ------
:exc:`.HTTPException` HTTPException
Using the invite failed. Using the invite failed.
Returns Returns
@ -597,7 +636,7 @@ class Invite(Hashable):
""" """
return await self.use() return await self.use()
async def delete(self, *, reason: Optional[str] = None): async def delete(self, *, reason: Optional[str] = None) -> None:
"""|coro| """|coro|
Revokes the instant invite. Revokes the instant invite.

58
discord/member.py

@ -28,7 +28,7 @@ import datetime
import inspect import inspect
import itertools import itertools
from operator import attrgetter from operator import attrgetter
from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Union from typing import Any, Callable, Collection, Coroutine, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Union, Type
import discord.abc import discord.abc
@ -214,7 +214,7 @@ class _ClientStatus:
return self return self
def flatten_user(cls): def flatten_user(cls: Any) -> Type[Member]:
for attr, value in itertools.chain(BaseUser.__dict__.items(), User.__dict__.items()): for attr, value in itertools.chain(BaseUser.__dict__.items(), User.__dict__.items()):
# Ignore private/special methods (or not) # Ignore private/special methods (or not)
# if attr.startswith('_'): # if attr.startswith('_'):
@ -331,7 +331,7 @@ class Member(discord.abc.Messageable, discord.abc.Connectable, _UserTag):
default_avatar: Asset default_avatar: Asset
avatar: Optional[Asset] avatar: Optional[Asset]
dm_channel: Optional[DMChannel] dm_channel: Optional[DMChannel]
create_dm = User.create_dm create_dm: Callable[[], Coroutine[Any, Any, DMChannel]]
mutual_guilds: List[Guild] mutual_guilds: List[Guild]
public_flags: PublicUserFlags public_flags: PublicUserFlags
banner: Optional[Asset] banner: Optional[Asset]
@ -361,10 +361,10 @@ class Member(discord.abc.Messageable, discord.abc.Connectable, _UserTag):
f' bot={self._user.bot} nick={self.nick!r} guild={self.guild!r}>' f' bot={self._user.bot} nick={self.nick!r} guild={self.guild!r}>'
) )
def __eq__(self, other: Any) -> bool: def __eq__(self, other: object) -> bool:
return isinstance(other, _UserTag) and other.id == self.id return isinstance(other, _UserTag) and other.id == self.id
def __ne__(self, other: Any) -> bool: def __ne__(self, other: object) -> bool:
return not self.__eq__(other) return not self.__eq__(other)
def __hash__(self) -> int: def __hash__(self) -> int:
@ -445,7 +445,7 @@ class Member(discord.abc.Messageable, discord.abc.Connectable, _UserTag):
if self._self: if self._self:
return return
self._activities = tuple(map(create_activity, data['activities'])) self._activities = tuple(create_activity(d, self._state) for d in data['activities'])
self._client_status._update(data['status'], data['client_status']) self._client_status._update(data['status'], data['client_status'])
if len(user) > 1: if len(user) > 1:
@ -696,7 +696,7 @@ class Member(discord.abc.Messageable, discord.abc.Connectable, _UserTag):
async def ban( async def ban(
self, self,
*, *,
delete_message_days: Literal[0, 1, 2, 3, 4, 5, 6, 7] = 1, delete_message_days: int = 1,
reason: Optional[str] = None, reason: Optional[str] = None,
) -> None: ) -> None:
"""|coro| """|coro|
@ -726,7 +726,7 @@ class Member(discord.abc.Messageable, discord.abc.Connectable, _UserTag):
mute: bool = MISSING, mute: bool = MISSING,
deafen: bool = MISSING, deafen: bool = MISSING,
suppress: bool = MISSING, suppress: bool = MISSING,
roles: List[discord.abc.Snowflake] = MISSING, roles: Collection[discord.abc.Snowflake] = MISSING,
voice_channel: Optional[VocalGuildChannel] = MISSING, voice_channel: Optional[VocalGuildChannel] = MISSING,
timed_out_until: Optional[datetime.datetime] = MISSING, timed_out_until: Optional[datetime.datetime] = MISSING,
avatar: Optional[bytes] = MISSING, avatar: Optional[bytes] = MISSING,
@ -783,7 +783,7 @@ class Member(discord.abc.Messageable, discord.abc.Connectable, _UserTag):
.. versionadded:: 1.7 .. versionadded:: 1.7
roles: List[:class:`Role`] roles: List[:class:`Role`]
The member's new list of roles. This *replaces* the roles. The member's new list of roles. This *replaces* the roles.
voice_channel: Optional[:class:`VoiceChannel`] voice_channel: Optional[Union[:class:`VoiceChannel`, :class:`StageChannel`]]
The voice channel to move the member to. The voice channel to move the member to.
Pass ``None`` to kick them from voice. Pass ``None`` to kick them from voice.
timed_out_until: Optional[:class:`datetime.datetime`] timed_out_until: Optional[:class:`datetime.datetime`]
@ -913,7 +913,7 @@ class Member(discord.abc.Messageable, discord.abc.Connectable, _UserTag):
else: else:
await self._state.http.edit_my_voice_state(self.guild.id, payload) await self._state.http.edit_my_voice_state(self.guild.id, payload)
async def move_to(self, channel: VocalGuildChannel, *, reason: Optional[str] = None) -> None: async def move_to(self, channel: Optional[VocalGuildChannel], *, reason: Optional[str] = None) -> None:
"""|coro| """|coro|
Moves a member to a new voice channel (they must be connected first). Moves a member to a new voice channel (they must be connected first).
@ -928,7 +928,7 @@ class Member(discord.abc.Messageable, discord.abc.Connectable, _UserTag):
Parameters Parameters
----------- -----------
channel: Optional[:class:`VoiceChannel`] channel: Optional[Union[:class:`VoiceChannel`, :class:`StageChannel`]]
The new voice channel to move the member to. The new voice channel to move the member to.
Pass ``None`` to kick them from voice. Pass ``None`` to kick them from voice.
reason: Optional[:class:`str`] reason: Optional[:class:`str`]
@ -936,6 +936,42 @@ class Member(discord.abc.Messageable, discord.abc.Connectable, _UserTag):
""" """
await self.edit(voice_channel=channel, reason=reason) await self.edit(voice_channel=channel, reason=reason)
async def timeout(self, when: Union[datetime.timedelta, datetime.datetime], /, *, reason: Optional[str] = None) -> None:
"""|coro|
Applies a time out to a member until the specified date time or for the
given :class:`datetime.timedelta`.
You must have the :attr:`~Permissions.moderate_members` permission to
use this.
This raises the same exceptions as :meth:`edit`.
Parameters
-----------
when: Union[:class:`datetime.timedelta`, :class:`datetime.datetime`]
If this is a :class:`datetime.timedelta` then it represents the amount of
time the member should be timed out for. If this is a :class:`datetime.datetime`
then it's when the member's timeout should expire. Note that the API only allows
for timeouts up to 28 days.
reason: Optional[:class:`str`]
The reason for doing this action. Shows up on the audit log.
Raises
-------
TypeError
The ``when`` parameter was the wrong type of the datetime was not timezone-aware.
"""
if isinstance(when, datetime.timedelta):
timed_out_until = utils.utcnow() + when
elif isinstance(when, datetime.datetime):
timed_out_until = when
else:
raise TypeError(f'expected datetime.datetime or datetime.timedelta not {when.__class__!r}')
await self.edit(timed_out_until=timed_out_until, reason=reason)
async def add_roles(self, *roles: Snowflake, reason: Optional[str] = None, atomic: bool = True) -> None: async def add_roles(self, *roles: Snowflake, reason: Optional[str] = None, atomic: bool = True) -> None:
r"""|coro| r"""|coro|

10
discord/mentions.py

@ -23,7 +23,7 @@ DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations from __future__ import annotations
from typing import Union, List, TYPE_CHECKING, Any, Union from typing import Union, List, TYPE_CHECKING, Any
# fmt: off # fmt: off
__all__ = ( __all__ = (
@ -92,10 +92,10 @@ class AllowedMentions:
roles: Union[bool, List[Snowflake]] = default, roles: Union[bool, List[Snowflake]] = default,
replied_user: bool = default, replied_user: bool = default,
): ):
self.everyone = everyone self.everyone: bool = everyone
self.users = users self.users: Union[bool, List[Snowflake]] = users
self.roles = roles self.roles: Union[bool, List[Snowflake]] = roles
self.replied_user = replied_user self.replied_user: bool = replied_user
@classmethod @classmethod
def all(cls) -> Self: def all(cls) -> Self:

2293
discord/message.py

File diff suppressed because it is too large

5
discord/opus.py

@ -363,7 +363,7 @@ class Encoder(_OpusStruct):
_lib.opus_encoder_ctl(self._state, CTL_SET_FEC, 1 if enabled else 0) _lib.opus_encoder_ctl(self._state, CTL_SET_FEC, 1 if enabled else 0)
def set_expected_packet_loss_percent(self, percentage: float) -> None: def set_expected_packet_loss_percent(self, percentage: float) -> None:
_lib.opus_encoder_ctl(self._state, CTL_SET_PLP, min(100, max(0, int(percentage * 100)))) # type: ignore _lib.opus_encoder_ctl(self._state, CTL_SET_PLP, min(100, max(0, int(percentage * 100))))
def encode(self, pcm: bytes, frame_size: int) -> bytes: def encode(self, pcm: bytes, frame_size: int) -> bytes:
max_data_bytes = len(pcm) max_data_bytes = len(pcm)
@ -373,8 +373,7 @@ class Encoder(_OpusStruct):
ret = _lib.opus_encode(self._state, pcm_ptr, frame_size, data, max_data_bytes) ret = _lib.opus_encode(self._state, pcm_ptr, frame_size, data, max_data_bytes)
# array can be initialized with bytes but mypy doesn't know return array.array('b', data[:ret]).tobytes()
return array.array('b', data[:ret]).tobytes() # type: ignore
class Decoder(_OpusStruct): class Decoder(_OpusStruct):

15
discord/partial_emoji.py

@ -42,6 +42,7 @@ if TYPE_CHECKING:
from .state import ConnectionState from .state import ConnectionState
from datetime import datetime from datetime import datetime
from .types.message import PartialEmoji as PartialEmojiPayload from .types.message import PartialEmoji as PartialEmojiPayload
from .types.activity import ActivityEmoji
class _EmojiTag: class _EmojiTag:
@ -99,13 +100,13 @@ class PartialEmoji(_EmojiTag, AssetMixin):
id: Optional[int] id: Optional[int]
def __init__(self, *, name: str, animated: bool = False, id: Optional[int] = None): def __init__(self, *, name: str, animated: bool = False, id: Optional[int] = None):
self.animated = animated self.animated: bool = animated
self.name = name self.name: str = name
self.id = id self.id: Optional[int] = id
self._state: Optional[ConnectionState] = None self._state: Optional[ConnectionState] = None
@classmethod @classmethod
def from_dict(cls, data: Union[PartialEmojiPayload, Dict[str, Any]]) -> Self: def from_dict(cls, data: Union[PartialEmojiPayload, ActivityEmoji, Dict[str, Any]]) -> Self:
return cls( return cls(
animated=data.get('animated', False), animated=data.get('animated', False),
id=utils._get_as_snowflake(data, 'id'), id=utils._get_as_snowflake(data, 'id'),
@ -178,10 +179,10 @@ class PartialEmoji(_EmojiTag, AssetMixin):
return f'<a:{self.name}:{self.id}>' return f'<a:{self.name}:{self.id}>'
return f'<:{self.name}:{self.id}>' return f'<:{self.name}:{self.id}>'
def __repr__(self): def __repr__(self) -> str:
return f'<{self.__class__.__name__} animated={self.animated} name={self.name!r} id={self.id}>' return f'<{self.__class__.__name__} animated={self.animated} name={self.name!r} id={self.id}>'
def __eq__(self, other: Any) -> bool: def __eq__(self, other: object) -> bool:
if self.is_unicode_emoji(): if self.is_unicode_emoji():
return isinstance(other, PartialEmoji) and self.name == other.name return isinstance(other, PartialEmoji) and self.name == other.name
@ -189,7 +190,7 @@ class PartialEmoji(_EmojiTag, AssetMixin):
return self.id == other.id return self.id == other.id
return False return False
def __ne__(self, other: Any) -> bool: def __ne__(self, other: object) -> bool:
return not self.__eq__(other) return not self.__eq__(other)
def __hash__(self) -> int: def __hash__(self) -> int:

19
discord/permissions.py

@ -234,12 +234,19 @@ class Permissions(BaseFlags):
@classmethod @classmethod
def stage_moderator(cls) -> Self: def stage_moderator(cls) -> Self:
"""A factory method that creates a :class:`Permissions` with all """A factory method that creates a :class:`Permissions` with all permissions
"Stage Moderator" permissions from the official Discord UI set to ``True``. for stage moderators set to ``True``. These permissions are currently:
- :attr:`manage_channels`
- :attr:`mute_members`
- :attr:`move_members`
.. versionadded:: 1.7 .. versionadded:: 1.7
.. versionchanged:: 2.0
Added :attr:`manage_channels` permission and removed :attr:`request_to_speak` permission.
""" """
return cls(0b100000001010000000000000000000000) return cls(0b1010000000000000000010000)
@classmethod @classmethod
def advanced(cls) -> Self: def advanced(cls) -> Self:
@ -279,7 +286,7 @@ class Permissions(BaseFlags):
# So 0000 OP2 0101 -> 0101 # So 0000 OP2 0101 -> 0101
# The OP is base & ~denied. # The OP is base & ~denied.
# The OP2 is base | allowed. # The OP2 is base | allowed.
self.value = (self.value & ~deny) | allow self.value: int = (self.value & ~deny) | allow
@flag_value @flag_value
def create_instant_invite(self) -> int: def create_instant_invite(self) -> int:
@ -697,7 +704,7 @@ class PermissionOverwrite:
setattr(self, key, value) setattr(self, key, value)
def __eq__(self, other: Any) -> bool: def __eq__(self, other: object) -> bool:
return isinstance(other, PermissionOverwrite) and self._values == other._values return isinstance(other, PermissionOverwrite) and self._values == other._values
def _set(self, key: str, value: Optional[bool]) -> None: def _set(self, key: str, value: Optional[bool]) -> None:
@ -750,7 +757,7 @@ class PermissionOverwrite:
""" """
return len(self._values) == 0 return len(self._values) == 0
def update(self, **kwargs: bool) -> None: def update(self, **kwargs: Optional[bool]) -> None:
r"""Bulk updates this permission overwrite object. r"""Bulk updates this permission overwrite object.
Allows you to set multiple attributes by using keyword Allows you to set multiple attributes by using keyword

31
discord/player.py

@ -365,12 +365,11 @@ class FFmpegOpusAudio(FFmpegAudio):
bitrate: Optional[int] = None, bitrate: Optional[int] = None,
codec: Optional[str] = None, codec: Optional[str] = None,
executable: str = 'ffmpeg', executable: str = 'ffmpeg',
pipe=False, pipe: bool = False,
stderr=None, stderr: Optional[IO[bytes]] = None,
before_options=None, before_options: Optional[str] = None,
options=None, options: Optional[str] = None,
) -> None: ) -> None:
args = [] args = []
subprocess_kwargs = {'stdin': subprocess.PIPE if pipe else subprocess.DEVNULL, 'stderr': stderr} subprocess_kwargs = {'stdin': subprocess.PIPE if pipe else subprocess.DEVNULL, 'stderr': stderr}
@ -521,9 +520,9 @@ class FFmpegOpusAudio(FFmpegAudio):
raise TypeError(f"Expected str or callable for parameter 'probe', not '{method.__class__.__name__}'") raise TypeError(f"Expected str or callable for parameter 'probe', not '{method.__class__.__name__}'")
codec = bitrate = None codec = bitrate = None
loop = asyncio.get_event_loop() loop = asyncio.get_running_loop()
try: try:
codec, bitrate = await loop.run_in_executor(None, lambda: probefunc(source, executable)) # type: ignore codec, bitrate = await loop.run_in_executor(None, lambda: probefunc(source, executable))
except Exception: except Exception:
if not fallback: if not fallback:
_log.exception("Probe '%s' using '%s' failed", method, executable) _log.exception("Probe '%s' using '%s' failed", method, executable)
@ -531,7 +530,7 @@ class FFmpegOpusAudio(FFmpegAudio):
_log.exception("Probe '%s' using '%s' failed, trying fallback", method, executable) _log.exception("Probe '%s' using '%s' failed, trying fallback", method, executable)
try: try:
codec, bitrate = await loop.run_in_executor(None, lambda: fallback(source, executable)) # type: ignore codec, bitrate = await loop.run_in_executor(None, lambda: fallback(source, executable))
except Exception: except Exception:
_log.exception("Fallback probe using '%s' failed", executable) _log.exception("Fallback probe using '%s' failed", executable)
else: else:
@ -635,7 +634,13 @@ class PCMVolumeTransformer(AudioSource, Generic[AT]):
class AudioPlayer(threading.Thread): class AudioPlayer(threading.Thread):
DELAY: float = OpusEncoder.FRAME_LENGTH / 1000.0 DELAY: float = OpusEncoder.FRAME_LENGTH / 1000.0
def __init__(self, source: AudioSource, client: Player, *, after=None): def __init__(
self,
source: AudioSource,
client: Player,
*,
after: Optional[Callable[[Optional[Exception]], Any]] = None,
) -> None:
threading.Thread.__init__(self) threading.Thread.__init__(self)
self.daemon: bool = True self.daemon: bool = True
self.source: AudioSource = source self.source: AudioSource = source
@ -644,7 +649,7 @@ class AudioPlayer(threading.Thread):
self._end: threading.Event = threading.Event() self._end: threading.Event = threading.Event()
self._resumed: threading.Event = threading.Event() self._resumed: threading.Event = threading.Event()
self._resumed.set() # we are not paused self._resumed.set() # We are not paused
self._current_error: Optional[Exception] = None self._current_error: Optional[Exception] = None
self._connected: threading.Event = client.client._connected self._connected: threading.Event = client.client._connected
self._lock: threading.Lock = threading.Lock() self._lock: threading.Lock = threading.Lock()
@ -724,8 +729,8 @@ class AudioPlayer(threading.Thread):
self._speak(SpeakingState.none) self._speak(SpeakingState.none)
def resume(self, *, update_speaking: bool = True) -> None: def resume(self, *, update_speaking: bool = True) -> None:
self.loops = 0 self.loops: int = 0
self._start = time.perf_counter() self._start: float = time.perf_counter()
self._resumed.set() self._resumed.set()
if update_speaking: if update_speaking:
self._speak(SpeakingState.voice) self._speak(SpeakingState.voice)
@ -744,6 +749,6 @@ class AudioPlayer(threading.Thread):
def _speak(self, speaking: SpeakingState) -> None: def _speak(self, speaking: SpeakingState) -> None:
try: try:
asyncio.run_coroutine_threadsafe(self.client.ws.speak(speaking), self.client.loop) asyncio.run_coroutine_threadsafe(self.client.ws.speak(speaking), self.client.client.loop)
except Exception as e: except Exception as e:
_log.info("Speaking call in player failed: %s", e) _log.info("Speaking call in player failed: %s", e)

10
discord/reaction.py

@ -25,6 +25,7 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations from __future__ import annotations
from typing import Any, TYPE_CHECKING, AsyncIterator, Union, Optional from typing import Any, TYPE_CHECKING, AsyncIterator, Union, Optional
from .user import User
from .object import Object from .object import Object
# fmt: off # fmt: off
@ -34,7 +35,6 @@ __all__ = (
# fmt: on # fmt: on
if TYPE_CHECKING: if TYPE_CHECKING:
from .user import User
from .member import Member from .member import Member
from .types.message import Reaction as ReactionPayload from .types.message import Reaction as ReactionPayload
from .message import Message from .message import Message
@ -94,10 +94,10 @@ class Reaction:
""":class:`bool`: If this is a custom emoji.""" """:class:`bool`: If this is a custom emoji."""
return not isinstance(self.emoji, str) return not isinstance(self.emoji, str)
def __eq__(self, other: Any) -> bool: def __eq__(self, other: object) -> bool:
return isinstance(other, self.__class__) and other.emoji == self.emoji return isinstance(other, self.__class__) and other.emoji == self.emoji
def __ne__(self, other: Any) -> bool: def __ne__(self, other: object) -> bool:
if isinstance(other, self.__class__): if isinstance(other, self.__class__):
return other.emoji != self.emoji return other.emoji != self.emoji
return True return True
@ -149,8 +149,8 @@ class Reaction:
.. versionadded:: 1.3 .. versionadded:: 1.3
.. versionchanged:: 2.0 .. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising This function will now raise :exc:`ValueError` instead of
:exc:`ValueError`. ``InvalidArgument``.
Raises Raises
-------- --------

8
discord/role.py

@ -212,7 +212,7 @@ class Role(Hashable):
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<Role id={self.id} name={self.name!r}>' return f'<Role id={self.id} name={self.name!r}>'
def __lt__(self, other: Any) -> bool: def __lt__(self, other: object) -> bool:
if not isinstance(other, Role) or not isinstance(self, Role): if not isinstance(other, Role) or not isinstance(self, Role):
return NotImplemented return NotImplemented
@ -242,7 +242,7 @@ class Role(Hashable):
def __gt__(self, other: Any) -> bool: def __gt__(self, other: Any) -> bool:
return Role.__lt__(other, self) return Role.__lt__(other, self)
def __ge__(self, other: Any) -> bool: def __ge__(self, other: object) -> bool:
r = Role.__lt__(self, other) r = Role.__lt__(self, other)
if r is NotImplemented: if r is NotImplemented:
return NotImplemented return NotImplemented
@ -416,8 +416,8 @@ class Role(Hashable):
The ``display_icon``, ``icon``, and ``unicode_emoji`` keyword-only parameters were added. The ``display_icon``, ``icon``, and ``unicode_emoji`` keyword-only parameters were added.
.. versionchanged:: 2.0 .. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising This function will now raise :exc:`ValueError` instead of
:exc:`ValueError`. ``InvalidArgument``.
Parameters Parameters
----------- -----------

66
discord/scheduled_event.py

@ -41,6 +41,7 @@ if TYPE_CHECKING:
) )
from .abc import Snowflake from .abc import Snowflake
from .guild import Guild
from .channel import VoiceChannel, StageChannel from .channel import VoiceChannel, StageChannel
from .state import ConnectionState from .state import ConnectionState
from .user import User from .user import User
@ -79,15 +80,15 @@ class ScheduledEvent(Hashable):
The scheduled event's ID. The scheduled event's ID.
name: :class:`str` name: :class:`str`
The name of the scheduled event. The name of the scheduled event.
description: :class:`str` description: Optional[:class:`str`]
The description of the scheduled event. The description of the scheduled event.
entity_type: :class:`EntityType` entity_type: :class:`EntityType`
The type of entity this event is for. The type of entity this event is for.
entity_id: :class:`int` entity_id: Optional[:class:`int`]
The ID of the entity this event is for. The ID of the entity this event is for if available.
start_time: :class:`datetime.datetime` start_time: :class:`datetime.datetime`
The time that the scheduled event will start in UTC. The time that the scheduled event will start in UTC.
end_time: :class:`datetime.datetime` end_time: Optional[:class:`datetime.datetime`]
The time that the scheduled event will end in UTC. The time that the scheduled event will end in UTC.
privacy_level: :class:`PrivacyLevel` privacy_level: :class:`PrivacyLevel`
The privacy level of the scheduled event. The privacy level of the scheduled event.
@ -130,9 +131,9 @@ class ScheduledEvent(Hashable):
self.id: int = int(data['id']) self.id: int = int(data['id'])
self.guild_id: int = int(data['guild_id']) self.guild_id: int = int(data['guild_id'])
self.name: str = data['name'] self.name: str = data['name']
self.description: str = data.get('description', '') self.description: Optional[str] = data.get('description')
self.entity_type = try_enum(EntityType, data['entity_type']) self.entity_type: EntityType = try_enum(EntityType, data['entity_type'])
self.entity_id: int = int(data['id']) self.entity_id: Optional[int] = _get_as_snowflake(data, 'entity_id')
self.start_time: datetime = parse_time(data['scheduled_start_time']) self.start_time: datetime = parse_time(data['scheduled_start_time'])
self.privacy_level: PrivacyLevel = try_enum(PrivacyLevel, data['status']) self.privacy_level: PrivacyLevel = try_enum(PrivacyLevel, data['status'])
self.status: EventStatus = try_enum(EventStatus, data['status']) self.status: EventStatus = try_enum(EventStatus, data['status'])
@ -145,15 +146,14 @@ class ScheduledEvent(Hashable):
self.end_time: Optional[datetime] = parse_time(data.get('scheduled_end_time')) self.end_time: Optional[datetime] = parse_time(data.get('scheduled_end_time'))
self.channel_id: Optional[int] = _get_as_snowflake(data, 'channel_id') self.channel_id: Optional[int] = _get_as_snowflake(data, 'channel_id')
metadata = data.get('metadata') metadata = data.get('entity_metadata')
if metadata:
self._unroll_metadata(metadata) self._unroll_metadata(metadata)
def _unroll_metadata(self, data: EntityMetadata): def _unroll_metadata(self, data: Optional[EntityMetadata]):
self.location: Optional[str] = data.get('location') self.location: Optional[str] = data.get('location') if data else None
@classmethod @classmethod
def from_creation(cls, *, state: ConnectionState, data: GuildScheduledEventPayload): def from_creation(cls, *, state: ConnectionState, data: GuildScheduledEventPayload) -> None:
creator_id = data.get('creator_id') creator_id = data.get('creator_id')
self = cls(state=state, data=data) self = cls(state=state, data=data)
if creator_id: if creator_id:
@ -169,11 +169,21 @@ class ScheduledEvent(Hashable):
return None return None
return Asset._from_scheduled_event_cover_image(self._state, self.id, self._cover_image) return Asset._from_scheduled_event_cover_image(self._state, self.id, self._cover_image)
@property
def guild(self) -> Optional[Guild]:
"""Optional[:class:`Guild`]: The guild this scheduled event is in."""
return self._state._get_guild(self.guild_id)
@property @property
def channel(self) -> Optional[Union[VoiceChannel, StageChannel]]: def channel(self) -> Optional[Union[VoiceChannel, StageChannel]]:
"""Optional[Union[:class:`VoiceChannel`, :class:`StageChannel`]]: The channel this scheduled event is in.""" """Optional[Union[:class:`VoiceChannel`, :class:`StageChannel`]]: The channel this scheduled event is in."""
return self.guild.get_channel(self.channel_id) # type: ignore return self.guild.get_channel(self.channel_id) # type: ignore
@property
def url(self) -> str:
""":class:`str`: The url for the scheduled event."""
return f'https://discord.com/events/{self.guild_id}/{self.id}'
async def start(self, *, reason: Optional[str] = None) -> ScheduledEvent: async def start(self, *, reason: Optional[str] = None) -> ScheduledEvent:
"""|coro| """|coro|
@ -286,7 +296,7 @@ class ScheduledEvent(Hashable):
description: str = MISSING, description: str = MISSING,
channel: Optional[Snowflake] = MISSING, channel: Optional[Snowflake] = MISSING,
start_time: datetime = MISSING, start_time: datetime = MISSING,
end_time: datetime = MISSING, end_time: Optional[datetime] = MISSING,
privacy_level: PrivacyLevel = MISSING, privacy_level: PrivacyLevel = MISSING,
entity_type: EntityType = MISSING, entity_type: EntityType = MISSING,
status: EventStatus = MISSING, status: EventStatus = MISSING,
@ -314,10 +324,14 @@ class ScheduledEvent(Hashable):
start_time: :class:`datetime.datetime` start_time: :class:`datetime.datetime`
The time that the scheduled event will start. This must be a timezone-aware The time that the scheduled event will start. This must be a timezone-aware
datetime object. Consider using :func:`utils.utcnow`. datetime object. Consider using :func:`utils.utcnow`.
end_time: :class:`datetime.datetime` end_time: Optional[:class:`datetime.datetime`]
The time that the scheduled event will end. This must be a timezone-aware The time that the scheduled event will end. This must be a timezone-aware
datetime object. Consider using :func:`utils.utcnow`. datetime object. Consider using :func:`utils.utcnow`.
If the entity type is either :attr:`EntityType.voice` or
:attr:`EntityType.stage_instance`, the end_time can be cleared by
passing ``None``.
Required if the entity type is :attr:`EntityType.external`. Required if the entity type is :attr:`EntityType.external`.
privacy_level: :class:`PrivacyLevel` privacy_level: :class:`PrivacyLevel`
The privacy level of the scheduled event. The privacy level of the scheduled event.
@ -325,8 +339,8 @@ class ScheduledEvent(Hashable):
The new entity type. The new entity type.
status: :class:`EventStatus` status: :class:`EventStatus`
The new status of the scheduled event. The new status of the scheduled event.
image: :class:`bytes` image: Optional[:class:`bytes`]
The new image of the scheduled event. The new image of the scheduled event or ``None`` to remove the image.
location: :class:`str` location: :class:`str`
The new location of the scheduled event. The new location of the scheduled event.
@ -383,7 +397,7 @@ class ScheduledEvent(Hashable):
payload['status'] = status.value payload['status'] = status.value
if image is not MISSING: if image is not MISSING:
image_as_str: str = _bytes_to_base64_data(image) image_as_str: Optional[str] = _bytes_to_base64_data(image) if image is not None else image
payload['image'] = image_as_str payload['image'] = image_as_str
if entity_type is not MISSING: if entity_type is not MISSING:
@ -400,25 +414,31 @@ class ScheduledEvent(Hashable):
payload['channel_id'] = channel.id payload['channel_id'] = channel.id
if location is not MISSING: if location not in (MISSING, None):
raise TypeError('location cannot be set when entity_type is voice or stage_instance') raise TypeError('location cannot be set when entity_type is voice or stage_instance')
payload['entity_metadata'] = None
else: else:
if channel is not MISSING: if channel not in (MISSING, None):
raise TypeError('channel cannot be set when entity_type is external') raise TypeError('channel cannot be set when entity_type is external')
payload['channel_id'] = None
if location is MISSING or location is None: if location is MISSING or location is None:
raise TypeError('location must be set when entity_type is external') raise TypeError('location must be set when entity_type is external')
metadata['location'] = location metadata['location'] = location
if end_time is MISSING: if end_time is MISSING or end_time is None:
raise TypeError('end_time must be set when entity_type is external') raise TypeError('end_time must be set when entity_type is external')
if end_time is not MISSING:
if end_time is not None:
if end_time.tzinfo is None: if end_time.tzinfo is None:
raise ValueError( raise ValueError(
'end_time must be an aware datetime. Consider using discord.utils.utcnow() or datetime.datetime.now().astimezone() for local time.' 'end_time must be an aware datetime. Consider using discord.utils.utcnow() or datetime.datetime.now().astimezone() for local time.'
) )
payload['scheduled_end_time'] = end_time.isoformat() payload['scheduled_end_time'] = end_time.isoformat()
else:
payload['scheduled_end_time'] = end_time
if metadata: if metadata:
payload['entity_metadata'] = metadata payload['entity_metadata'] = metadata
@ -459,7 +479,7 @@ class ScheduledEvent(Hashable):
) -> AsyncIterator[User]: ) -> AsyncIterator[User]:
"""|coro| """|coro|
Retrieves all :class:`User` that are in this thread. Retrieves all :class:`User` that are subscribed to this event.
This requires :attr:`Intents.members` to get information about members This requires :attr:`Intents.members` to get information about members
other than yourself. other than yourself.
@ -472,7 +492,7 @@ class ScheduledEvent(Hashable):
Returns Returns
-------- --------
List[:class:`User`] List[:class:`User`]
All thread members in the thread. All subscribed users of this event.
""" """
async def _before_strategy(retrieve, before, limit): async def _before_strategy(retrieve, before, limit):
@ -548,4 +568,4 @@ class ScheduledEvent(Hashable):
self._users[user.id] = user self._users[user.id] = user
def _pop_user(self, user_id: int) -> None: def _pop_user(self, user_id: int) -> None:
self._users.pop(user_id) self._users.pop(user_id, None)

22
discord/stage_instance.py

@ -26,7 +26,7 @@ from __future__ import annotations
from typing import Optional, TYPE_CHECKING from typing import Optional, TYPE_CHECKING
from .utils import MISSING, cached_slot_property from .utils import MISSING, cached_slot_property, _get_as_snowflake
from .mixins import Hashable from .mixins import Hashable
from .enums import PrivacyLevel, try_enum from .enums import PrivacyLevel, try_enum
@ -41,6 +41,7 @@ if TYPE_CHECKING:
from .state import ConnectionState from .state import ConnectionState
from .channel import StageChannel from .channel import StageChannel
from .guild import Guild from .guild import Guild
from .scheduled_event import ScheduledEvent
class StageInstance(Hashable): class StageInstance(Hashable):
@ -76,6 +77,10 @@ class StageInstance(Hashable):
The privacy level of the stage instance. The privacy level of the stage instance.
discoverable_disabled: :class:`bool` discoverable_disabled: :class:`bool`
Whether discoverability for the stage instance is disabled. Whether discoverability for the stage instance is disabled.
scheduled_event_id: Optional[:class:`int`]
The ID of scheduled event that belongs to the stage instance if any.
.. versionadded:: 2.0
""" """
__slots__ = ( __slots__ = (
@ -86,20 +91,23 @@ class StageInstance(Hashable):
'topic', 'topic',
'privacy_level', 'privacy_level',
'discoverable_disabled', 'discoverable_disabled',
'scheduled_event_id',
'_cs_channel', '_cs_channel',
'_cs_scheduled_event',
) )
def __init__(self, *, state: ConnectionState, guild: Guild, data: StageInstancePayload) -> None: def __init__(self, *, state: ConnectionState, guild: Guild, data: StageInstancePayload) -> None:
self._state = state self._state: ConnectionState = state
self.guild = guild self.guild: Guild = guild
self._update(data) self._update(data)
def _update(self, data: StageInstancePayload): def _update(self, data: StageInstancePayload) -> None:
self.id: int = int(data['id']) self.id: int = int(data['id'])
self.channel_id: int = int(data['channel_id']) self.channel_id: int = int(data['channel_id'])
self.topic: str = data['topic'] self.topic: str = data['topic']
self.privacy_level: PrivacyLevel = try_enum(PrivacyLevel, data['privacy_level']) self.privacy_level: PrivacyLevel = try_enum(PrivacyLevel, data['privacy_level'])
self.discoverable_disabled: bool = data.get('discoverable_disabled', False) self.discoverable_disabled: bool = data.get('discoverable_disabled', False)
self.scheduled_event_id: Optional[int] = _get_as_snowflake(data, 'guild_scheduled_event_id')
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<StageInstance id={self.id} guild={self.guild!r} channel_id={self.channel_id} topic={self.topic!r}>' return f'<StageInstance id={self.id} guild={self.guild!r} channel_id={self.channel_id} topic={self.topic!r}>'
@ -115,6 +123,12 @@ class StageInstance(Hashable):
# The returned channel will always be a StageChannel or None # The returned channel will always be a StageChannel or None
return self._state.get_channel(self.channel_id) # type: ignore return self._state.get_channel(self.channel_id) # type: ignore
@cached_slot_property('_cs_scheduled_event')
def scheduled_event(self) -> Optional[ScheduledEvent]:
"""Optional[:class:`ScheduledEvent`]: The scheduled event that belongs to the stage instance."""
# Guild.get_scheduled_event() expects an int, we are passing Optional[int]
return self.guild.get_scheduled_event(self.scheduled_event_id) # type: ignore
async def edit( async def edit(
self, self,
*, *,

74
discord/state.py

@ -41,6 +41,8 @@ from typing import (
Coroutine, Coroutine,
Tuple, Tuple,
Deque, Deque,
Literal,
overload,
) )
import weakref import weakref
import inspect import inspect
@ -93,7 +95,7 @@ if TYPE_CHECKING:
from .types.activity import Activity as ActivityPayload from .types.activity import Activity as ActivityPayload
from .types.channel import DMChannel as DMChannelPayload from .types.channel import DMChannel as DMChannelPayload
from .types.user import User as UserPayload, PartialUser as PartialUserPayload from .types.user import User as UserPayload, PartialUser as PartialUserPayload
from .types.emoji import Emoji as EmojiPayload from .types.emoji import Emoji as EmojiPayload, PartialEmoji as PartialEmojiPayload
from .types.sticker import GuildSticker as GuildStickerPayload from .types.sticker import GuildSticker as GuildStickerPayload
from .types.guild import Guild as GuildPayload from .types.guild import Guild as GuildPayload
from .types.message import Message as MessagePayload, PartialMessage as PartialMessagePayload from .types.message import Message as MessagePayload, PartialMessage as PartialMessagePayload
@ -376,24 +378,24 @@ class ConnectionState:
def __init__( def __init__(
self, self,
*, *,
dispatch: Callable, dispatch: Callable[..., Any],
handlers: Dict[str, Callable], handlers: Dict[str, Callable[..., Any]],
hooks: Dict[str, Callable], hooks: Dict[str, Callable[..., Coroutine[Any, Any, Any]]],
http: HTTPClient, http: HTTPClient,
loop: asyncio.AbstractEventLoop,
client: Client, client: Client,
**options: Any, **options: Any,
) -> None: ) -> None:
self.loop: asyncio.AbstractEventLoop = loop # Set later, after Client.login
self.loop: asyncio.AbstractEventLoop = utils.MISSING
self.http: HTTPClient = http self.http: HTTPClient = http
self.client = client self.client = client
self.max_messages: Optional[int] = options.get('max_messages', 1000) self.max_messages: Optional[int] = options.get('max_messages', 1000)
if self.max_messages is not None and self.max_messages <= 0: if self.max_messages is not None and self.max_messages <= 0:
self.max_messages = 1000 self.max_messages = 1000
self.dispatch: Callable = dispatch self.dispatch: Callable[..., Any] = dispatch
self.handlers: Dict[str, Callable] = handlers self.handlers: Dict[str, Callable[..., Any]] = handlers
self.hooks: Dict[str, Callable] = hooks self.hooks: Dict[str, Callable[..., Coroutine[Any, Any, Any]]] = hooks
self._ready_task: Optional[asyncio.Task] = None self._ready_task: Optional[asyncio.Task] = None
self.heartbeat_timeout: float = options.get('heartbeat_timeout', 60.0) self.heartbeat_timeout: float = options.get('heartbeat_timeout', 60.0)
@ -439,11 +441,11 @@ class ConnectionState:
if cache_flags._empty: if cache_flags._empty:
self.store_user = self.create_user # type: ignore self.store_user = self.create_user # type: ignore
parsers = {} self.parsers: Dict[str, Callable[[Any], None]]
self.parsers = parsers = {}
for attr, func in inspect.getmembers(self): for attr, func in inspect.getmembers(self):
if attr.startswith('parse_'): if attr.startswith('parse_'):
parsers[attr[6:].upper()] = func parsers[attr[6:].upper()] = func
self.parsers: Dict[str, Callable[[Dict[str, Any]], None]] = parsers
self.clear() self.clear()
@ -505,6 +507,9 @@ class ConnectionState:
else: else:
await coro(*args, **kwargs) await coro(*args, **kwargs)
async def async_setup(self) -> None:
pass
@property @property
def session_id(self) -> Optional[str]: def session_id(self) -> Optional[str]:
return self.ws.session_id return self.ws.session_id
@ -588,7 +593,7 @@ class ConnectionState:
def create_user(self, data: Union[UserPayload, PartialUserPayload]) -> User: def create_user(self, data: Union[UserPayload, PartialUserPayload]) -> User:
return User(state=self, data=data) return User(state=self, data=data)
def get_user(self, id): def get_user(self, id: int) -> Optional[User]:
return self._users.get(id) return self._users.get(id)
def store_emoji(self, guild: Guild, data: EmojiPayload) -> Emoji: def store_emoji(self, guild: Guild, data: EmojiPayload) -> Emoji:
@ -1049,7 +1054,7 @@ class ConnectionState:
if old_member._client_status != member._client_status or old_member._activities != member._activities: if old_member._client_status != member._client_status or old_member._activities != member._activities:
self.dispatch('presence_update', old_member, member) self.dispatch('presence_update', old_member, member)
def parse_user_update(self, data: gw.UserUpdateEvent): def parse_user_update(self, data: gw.UserUpdateEvent) -> None:
if self.user: if self.user:
self.user._update(data) self.user._update(data)
@ -1260,6 +1265,8 @@ class ConnectionState:
existing = guild.get_thread(int(data['id'])) existing = guild.get_thread(int(data['id']))
if existing is not None: if existing is not None:
old = existing._update(data) old = existing._update(data)
if existing.archived:
guild._remove_thread(existing)
if old is not None: if old is not None:
self.dispatch('thread_update', old, existing) self.dispatch('thread_update', old, existing)
else: # Shouldn't happen else: # Shouldn't happen
@ -1397,10 +1404,8 @@ class ConnectionState:
def parse_guild_member_remove(self, data: gw.GuildMemberRemoveEvent) -> None: def parse_guild_member_remove(self, data: gw.GuildMemberRemoveEvent) -> None:
guild = self._get_guild(int(data['guild_id'])) guild = self._get_guild(int(data['guild_id']))
if guild is not None: if guild is not None:
try: if guild._member_count is not None:
guild._member_count -= 1 guild._member_count -= 1
except AttributeError:
pass
user_id = int(data['user']['id']) user_id = int(data['user']['id'])
member = guild.get_member(user_id) member = guild.get_member(user_id)
@ -1630,7 +1635,7 @@ class ConnectionState:
guild.stickers = tuple(map(lambda d: self.store_sticker(guild, d), data['stickers'])) guild.stickers = tuple(map(lambda d: self.store_sticker(guild, d), data['stickers']))
self.dispatch('guild_stickers_update', guild, before_stickers, guild.stickers) self.dispatch('guild_stickers_update', guild, before_stickers, guild.stickers)
def _get_create_guild(self, data): def _get_create_guild(self, data: gw.GuildCreateEvent):
guild = self._get_guild(int(data['id'])) guild = self._get_guild(int(data['id']))
# Discord being Discord sends a GUILD_CREATE after an OPCode 14 is sent (a la bots) # Discord being Discord sends a GUILD_CREATE after an OPCode 14 is sent (a la bots)
# However, we want that if we forced a GUILD_CREATE for an unavailable guild # However, we want that if we forced a GUILD_CREATE for an unavailable guild
@ -1640,7 +1645,7 @@ class ConnectionState:
return self._add_guild_from_data(data) return self._add_guild_from_data(data)
def is_guild_evicted(self, guild) -> bool: def is_guild_evicted(self, guild: Guild) -> bool:
return guild.id not in self._guilds return guild.id not in self._guilds
async def assert_guild_presence_count(self, guild: Guild): async def assert_guild_presence_count(self, guild: Guild):
@ -1706,9 +1711,15 @@ class ConnectionState:
) )
request.start() request.start()
if wait: @overload
return await request.wait() async def chunk_guild(self, guild: Guild, *, wait: Literal[True] = ..., channels: List[abcSnowflake] = ...) -> Optional[List[Member]]:
return request.get_future() ...
@overload
async def chunk_guild(
self, guild: Guild, *, wait: Literal[False] = ..., channels: List[abcSnowflake] = ...
) -> asyncio.Future[Optional[List[Member]]]:
...
async def chunk_guild( async def chunk_guild(
self, self,
@ -1716,7 +1727,7 @@ class ConnectionState:
*, *,
wait: bool = True, wait: bool = True,
channels: List[abcSnowflake] = MISSING, channels: List[abcSnowflake] = MISSING,
): ) -> Union[asyncio.Future[Optional[List[Member]]], Optional[List[Member]]]:
if not guild.me: if not guild.me:
await guild.query_members(user_ids=[self.self_id], cache=True) # type: ignore - self_id is always present here await guild.query_members(user_ids=[self.self_id], cache=True) # type: ignore - self_id is always present here
@ -1960,7 +1971,7 @@ class ConnectionState:
if guild is not None: if guild is not None:
scheduled_event = ScheduledEvent(state=self, data=data) scheduled_event = ScheduledEvent(state=self, data=data)
guild._scheduled_events[scheduled_event.id] = scheduled_event guild._scheduled_events[scheduled_event.id] = scheduled_event
self.dispatch('scheduled_event_create', guild, scheduled_event) self.dispatch('scheduled_event_create', scheduled_event)
else: else:
_log.debug('SCHEDULED_EVENT_CREATE referencing unknown guild ID: %s. Discarding.', data['guild_id']) _log.debug('SCHEDULED_EVENT_CREATE referencing unknown guild ID: %s. Discarding.', data['guild_id'])
@ -1971,7 +1982,7 @@ class ConnectionState:
if scheduled_event is not None: if scheduled_event is not None:
old_scheduled_event = copy.copy(scheduled_event) old_scheduled_event = copy.copy(scheduled_event)
scheduled_event._update(data) scheduled_event._update(data)
self.dispatch('scheduled_event_update', guild, old_scheduled_event, scheduled_event) self.dispatch('scheduled_event_update', old_scheduled_event, scheduled_event)
else: else:
_log.debug('SCHEDULED_EVENT_UPDATE referencing unknown scheduled event ID: %s. Discarding.', data['id']) _log.debug('SCHEDULED_EVENT_UPDATE referencing unknown scheduled event ID: %s. Discarding.', data['id'])
else: else:
@ -1985,7 +1996,7 @@ class ConnectionState:
except KeyError: except KeyError:
pass pass
else: else:
self.dispatch('scheduled_event_delete', guild, scheduled_event) self.dispatch('scheduled_event_delete', scheduled_event)
else: else:
_log.debug('SCHEDULED_EVENT_DELETE referencing unknown guild ID: %s. Discarding.', data['guild_id']) _log.debug('SCHEDULED_EVENT_DELETE referencing unknown guild ID: %s. Discarding.', data['guild_id'])
@ -1997,10 +2008,9 @@ class ConnectionState:
user = self.get_user(int(data['user_id'])) user = self.get_user(int(data['user_id']))
if user is not None: if user is not None:
scheduled_event._add_user(user) scheduled_event._add_user(user)
self.dispatch('scheduled_event_user_add', guild, scheduled_event, user) self.dispatch('scheduled_event_user_add', scheduled_event, user)
else: else:
_log.debug('SCHEDULED_EVENT_USER_ADD referencing unknown user ID: %s. Discarding.', data['user_id']) _log.debug('SCHEDULED_EVENT_USER_ADD referencing unknown user ID: %s. Discarding.', data['user_id'])
self.dispatch('scheduled_event_user_add', guild, scheduled_event, user)
else: else:
_log.debug( _log.debug(
'SCHEDULED_EVENT_USER_ADD referencing unknown scheduled event ID: %s. Discarding.', 'SCHEDULED_EVENT_USER_ADD referencing unknown scheduled event ID: %s. Discarding.',
@ -2020,7 +2030,6 @@ class ConnectionState:
self.dispatch('scheduled_event_user_remove', scheduled_event, user) self.dispatch('scheduled_event_user_remove', scheduled_event, user)
else: else:
_log.debug('SCHEDULED_EVENT_USER_REMOVE referencing unknown user ID: %s. Discarding.', data['user_id']) _log.debug('SCHEDULED_EVENT_USER_REMOVE referencing unknown user ID: %s. Discarding.', data['user_id'])
self.dispatch('scheduled_event_user_remove', scheduled_event, user)
else: else:
_log.debug( _log.debug(
'SCHEDULED_EVENT_USER_REMOVE referencing unknown scheduled event ID: %s. Discarding.', 'SCHEDULED_EVENT_USER_REMOVE referencing unknown scheduled event ID: %s. Discarding.',
@ -2173,16 +2182,19 @@ class ConnectionState:
return channel.guild.get_member(user_id) return channel.guild.get_member(user_id)
return self.get_user(user_id) return self.get_user(user_id)
def get_reaction_emoji(self, data) -> Union[Emoji, PartialEmoji]: def get_reaction_emoji(self, data: PartialEmojiPayload) -> Union[Emoji, PartialEmoji, str]:
emoji_id = utils._get_as_snowflake(data, 'id') emoji_id = utils._get_as_snowflake(data, 'id')
if not emoji_id: if not emoji_id:
return data['name'] # the name key will be a str
return data['name'] # type: ignore
try: try:
return self._emojis[emoji_id] return self._emojis[emoji_id]
except KeyError: except KeyError:
return PartialEmoji.with_state(self, animated=data.get('animated', False), id=emoji_id, name=data['name']) return PartialEmoji.with_state(
self, animated=data.get('animated', False), id=emoji_id, name=data['name'] # type: ignore
)
def _upgrade_partial_emoji(self, emoji: PartialEmoji) -> Union[Emoji, PartialEmoji, str]: def _upgrade_partial_emoji(self, emoji: PartialEmoji) -> Union[Emoji, PartialEmoji, str]:
emoji_id = emoji.id emoji_id = emoji.id

4
discord/template.py

@ -174,8 +174,8 @@ class Template:
The ``region`` parameter has been removed. The ``region`` parameter has been removed.
.. versionchanged:: 2.0 .. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising This function will now raise :exc:`ValueError` instead of
:exc:`ValueError`. ``InvalidArgument``.
Parameters Parameters
---------- ----------

70
discord/threads.py

@ -43,6 +43,7 @@ __all__ = (
if TYPE_CHECKING: if TYPE_CHECKING:
from datetime import datetime from datetime import datetime
from typing_extensions import Self
from .types.threads import ( from .types.threads import (
Thread as ThreadPayload, Thread as ThreadPayload,
@ -143,13 +144,13 @@ class Thread(Messageable, Hashable):
'_created_at', '_created_at',
) )
def __init__(self, *, guild: Guild, state: ConnectionState, data: ThreadPayload): def __init__(self, *, guild: Guild, state: ConnectionState, data: ThreadPayload) -> None:
self._state: ConnectionState = state self._state: ConnectionState = state
self.guild = guild self.guild: Guild = guild
self._members: Dict[int, ThreadMember] = {} self._members: Dict[int, ThreadMember] = {}
self._from_data(data) self._from_data(data)
async def _get_channel(self): async def _get_channel(self) -> Self:
return self return self
def __repr__(self) -> str: def __repr__(self) -> str:
@ -162,26 +163,25 @@ class Thread(Messageable, Hashable):
return self.name return self.name
def _from_data(self, data: ThreadPayload): def _from_data(self, data: ThreadPayload):
self.id = int(data['id']) self.id: int = int(data['id'])
self.parent_id = int(data['parent_id']) self.parent_id: int = int(data['parent_id'])
self.owner_id = int(data['owner_id']) self.owner_id: int = int(data['owner_id'])
self.name = data['name'] self.name: str = data['name']
self._type = try_enum(ChannelType, data['type']) self._type: ChannelType = try_enum(ChannelType, data['type'])
self.last_message_id = _get_as_snowflake(data, 'last_message_id') self.last_message_id: Optional[int] = _get_as_snowflake(data, 'last_message_id')
self.slowmode_delay = data.get('rate_limit_per_user', 0) self.slowmode_delay: int = data.get('rate_limit_per_user', 0)
self.message_count = data['message_count'] self.message_count: int = data['message_count']
self.member_count = data['member_count'] self.member_count: int = data['member_count']
self._member_ids = data['member_ids_preview'] self._member_ids: List[Union[str, int]] = data['member_ids_preview']
self._unroll_metadata(data['thread_metadata']) self._unroll_metadata(data['thread_metadata'])
def _unroll_metadata(self, data: ThreadMetadata): def _unroll_metadata(self, data: ThreadMetadata):
self.archived = data['archived'] self.archived: bool = data['archived']
self.auto_archive_duration = data['auto_archive_duration'] self.auto_archive_duration: int = data['auto_archive_duration']
self.archive_timestamp = parse_time(data['archive_timestamp']) self.archive_timestamp: datetime = parse_time(data['archive_timestamp'])
self._created_at = parse_time(data.get('creation_timestamp')) self.locked: bool = data.get('locked', False)
self.locked = data.get('locked', False) self.invitable: bool = data.get('invitable', True)
self.invitable = data.get('invitable', True) self._created_at: Optional[datetime] = parse_time(data.get('create_timestamp'))
self._created_at = parse_time(data.get('create_timestamp'))
def _update(self, data): def _update(self, data):
old = copy.copy(self) old = copy.copy(self)
@ -249,8 +249,7 @@ class Thread(Messageable, Hashable):
def members(self) -> List[ThreadMember]: def members(self) -> List[ThreadMember]:
"""List[:class:`ThreadMember`]: A list of thread members in this thread. """List[:class:`ThreadMember`]: A list of thread members in this thread.
Initial members are not provided by Discord. You must call :func:`fetch_members` Initial members are not provided by Discord. You must call :func:`fetch_members`.
or have thread subscribing enabled.
""" """
return list(self._members.values()) return list(self._members.values())
@ -577,7 +576,7 @@ class Thread(Messageable, Hashable):
# The data payload will always be a Thread payload # The data payload will always be a Thread payload
return Thread(data=data, state=self._state, guild=self.guild) # type: ignore return Thread(data=data, state=self._state, guild=self.guild) # type: ignore
async def join(self): async def join(self) -> None:
"""|coro| """|coro|
Joins this thread. Joins this thread.
@ -594,7 +593,7 @@ class Thread(Messageable, Hashable):
""" """
await self._state.http.join_thread(self.id) await self._state.http.join_thread(self.id)
async def leave(self): async def leave(self) -> None:
"""|coro| """|coro|
Leaves this thread. Leaves this thread.
@ -606,7 +605,7 @@ class Thread(Messageable, Hashable):
""" """
await self._state.http.leave_thread(self.id) await self._state.http.leave_thread(self.id)
async def add_user(self, user: Snowflake, /): async def add_user(self, user: Snowflake, /) -> None:
"""|coro| """|coro|
Adds a user to this thread. Adds a user to this thread.
@ -629,7 +628,7 @@ class Thread(Messageable, Hashable):
""" """
await self._state.http.add_user_to_thread(self.id, user.id) await self._state.http.add_user_to_thread(self.id, user.id)
async def remove_user(self, user: Snowflake, /): async def remove_user(self, user: Snowflake, /) -> None:
"""|coro| """|coro|
Removes a user from this thread. Removes a user from this thread.
@ -680,7 +679,7 @@ class Thread(Messageable, Hashable):
return self.members # Includes correct self.me return self.members # Includes correct self.me
async def delete(self): async def delete(self) -> None:
"""|coro| """|coro|
Deletes this thread. Deletes this thread.
@ -772,23 +771,24 @@ class ThreadMember(Hashable):
'parent', 'parent',
) )
def __init__(self, parent: Thread, data: ThreadMemberPayload): def __init__(self, parent: Thread, data: ThreadMemberPayload) -> None:
self.parent = parent self.parent: Thread = parent
self._state = parent._state self._state: ConnectionState = parent._state
self._from_data(data) self._from_data(data)
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<ThreadMember id={self.id} thread_id={self.thread_id} joined_at={self.joined_at!r}>' return f'<ThreadMember id={self.id} thread_id={self.thread_id} joined_at={self.joined_at!r}>'
def _from_data(self, data: ThreadMemberPayload): def _from_data(self, data: ThreadMemberPayload) -> None:
state = self._state state = self._state
self.id: int
try: try:
self.id = int(data['user_id']) self.id = int(data['user_id'])
except KeyError: except KeyError:
assert state.self_id is not None self.id = state.self_id # type: ignore
self.id = state.self_id
self.thread_id: int
try: try:
self.thread_id = int(data['id']) self.thread_id = int(data['id'])
except KeyError: except KeyError:
@ -798,7 +798,7 @@ class ThreadMember(Hashable):
self.flags = data.get('flags') self.flags = data.get('flags')
if (mdata := data.get('member')) is not None: if (mdata := data.get('member')) is not None:
guild = self.parent.parent.guild # type: ignore guild = self.parent.guild
mdata['guild_id'] = guild.id mdata['guild_id'] = guild.id
self.id = user_id = int(data['user_id']) self.id = user_id = int(data['user_id'])
mdata['presence'] = data.get('presence') mdata['presence'] = data.get('presence')
@ -817,4 +817,4 @@ class ThreadMember(Hashable):
"""Optional[:class:`Member`]: The member this :class:`ThreadMember` represents. If the member """Optional[:class:`Member`]: The member this :class:`ThreadMember` represents. If the member
is not cached then this will be ``None``. is not cached then this will be ``None``.
""" """
return self.parent.parent.guild.get_member(self.id) # type: ignore return self.parent.guild.get_member(self.id)

1
discord/types/activity.py

@ -112,3 +112,4 @@ class Activity(_BaseActivity, total=False):
session_id: Optional[str] session_id: Optional[str]
instance: bool instance: bool
buttons: List[ActivityButton] buttons: List[ActivityButton]
sync_id: str

1
discord/types/channel.py

@ -156,3 +156,4 @@ class StageInstance(TypedDict):
topic: str topic: str
privacy_level: PrivacyLevel privacy_level: PrivacyLevel
discoverable_disabled: bool discoverable_disabled: bool
guild_scheduled_event_id: Optional[int]

242
discord/types/interactions.py

@ -0,0 +1,242 @@
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Dict, List, Literal, TypedDict, Union
from .channel import ChannelTypeWithoutThread, ThreadMetadata
from .threads import ThreadType
from .member import Member
from .message import Attachment
from .role import Role
from .snowflake import Snowflake
from .user import User
if TYPE_CHECKING:
from .message import Message
InteractionType = Literal[1, 2, 3, 4, 5]
class _BasePartialChannel(TypedDict):
id: Snowflake
name: str
permissions: str
class PartialChannel(_BasePartialChannel):
type: ChannelTypeWithoutThread
class PartialThread(_BasePartialChannel):
type: ThreadType
thread_metadata: ThreadMetadata
parent_id: Snowflake
class ResolvedData(TypedDict, total=False):
users: Dict[str, User]
members: Dict[str, Member]
roles: Dict[str, Role]
channels: Dict[str, Union[PartialChannel, PartialThread]]
messages: Dict[str, Message]
attachments: Dict[str, Attachment]
class _BaseApplicationCommandInteractionDataOption(TypedDict):
name: str
class _CommandGroupApplicationCommandInteractionDataOption(_BaseApplicationCommandInteractionDataOption):
type: Literal[1, 2]
options: List[ApplicationCommandInteractionDataOption]
class _BaseValueApplicationCommandInteractionDataOption(_BaseApplicationCommandInteractionDataOption, total=False):
focused: bool
class _StringValueApplicationCommandInteractionDataOption(_BaseValueApplicationCommandInteractionDataOption):
type: Literal[3]
value: str
class _IntegerValueApplicationCommandInteractionDataOption(_BaseValueApplicationCommandInteractionDataOption):
type: Literal[4]
value: int
class _BooleanValueApplicationCommandInteractionDataOption(_BaseValueApplicationCommandInteractionDataOption):
type: Literal[5]
value: bool
class _SnowflakeValueApplicationCommandInteractionDataOption(_BaseValueApplicationCommandInteractionDataOption):
type: Literal[6, 7, 8, 9, 11]
value: Snowflake
class _NumberValueApplicationCommandInteractionDataOption(_BaseValueApplicationCommandInteractionDataOption):
type: Literal[10]
value: float
_ValueApplicationCommandInteractionDataOption = Union[
_StringValueApplicationCommandInteractionDataOption,
_IntegerValueApplicationCommandInteractionDataOption,
_BooleanValueApplicationCommandInteractionDataOption,
_SnowflakeValueApplicationCommandInteractionDataOption,
_NumberValueApplicationCommandInteractionDataOption,
]
ApplicationCommandInteractionDataOption = Union[
_CommandGroupApplicationCommandInteractionDataOption,
_ValueApplicationCommandInteractionDataOption,
]
class _BaseApplicationCommandInteractionDataOptional(TypedDict, total=False):
resolved: ResolvedData
guild_id: Snowflake
class _BaseApplicationCommandInteractionData(_BaseApplicationCommandInteractionDataOptional):
id: Snowflake
name: str
class ChatInputApplicationCommandInteractionData(_BaseApplicationCommandInteractionData, total=False):
type: Literal[1]
options: List[ApplicationCommandInteractionDataOption]
class _BaseNonChatInputApplicationCommandInteractionData(_BaseApplicationCommandInteractionData):
target_id: Snowflake
class UserApplicationCommandInteractionData(_BaseNonChatInputApplicationCommandInteractionData):
type: Literal[2]
class MessageApplicationCommandInteractionData(_BaseNonChatInputApplicationCommandInteractionData):
type: Literal[3]
ApplicationCommandInteractionData = Union[
ChatInputApplicationCommandInteractionData,
UserApplicationCommandInteractionData,
MessageApplicationCommandInteractionData,
]
class _BaseMessageComponentInteractionData(TypedDict):
custom_id: str
class ButtonMessageComponentInteractionData(_BaseMessageComponentInteractionData):
component_type: Literal[2]
class SelectMessageComponentInteractionData(_BaseMessageComponentInteractionData):
component_type: Literal[3]
values: List[str]
MessageComponentInteractionData = Union[ButtonMessageComponentInteractionData, SelectMessageComponentInteractionData]
class ModalSubmitTextInputInteractionData(TypedDict):
type: Literal[4]
custom_id: str
value: str
ModalSubmitComponentItemInteractionData = ModalSubmitTextInputInteractionData
class ModalSubmitActionRowInteractionData(TypedDict):
type: Literal[1]
components: List[ModalSubmitComponentItemInteractionData]
ModalSubmitComponentInteractionData = Union[ModalSubmitActionRowInteractionData, ModalSubmitComponentItemInteractionData]
class ModalSubmitInteractionData(TypedDict):
custom_id: str
components: List[ModalSubmitActionRowInteractionData]
InteractionData = Union[
ApplicationCommandInteractionData,
MessageComponentInteractionData,
ModalSubmitInteractionData,
]
class _BaseInteractionOptional(TypedDict, total=False):
guild_id: Snowflake
channel_id: Snowflake
locale: str
guild_locale: str
class _BaseInteraction(_BaseInteractionOptional):
id: Snowflake
application_id: Snowflake
token: str
version: Literal[1]
class PingInteraction(_BaseInteraction):
type: Literal[1]
class ApplicationCommandInteraction(_BaseInteraction):
type: Literal[2, 4]
data: ApplicationCommandInteractionData
class MessageComponentInteraction(_BaseInteraction):
type: Literal[3]
data: MessageComponentInteractionData
class ModalSubmitInteraction(_BaseInteraction):
type: Literal[5]
data: ModalSubmitInteractionData
Interaction = Union[PingInteraction, ApplicationCommandInteraction, MessageComponentInteraction, ModalSubmitInteraction]
class MessageInteraction(TypedDict):
id: Snowflake
type: InteractionType
name: str
user: User

4
discord/types/scheduled_event.py

@ -35,7 +35,7 @@ EntityType = Literal[1, 2, 3]
class _BaseGuildScheduledEventOptional(TypedDict, total=False): class _BaseGuildScheduledEventOptional(TypedDict, total=False):
creator_id: Optional[Snowflake] creator_id: Optional[Snowflake]
description: str description: Optional[str]
creator: User creator: User
user_count: int user_count: int
image: Optional[str] image: Optional[str]
@ -75,7 +75,7 @@ class EntityMetadata(TypedDict):
class ExternalScheduledEvent(_BaseGuildScheduledEvent): class ExternalScheduledEvent(_BaseGuildScheduledEvent):
channel_id: Literal[None] channel_id: Literal[None]
entity_metadata: EntityMetadata entity_metadata: EntityMetadata
scheduled_end_time: Optional[str] scheduled_end_time: str
entity_type: Literal[3] entity_type: Literal[3]

14
discord/types/widget.py

@ -46,18 +46,20 @@ class WidgetMember(User, total=False):
suppress: bool suppress: bool
class _WidgetOptional(TypedDict, total=False): class Widget(TypedDict):
id: Snowflake
name: str
instant_invite: Optional[str]
channels: List[WidgetChannel] channels: List[WidgetChannel]
members: List[WidgetMember] members: List[WidgetMember]
presence_count: int presence_count: int
class Widget(_WidgetOptional): class WidgetSettings(TypedDict):
id: Snowflake enabled: bool
name: str channel_id: Optional[Snowflake]
instant_invite: str
class WidgetSettings(TypedDict): class EditWidgetSettings(TypedDict, total=False):
enabled: bool enabled: bool
channel_id: Optional[Snowflake] channel_id: Optional[Snowflake]

10
discord/user.py

@ -245,10 +245,10 @@ class BaseUser(_UserTag):
def __str__(self) -> str: def __str__(self) -> str:
return f'{self.name}#{self.discriminator}' return f'{self.name}#{self.discriminator}'
def __eq__(self, other: Any) -> bool: def __eq__(self, other: object) -> bool:
return isinstance(other, _UserTag) and other.id == self.id return isinstance(other, _UserTag) and other.id == self.id
def __ne__(self, other: Any) -> bool: def __ne__(self, other: object) -> bool:
return not self.__eq__(other) return not self.__eq__(other)
def __hash__(self) -> int: def __hash__(self) -> int:
@ -613,8 +613,8 @@ class ClientUser(BaseUser):
The edit is no longer in-place, instead the newly edited client user is returned. The edit is no longer in-place, instead the newly edited client user is returned.
.. versionchanged:: 2.0 .. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising This function will now raise :exc:`ValueError` instead of
:exc:`ValueError`. ``InvalidArgument``.
Parameters Parameters
----------- -----------
@ -845,7 +845,7 @@ class User(BaseUser, discord.abc.Connectable, discord.abc.Messageable):
def _get_voice_state_pair(self) -> Tuple[int, int]: def _get_voice_state_pair(self) -> Tuple[int, int]:
return self._state.self_id, self.dm_channel.id return self._state.self_id, self.dm_channel.id
async def _get_channel(self): async def _get_channel(self) -> DMChannel:
ch = await self.create_dm() ch = await self.create_dm()
return ch return ch

53
discord/utils.py

@ -29,6 +29,7 @@ from typing import (
Any, Any,
AsyncIterable, AsyncIterable,
AsyncIterator, AsyncIterator,
Awaitable,
Callable, Callable,
Coroutine, Coroutine,
Dict, Dict,
@ -42,6 +43,7 @@ from typing import (
NamedTuple, NamedTuple,
Optional, Optional,
Protocol, Protocol,
Set,
Sequence, Sequence,
Tuple, Tuple,
Type, Type,
@ -76,7 +78,7 @@ import yarl
from .enums import BrowserEnum from .enums import BrowserEnum
try: try:
import orjson import orjson # type: ignore
except ModuleNotFoundError: except ModuleNotFoundError:
HAS_ORJSON = False HAS_ORJSON = False
else: else:
@ -111,6 +113,9 @@ class _MissingSentinel:
def __bool__(self): def __bool__(self):
return False return False
def __hash__(self):
return 0
def __repr__(self): def __repr__(self):
return '...' return '...'
@ -137,7 +142,7 @@ if TYPE_CHECKING:
from aiohttp import ClientSession from aiohttp import ClientSession
from functools import cached_property as cached_property from functools import cached_property as cached_property
from typing_extensions import ParamSpec from typing_extensions import ParamSpec, Self
from .permissions import Permissions from .permissions import Permissions
from .abc import Messageable, Snowflake from .abc import Messageable, Snowflake
@ -151,8 +156,16 @@ if TYPE_CHECKING:
P = ParamSpec('P') P = ParamSpec('P')
MaybeCoroFunc = Union[
Callable[P, Coroutine[Any, Any, 'T']],
Callable[P, 'T'],
]
_SnowflakeListBase = array.array[int]
else: else:
cached_property = _cached_property cached_property = _cached_property
_SnowflakeListBase = array.array
T = TypeVar('T') T = TypeVar('T')
@ -194,7 +207,7 @@ class classproperty(Generic[T_co]):
def __get__(self, instance: Optional[Any], owner: Type[Any]) -> T_co: def __get__(self, instance: Optional[Any], owner: Type[Any]) -> T_co:
return self.fget(owner) return self.fget(owner)
def __set__(self, instance, value) -> None: def __set__(self, instance: Optional[Any], value: Any) -> None:
raise AttributeError('cannot set attribute') raise AttributeError('cannot set attribute')
@ -226,7 +239,7 @@ class SequenceProxy(Sequence[T_co]):
def __reversed__(self) -> Iterator[T_co]: def __reversed__(self) -> Iterator[T_co]:
return reversed(self.__proxied) return reversed(self.__proxied)
def index(self, value: Any, *args, **kwargs) -> int: def index(self, value: Any, *args: Any, **kwargs: Any) -> int:
return self.__proxied.index(value, *args, **kwargs) return self.__proxied.index(value, *args, **kwargs)
def count(self, value: Any) -> int: def count(self, value: Any) -> int:
@ -255,10 +268,10 @@ def parse_time(timestamp: Optional[str]) -> Optional[datetime.datetime]:
def copy_doc(original: Callable) -> Callable[[T], T]: def copy_doc(original: Callable) -> Callable[[T], T]:
def decorator(overriden: T) -> T: def decorator(overridden: T) -> T:
overriden.__doc__ = original.__doc__ overridden.__doc__ = original.__doc__
overriden.__signature__ = _signature(original) # type: ignore overridden.__signature__ = _signature(original) # type: ignore
return overriden return overridden
return decorator return decorator
@ -588,9 +601,13 @@ def _bytes_to_base64_data(data: bytes) -> str:
return fmt.format(mime=mime, data=b64) return fmt.format(mime=mime, data=b64)
def _is_submodule(parent: str, child: str) -> bool:
return parent == child or child.startswith(parent + '.')
if HAS_ORJSON: if HAS_ORJSON:
def _to_json(obj: Any) -> str: # type: ignore def _to_json(obj: Any) -> str:
return orjson.dumps(obj).decode('utf-8') return orjson.dumps(obj).decode('utf-8')
_from_json = orjson.loads # type: ignore _from_json = orjson.loads # type: ignore
@ -614,15 +631,15 @@ def _parse_ratelimit_header(request: Any, *, use_clock: bool = False) -> float:
return float(reset_after) return float(reset_after)
async def maybe_coroutine(f, *args, **kwargs): async def maybe_coroutine(f: MaybeCoroFunc[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
value = f(*args, **kwargs) value = f(*args, **kwargs)
if _isawaitable(value): if _isawaitable(value):
return await value return await value
else: else:
return value return value # type: ignore
async def async_all(gen, *, check=_isawaitable): async def async_all(gen: Iterable[Awaitable[T]], *, check: Callable[[T], bool] = _isawaitable) -> bool:
for elem in gen: for elem in gen:
if check(elem): if check(elem):
elem = await elem elem = await elem
@ -631,7 +648,7 @@ async def async_all(gen, *, check=_isawaitable):
return True return True
async def sane_wait_for(futures, *, timeout): async def sane_wait_for(futures: Iterable[Awaitable[T]], *, timeout: Optional[float]) -> Set[asyncio.Task[T]]:
ensured = [asyncio.ensure_future(fut) for fut in futures] ensured = [asyncio.ensure_future(fut) for fut in futures]
done, pending = await asyncio.wait(ensured, timeout=timeout, return_when=asyncio.ALL_COMPLETED) done, pending = await asyncio.wait(ensured, timeout=timeout, return_when=asyncio.ALL_COMPLETED)
@ -649,7 +666,7 @@ def get_slots(cls: Type[Any]) -> Iterator[str]:
continue continue
def compute_timedelta(dt: datetime.datetime): def compute_timedelta(dt: datetime.datetime) -> float:
if dt.tzinfo is None: if dt.tzinfo is None:
dt = dt.astimezone() dt = dt.astimezone()
now = datetime.datetime.now(datetime.timezone.utc) now = datetime.datetime.now(datetime.timezone.utc)
@ -698,7 +715,7 @@ def valid_icon_size(size: int) -> bool:
return not size & (size - 1) and 4096 >= size >= 16 return not size & (size - 1) and 4096 >= size >= 16
class SnowflakeList(array.array): class SnowflakeList(_SnowflakeListBase):
"""Internal data storage class to efficiently store a list of snowflakes. """Internal data storage class to efficiently store a list of snowflakes.
This should have the following characteristics: This should have the following characteristics:
@ -717,7 +734,7 @@ class SnowflakeList(array.array):
def __init__(self, data: Iterable[int], *, is_sorted: bool = False): def __init__(self, data: Iterable[int], *, is_sorted: bool = False):
... ...
def __new__(cls, data: Iterable[int], *, is_sorted: bool = False): def __new__(cls, data: Iterable[int], *, is_sorted: bool = False) -> Self:
return array.array.__new__(cls, 'Q', data if is_sorted else sorted(data)) # type: ignore return array.array.__new__(cls, 'Q', data if is_sorted else sorted(data)) # type: ignore
def add(self, element: int) -> None: def add(self, element: int) -> None:
@ -1022,7 +1039,7 @@ def evaluate_annotation(
cache: Dict[str, Any], cache: Dict[str, Any],
*, *,
implicit_str: bool = True, implicit_str: bool = True,
): ) -> Any:
if isinstance(tp, ForwardRef): if isinstance(tp, ForwardRef):
tp = tp.__forward_arg__ tp = tp.__forward_arg__
# ForwardRefs always evaluate their internals # ForwardRefs always evaluate their internals
@ -1092,7 +1109,7 @@ def is_inside_class(func: Callable[..., Any]) -> bool:
# denoting which class it belongs to. So, e.g. for A.foo the qualname # denoting which class it belongs to. So, e.g. for A.foo the qualname
# would be A.foo while a global foo() would just be foo. # would be A.foo while a global foo() would just be foo.
# #
# Unfortuately, for nested functions this breaks. So inside an outer # Unfortunately, for nested functions this breaks. So inside an outer
# function named outer, those two would end up having a qualname with # function named outer, those two would end up having a qualname with
# outer.<locals>.A.foo and outer.<locals>.foo # outer.<locals>.A.foo and outer.<locals>.foo

25
discord/voice_client.py

@ -74,6 +74,7 @@ has_nacl: bool
try: try:
import nacl.secret # type: ignore import nacl.secret # type: ignore
import nacl.utils # type: ignore
has_nacl = True has_nacl = True
except ImportError: except ImportError:
@ -373,14 +374,14 @@ class VoiceClient(VoiceProtocol):
The endpoint we are connecting to. The endpoint we are connecting to.
channel: :class:`abc.Connectable` channel: :class:`abc.Connectable`
The voice channel connected to. The voice channel connected to.
loop: :class:`asyncio.AbstractEventLoop`
The event loop that the voice client is running on.
""" """
channel: abc.Connectable channel: abc.Connectable
endpoint_ip: str endpoint_ip: str
voice_port: int voice_port: int
secret_key: List[int] ip: str
port: int
secret_key: Optional[str]
def __init__(self, client: Client, channel: abc.Connectable): def __init__(self, client: Client, channel: abc.Connectable):
if not has_nacl: if not has_nacl:
@ -414,7 +415,7 @@ class VoiceClient(VoiceProtocol):
self.idrcs: Dict[int, int] = {} self.idrcs: Dict[int, int] = {}
self.ssids: Dict[int, int] = {} self.ssids: Dict[int, int] = {}
warn_nacl = not has_nacl warn_nacl: bool = not has_nacl
supported_modes: Tuple[SupportedModes, ...] = ( supported_modes: Tuple[SupportedModes, ...] = (
'xsalsa20_poly1305_lite', 'xsalsa20_poly1305_lite',
'xsalsa20_poly1305_suffix', 'xsalsa20_poly1305_suffix',
@ -443,8 +444,15 @@ class VoiceClient(VoiceProtocol):
# Connection related # Connection related
def checked_add(self, attr: str, value: int, limit: int) -> None:
val = getattr(self, attr)
if val + value > limit:
setattr(self, attr, 0)
else:
setattr(self, attr, val + value)
async def on_voice_state_update(self, data: GuildVoiceStatePayload) -> None: async def on_voice_state_update(self, data: GuildVoiceStatePayload) -> None:
self.session_id = data['session_id'] self.session_id: str = data['session_id']
channel_id = data['channel_id'] channel_id = data['channel_id']
if not self._handshaking or self._potentially_reconnecting: if not self._handshaking or self._potentially_reconnecting:
@ -484,11 +492,12 @@ class VoiceClient(VoiceProtocol):
self.endpoint, _, _ = endpoint.rpartition(':') self.endpoint, _, _ = endpoint.rpartition(':')
if self.endpoint.startswith('wss://'): if self.endpoint.startswith('wss://'):
self.endpoint = self.endpoint[6:] # Shouldn't ever be there... # Just in case, strip it off since we're going to add it later
self.endpoint: str = self.endpoint[6:]
self.endpoint_ip = MISSING self.endpoint_ip = MISSING
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self.socket: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.socket.setblocking(False) self.socket.setblocking(False)
if not self._handshaking: if not self._handshaking:
@ -575,7 +584,7 @@ class VoiceClient(VoiceProtocol):
raise raise
if self._runner is MISSING: if self._runner is MISSING:
self._runner = self.loop.create_task(self.poll_voice_ws(reconnect)) self._runner = self.client.loop.create_task(self.poll_voice_ws(reconnect))
async def potential_reconnect(self) -> bool: async def potential_reconnect(self) -> bool:
# Attempt to stop the player thread from playing early # Attempt to stop the player thread from playing early

337
discord/webhook/async_.py

@ -30,7 +30,7 @@ import json
import re import re
from urllib.parse import quote as urlquote from urllib.parse import quote as urlquote
from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Union, overload from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Sequence, Tuple, Union, TypeVar, Type, overload
from contextvars import ContextVar from contextvars import ContextVar
import weakref import weakref
@ -43,7 +43,7 @@ from ..enums import try_enum, WebhookType
from ..user import BaseUser, User from ..user import BaseUser, User
from ..flags import MessageFlags from ..flags import MessageFlags
from ..asset import Asset from ..asset import Asset
from ..http import Route, handle_message_parameters, MultipartParameters from ..http import Route, handle_message_parameters, HTTPClient
from ..mixins import Hashable from ..mixins import Hashable
from ..channel import PartialMessageable from ..channel import PartialMessageable
from ..file import File from ..file import File
@ -58,23 +58,37 @@ __all__ = (
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import Self
from types import TracebackType
from ..embeds import Embed from ..embeds import Embed
from ..mentions import AllowedMentions from ..mentions import AllowedMentions
from ..message import Attachment from ..message import Attachment
from ..state import ConnectionState from ..state import ConnectionState
from ..http import Response from ..http import Response
from ..guild import Guild
from ..channel import TextChannel
from ..abc import Snowflake
import datetime
from ..types.webhook import ( from ..types.webhook import (
Webhook as WebhookPayload, Webhook as WebhookPayload,
SourceGuild as SourceGuildPayload,
) )
from ..types.message import ( from ..types.message import (
Message as MessagePayload, Message as MessagePayload,
) )
from ..guild import Guild from ..types.user import (
from ..channel import TextChannel User as UserPayload,
from ..abc import Snowflake PartialUser as PartialUserPayload,
import datetime )
from ..types.channel import (
PartialChannel as PartialChannelPayload,
)
BE = TypeVar('BE', bound=BaseException)
_State = Union[ConnectionState, '_WebhookState']
MISSING = utils.MISSING MISSING: Any = utils.MISSING
class AsyncDeferredLock: class AsyncDeferredLock:
@ -82,14 +96,19 @@ class AsyncDeferredLock:
self.lock = lock self.lock = lock
self.delta: Optional[float] = None self.delta: Optional[float] = None
async def __aenter__(self): async def __aenter__(self) -> Self:
await self.lock.acquire() await self.lock.acquire()
return self return self
def delay_by(self, delta: float) -> None: def delay_by(self, delta: float) -> None:
self.delta = delta self.delta = delta
async def __aexit__(self, type, value, traceback): async def __aexit__(
self,
exc_type: Optional[Type[BE]],
exc: Optional[BE],
traceback: Optional[TracebackType],
) -> None:
if self.delta: if self.delta:
await asyncio.sleep(self.delta) await asyncio.sleep(self.delta)
self.lock.release() self.lock.release()
@ -106,7 +125,7 @@ class AsyncWebhookAdapter:
*, *,
payload: Optional[Dict[str, Any]] = None, payload: Optional[Dict[str, Any]] = None,
multipart: Optional[List[Dict[str, Any]]] = None, multipart: Optional[List[Dict[str, Any]]] = None,
files: Optional[List[File]] = None, files: Optional[Sequence[File]] = None,
reason: Optional[str] = None, reason: Optional[str] = None,
auth_token: Optional[str] = None, auth_token: Optional[str] = None,
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
@ -259,7 +278,7 @@ class AsyncWebhookAdapter:
session: aiohttp.ClientSession, session: aiohttp.ClientSession,
payload: Optional[Dict[str, Any]] = None, payload: Optional[Dict[str, Any]] = None,
multipart: Optional[List[Dict[str, Any]]] = None, multipart: Optional[List[Dict[str, Any]]] = None,
files: Optional[List[File]] = None, files: Optional[Sequence[File]] = None,
thread_id: Optional[int] = None, thread_id: Optional[int] = None,
wait: bool = False, wait: bool = False,
) -> Response[Optional[MessagePayload]]: ) -> Response[Optional[MessagePayload]]:
@ -276,6 +295,7 @@ class AsyncWebhookAdapter:
message_id: int, message_id: int,
*, *,
session: aiohttp.ClientSession, session: aiohttp.ClientSession,
thread_id: Optional[int] = None,
) -> Response[MessagePayload]: ) -> Response[MessagePayload]:
route = Route( route = Route(
'GET', 'GET',
@ -284,7 +304,8 @@ class AsyncWebhookAdapter:
webhook_token=token, webhook_token=token,
message_id=message_id, message_id=message_id,
) )
return self.request(route, session) params = None if thread_id is None else {'thread_id': thread_id}
return self.request(route, session, params=params)
def edit_webhook_message( def edit_webhook_message(
self, self,
@ -295,7 +316,8 @@ class AsyncWebhookAdapter:
session: aiohttp.ClientSession, session: aiohttp.ClientSession,
payload: Optional[Dict[str, Any]] = None, payload: Optional[Dict[str, Any]] = None,
multipart: Optional[List[Dict[str, Any]]] = None, multipart: Optional[List[Dict[str, Any]]] = None,
files: Optional[List[File]] = None, files: Optional[Sequence[File]] = None,
thread_id: Optional[int] = None,
) -> Response[Message]: ) -> Response[Message]:
route = Route( route = Route(
'PATCH', 'PATCH',
@ -304,7 +326,8 @@ class AsyncWebhookAdapter:
webhook_token=token, webhook_token=token,
message_id=message_id, message_id=message_id,
) )
return self.request(route, session, payload=payload, multipart=multipart, files=files) params = None if thread_id is None else {'thread_id': thread_id}
return self.request(route, session, payload=payload, multipart=multipart, files=files, params=params)
def delete_webhook_message( def delete_webhook_message(
self, self,
@ -313,6 +336,7 @@ class AsyncWebhookAdapter:
message_id: int, message_id: int,
*, *,
session: aiohttp.ClientSession, session: aiohttp.ClientSession,
thread_id: Optional[int] = None,
) -> Response[None]: ) -> Response[None]:
route = Route( route = Route(
'DELETE', 'DELETE',
@ -321,7 +345,8 @@ class AsyncWebhookAdapter:
webhook_token=token, webhook_token=token,
message_id=message_id, message_id=message_id,
) )
return self.request(route, session) params = None if thread_id is None else {'thread_id': thread_id}
return self.request(route, session, params=params)
def fetch_webhook( def fetch_webhook(
self, self,
@ -344,111 +369,6 @@ class AsyncWebhookAdapter:
return self.request(route, session=session) return self.request(route, session=session)
def interaction_response_params(type: int, data: Optional[Dict[str, Any]] = None) -> MultipartParameters:
payload: Dict[str, Any] = {
'type': type,
}
if data is not None:
payload['data'] = data
return MultipartParameters(payload=payload, multipart=None, files=None)
# This is a subset of handle_message_parameters
def interaction_message_response_params(
*,
type: int,
content: Optional[str] = MISSING,
tts: bool = False,
flags: MessageFlags = MISSING,
file: File = MISSING,
files: List[File] = MISSING,
embed: Optional[Embed] = MISSING,
embeds: List[Embed] = MISSING,
attachments: List[Union[Attachment, File]] = MISSING,
allowed_mentions: Optional[AllowedMentions] = MISSING,
previous_allowed_mentions: Optional[AllowedMentions] = None,
) -> MultipartParameters:
if files is not MISSING and file is not MISSING:
raise TypeError('Cannot mix file and files keyword arguments.')
if embeds is not MISSING and embed is not MISSING:
raise TypeError('Cannot mix embed and embeds keyword arguments.')
if file is not MISSING:
files = [file]
if attachments is not MISSING and files is not MISSING:
raise TypeError('Cannot mix attachments and files keyword arguments.')
data: Optional[Dict[str, Any]] = {
'tts': tts,
}
if embeds is not MISSING:
if len(embeds) > 10:
raise ValueError('embeds has a maximum of 10 elements.')
data['embeds'] = [e.to_dict() for e in embeds]
if embed is not MISSING:
if embed is None:
data['embeds'] = []
else:
data['embeds'] = [embed.to_dict()]
if content is not MISSING:
if content is not None:
data['content'] = str(content)
else:
data['content'] = None
if flags is not MISSING:
data['flags'] = flags.value
if allowed_mentions:
if previous_allowed_mentions is not None:
data['allowed_mentions'] = previous_allowed_mentions.merge(allowed_mentions).to_dict()
else:
data['allowed_mentions'] = allowed_mentions.to_dict()
elif previous_allowed_mentions is not None:
data['allowed_mentions'] = previous_allowed_mentions.to_dict()
if attachments is MISSING:
attachments = files # type: ignore
else:
files = [a for a in attachments if isinstance(a, File)]
if attachments is not MISSING:
file_index = 0
attachments_payload = []
for attachment in attachments:
if isinstance(attachment, File):
attachments_payload.append(attachment.to_dict(file_index))
file_index += 1
else:
attachments_payload.append(attachment.to_dict())
data['attachments'] = attachments_payload
multipart = []
if files:
data = {'type': type, 'data': data}
multipart.append({'name': 'payload_json', 'value': utils._to_json(data)})
data = None
for index, file in enumerate(files):
multipart.append(
{
'name': f'files[{index}]',
'value': file.fp,
'filename': file.filename,
'content_type': 'application/octet-stream',
}
)
else:
data = {'type': type, 'data': data}
return MultipartParameters(payload=data, multipart=multipart, files=files)
async_context: ContextVar[AsyncWebhookAdapter] = ContextVar('async_webhook_context', default=AsyncWebhookAdapter()) async_context: ContextVar[AsyncWebhookAdapter] = ContextVar('async_webhook_context', default=AsyncWebhookAdapter())
@ -469,11 +389,11 @@ class PartialWebhookChannel(Hashable):
__slots__ = ('id', 'name') __slots__ = ('id', 'name')
def __init__(self, *, data): def __init__(self, *, data: PartialChannelPayload) -> None:
self.id = int(data['id']) self.id: int = int(data['id'])
self.name = data['name'] self.name: str = data['name']
def __repr__(self): def __repr__(self) -> str:
return f'<PartialWebhookChannel name={self.name!r} id={self.id}>' return f'<PartialWebhookChannel name={self.name!r} id={self.id}>'
@ -494,13 +414,13 @@ class PartialWebhookGuild(Hashable):
__slots__ = ('id', 'name', '_icon', '_state') __slots__ = ('id', 'name', '_icon', '_state')
def __init__(self, *, data, state): def __init__(self, *, data: SourceGuildPayload, state: _State) -> None:
self._state = state self._state: _State = state
self.id = int(data['id']) self.id: int = int(data['id'])
self.name = data['name'] self.name: str = data['name']
self._icon = data['icon'] self._icon: str = data['icon']
def __repr__(self): def __repr__(self) -> str:
return f'<PartialWebhookGuild name={self.name!r} id={self.id}>' return f'<PartialWebhookGuild name={self.name!r} id={self.id}>'
@property @property
@ -514,14 +434,14 @@ class PartialWebhookGuild(Hashable):
class _FriendlyHttpAttributeErrorHelper: class _FriendlyHttpAttributeErrorHelper:
__slots__ = () __slots__ = ()
def __getattr__(self, attr): def __getattr__(self, attr: str) -> Any:
raise AttributeError('PartialWebhookState does not support http methods.') raise AttributeError('PartialWebhookState does not support http methods.')
class _WebhookState: class _WebhookState:
__slots__ = ('_parent', '_webhook') __slots__ = ('_parent', '_webhook', '_thread')
def __init__(self, webhook: Any, parent: Optional[Union[ConnectionState, _WebhookState]]): def __init__(self, webhook: Any, parent: Optional[_State], thread: Snowflake = MISSING):
self._webhook: Any = webhook self._webhook: Any = webhook
self._parent: Optional[ConnectionState] self._parent: Optional[ConnectionState]
@ -530,23 +450,25 @@ class _WebhookState:
else: else:
self._parent = parent self._parent = parent
def _get_guild(self, guild_id): self._thread: Snowflake = thread
def _get_guild(self, guild_id: Optional[int]) -> Optional[Guild]:
if self._parent is not None: if self._parent is not None:
return self._parent._get_guild(guild_id) return self._parent._get_guild(guild_id)
return None return None
def store_user(self, data): def store_user(self, data: Union[UserPayload, PartialUserPayload]) -> BaseUser:
if self._parent is not None: if self._parent is not None:
return self._parent.store_user(data) return self._parent.store_user(data)
# state parameter is artificial # state parameter is artificial
return BaseUser(state=self, data=data) # type: ignore return BaseUser(state=self, data=data) # type: ignore
def create_user(self, data): def create_user(self, data: Union[UserPayload, PartialUserPayload]) -> BaseUser:
# state parameter is artificial # state parameter is artificial
return BaseUser(state=self, data=data) # type: ignore return BaseUser(state=self, data=data) # type: ignore
@property @property
def http(self): def http(self) -> Union[HTTPClient, _FriendlyHttpAttributeErrorHelper]:
if self._parent is not None: if self._parent is not None:
return self._parent.http return self._parent.http
@ -554,7 +476,7 @@ class _WebhookState:
# However, using it should result in a late-binding error # However, using it should result in a late-binding error
return _FriendlyHttpAttributeErrorHelper() return _FriendlyHttpAttributeErrorHelper()
def __getattr__(self, attr): def __getattr__(self, attr: str) -> Any:
if self._parent is not None: if self._parent is not None:
return getattr(self._parent, attr) return getattr(self._parent, attr)
@ -578,9 +500,9 @@ class WebhookMessage(Message):
async def edit( async def edit(
self, self,
content: Optional[str] = MISSING, content: Optional[str] = MISSING,
embeds: List[Embed] = MISSING, embeds: Sequence[Embed] = MISSING,
embed: Optional[Embed] = MISSING, embed: Optional[Embed] = MISSING,
attachments: List[Union[Attachment, File]] = MISSING, attachments: Sequence[Union[Attachment, File]] = MISSING,
allowed_mentions: Optional[AllowedMentions] = None, allowed_mentions: Optional[AllowedMentions] = None,
) -> WebhookMessage: ) -> WebhookMessage:
"""|coro| """|coro|
@ -593,8 +515,8 @@ class WebhookMessage(Message):
The edit is no longer in-place, instead the newly edited message is returned. The edit is no longer in-place, instead the newly edited message is returned.
.. versionchanged:: 2.0 .. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising This function will now raise :exc:`ValueError` instead of
:exc:`ValueError`. ``InvalidArgument``.
Parameters Parameters
------------ ------------
@ -642,6 +564,7 @@ class WebhookMessage(Message):
embed=embed, embed=embed,
attachments=attachments, attachments=attachments,
allowed_mentions=allowed_mentions, allowed_mentions=allowed_mentions,
thread=self._state._thread,
) )
async def add_files(self, *files: File) -> WebhookMessage: async def add_files(self, *files: File) -> WebhookMessage:
@ -722,13 +645,13 @@ class WebhookMessage(Message):
async def inner_call(delay: float = delay): async def inner_call(delay: float = delay):
await asyncio.sleep(delay) await asyncio.sleep(delay)
try: try:
await self._state._webhook.delete_message(self.id) await self._state._webhook.delete_message(self.id, thread=self._state._thread)
except HTTPException: except HTTPException:
pass pass
asyncio.create_task(inner_call()) asyncio.create_task(inner_call())
else: else:
await self._state._webhook.delete_message(self.id) await self._state._webhook.delete_message(self.id, thread=self._state._thread)
class BaseWebhook(Hashable): class BaseWebhook(Hashable):
@ -747,19 +670,24 @@ class BaseWebhook(Hashable):
'_state', '_state',
) )
def __init__(self, data: WebhookPayload, token: Optional[str] = None, state: Optional[ConnectionState] = None): def __init__(
self,
data: WebhookPayload,
token: Optional[str] = None,
state: Optional[_State] = None,
) -> None:
self.auth_token: Optional[str] = token self.auth_token: Optional[str] = token
self._state: Union[ConnectionState, _WebhookState] = state or _WebhookState(self, parent=state) self._state: _State = state or _WebhookState(self, parent=state)
self._update(data) self._update(data)
def _update(self, data: WebhookPayload): def _update(self, data: WebhookPayload) -> None:
self.id = int(data['id']) self.id: int = int(data['id'])
self.type = try_enum(WebhookType, int(data['type'])) self.type: WebhookType = try_enum(WebhookType, int(data['type']))
self.channel_id = utils._get_as_snowflake(data, 'channel_id') self.channel_id: Optional[int] = utils._get_as_snowflake(data, 'channel_id')
self.guild_id = utils._get_as_snowflake(data, 'guild_id') self.guild_id: Optional[int] = utils._get_as_snowflake(data, 'guild_id')
self.name = data.get('name') self.name: Optional[str] = data.get('name')
self._avatar = data.get('avatar') self._avatar: Optional[str] = data.get('avatar')
self.token = data.get('token') self.token: Optional[str] = data.get('token')
user = data.get('user') user = data.get('user')
self.user: Optional[Union[BaseUser, User]] = None self.user: Optional[Union[BaseUser, User]] = None
@ -927,11 +855,17 @@ class Webhook(BaseWebhook):
__slots__: Tuple[str, ...] = ('session',) __slots__: Tuple[str, ...] = ('session',)
def __init__(self, data: WebhookPayload, session: aiohttp.ClientSession, token: Optional[str] = None, state=None): def __init__(
self,
data: WebhookPayload,
session: aiohttp.ClientSession,
token: Optional[str] = None,
state: Optional[_State] = None,
) -> None:
super().__init__(data, token, state) super().__init__(data, token, state)
self.session = session self.session: aiohttp.ClientSession = session
def __repr__(self): def __repr__(self) -> str:
return f'<Webhook id={self.id!r}>' return f'<Webhook id={self.id!r}>'
@property @property
@ -940,7 +874,7 @@ class Webhook(BaseWebhook):
return f'https://discord.com/api/webhooks/{self.id}/{self.token}' return f'https://discord.com/api/webhooks/{self.id}/{self.token}'
@classmethod @classmethod
def partial(cls, id: int, token: str, *, session: aiohttp.ClientSession, bot_token: Optional[str] = None) -> Webhook: def partial(cls, id: int, token: str, *, session: aiohttp.ClientSession, bot_token: Optional[str] = None) -> Self:
"""Creates a partial :class:`Webhook`. """Creates a partial :class:`Webhook`.
Parameters Parameters
@ -976,12 +910,12 @@ class Webhook(BaseWebhook):
return cls(data, session, token=bot_token) return cls(data, session, token=bot_token)
@classmethod @classmethod
def from_url(cls, url: str, *, session: aiohttp.ClientSession, bot_token: Optional[str] = None) -> Webhook: def from_url(cls, url: str, *, session: aiohttp.ClientSession, bot_token: Optional[str] = None) -> Self:
"""Creates a partial :class:`Webhook` from a webhook URL. """Creates a partial :class:`Webhook` from a webhook URL.
.. versionchanged:: 2.0 .. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising This function will now raise :exc:`ValueError` instead of
:exc:`ValueError`. ``InvalidArgument``.
Parameters Parameters
------------ ------------
@ -1019,7 +953,7 @@ class Webhook(BaseWebhook):
return cls(data, session, token=bot_token) # type: ignore return cls(data, session, token=bot_token) # type: ignore
@classmethod @classmethod
def _as_follower(cls, data, *, channel, user) -> Webhook: def _as_follower(cls, data, *, channel, user) -> Self:
name = f"{channel.guild} #{channel}" name = f"{channel.guild} #{channel}"
feed: WebhookPayload = { feed: WebhookPayload = {
'id': data['webhook_id'], 'id': data['webhook_id'],
@ -1035,8 +969,8 @@ class Webhook(BaseWebhook):
return cls(feed, session=session, state=state, token=state.http.token) return cls(feed, session=session, state=state, token=state.http.token)
@classmethod @classmethod
def from_state(cls, data, state) -> Webhook: def from_state(cls, data: WebhookPayload, state: ConnectionState) -> Self:
session = state.http._HTTPClient__session session = state.http._HTTPClient__session # type: ignore
return cls(data, session=session, state=state, token=state.http.token) return cls(data, session=session, state=state, token=state.http.token)
async def fetch(self, *, prefer_auth: bool = True) -> Webhook: async def fetch(self, *, prefer_auth: bool = True) -> Webhook:
@ -1085,7 +1019,7 @@ class Webhook(BaseWebhook):
return Webhook(data, self.session, token=self.auth_token, state=self._state) return Webhook(data, self.session, token=self.auth_token, state=self._state)
async def delete(self, *, reason: Optional[str] = None, prefer_auth: bool = True): async def delete(self, *, reason: Optional[str] = None, prefer_auth: bool = True) -> None:
"""|coro| """|coro|
Deletes this Webhook. Deletes this Webhook.
@ -1137,8 +1071,8 @@ class Webhook(BaseWebhook):
Edits this Webhook. Edits this Webhook.
.. versionchanged:: 2.0 .. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising This function will now raise :exc:`ValueError` instead of
:exc:`ValueError`. ``~InvalidArgument``.
Parameters Parameters
------------ ------------
@ -1203,8 +1137,8 @@ class Webhook(BaseWebhook):
return Webhook(data=data, session=self.session, token=self.auth_token, state=self._state) return Webhook(data=data, session=self.session, token=self.auth_token, state=self._state)
def _create_message(self, data): def _create_message(self, data, *, thread: Snowflake):
state = _WebhookState(self, parent=self._state) state = _WebhookState(self, parent=self._state, thread=thread)
# state may be artificial (unlikely at this point...) # state may be artificial (unlikely at this point...)
channel = self.channel or PartialMessageable(state=self._state, id=int(data['channel_id'])) # type: ignore channel = self.channel or PartialMessageable(state=self._state, id=int(data['channel_id'])) # type: ignore
# state is artificial # state is artificial
@ -1219,9 +1153,9 @@ class Webhook(BaseWebhook):
avatar_url: Any = MISSING, avatar_url: Any = MISSING,
tts: bool = MISSING, tts: bool = MISSING,
file: File = MISSING, file: File = MISSING,
files: List[File] = MISSING, files: Sequence[File] = MISSING,
embed: Embed = MISSING, embed: Embed = MISSING,
embeds: List[Embed] = MISSING, embeds: Sequence[Embed] = MISSING,
allowed_mentions: AllowedMentions = MISSING, allowed_mentions: AllowedMentions = MISSING,
thread: Snowflake = MISSING, thread: Snowflake = MISSING,
wait: Literal[True], wait: Literal[True],
@ -1238,9 +1172,9 @@ class Webhook(BaseWebhook):
avatar_url: Any = MISSING, avatar_url: Any = MISSING,
tts: bool = MISSING, tts: bool = MISSING,
file: File = MISSING, file: File = MISSING,
files: List[File] = MISSING, files: Sequence[File] = MISSING,
embed: Embed = MISSING, embed: Embed = MISSING,
embeds: List[Embed] = MISSING, embeds: Sequence[Embed] = MISSING,
allowed_mentions: AllowedMentions = MISSING, allowed_mentions: AllowedMentions = MISSING,
thread: Snowflake = MISSING, thread: Snowflake = MISSING,
wait: Literal[False] = ..., wait: Literal[False] = ...,
@ -1256,9 +1190,9 @@ class Webhook(BaseWebhook):
avatar_url: Any = MISSING, avatar_url: Any = MISSING,
tts: bool = False, tts: bool = False,
file: File = MISSING, file: File = MISSING,
files: List[File] = MISSING, files: Sequence[File] = MISSING,
embed: Embed = MISSING, embed: Embed = MISSING,
embeds: List[Embed] = MISSING, embeds: Sequence[Embed] = MISSING,
allowed_mentions: AllowedMentions = MISSING, allowed_mentions: AllowedMentions = MISSING,
thread: Snowflake = MISSING, thread: Snowflake = MISSING,
wait: bool = False, wait: bool = False,
@ -1278,8 +1212,8 @@ class Webhook(BaseWebhook):
``embeds`` parameter, which must be a :class:`list` of :class:`Embed` objects to send. ``embeds`` parameter, which must be a :class:`list` of :class:`Embed` objects to send.
.. versionchanged:: 2.0 .. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising This function will now raise :exc:`ValueError` instead of
:exc:`ValueError`. ``InvalidArgument``.
Parameters Parameters
------------ ------------
@ -1389,11 +1323,11 @@ class Webhook(BaseWebhook):
msg = None msg = None
if wait: if wait:
msg = self._create_message(data) msg = self._create_message(data, thread=thread)
return msg return msg
async def fetch_message(self, id: int, /) -> WebhookMessage: async def fetch_message(self, id: int, /, *, thread: Snowflake = MISSING) -> WebhookMessage:
"""|coro| """|coro|
Retrieves a single :class:`~discord.WebhookMessage` owned by this webhook. Retrieves a single :class:`~discord.WebhookMessage` owned by this webhook.
@ -1404,6 +1338,8 @@ class Webhook(BaseWebhook):
------------ ------------
id: :class:`int` id: :class:`int`
The message ID to look for. The message ID to look for.
thread: :class:`~discord.abc.Snowflake`
The thread to look in.
Raises Raises
-------- --------
@ -1425,24 +1361,30 @@ class Webhook(BaseWebhook):
if self.token is None: if self.token is None:
raise ValueError('This webhook does not have a token associated with it') raise ValueError('This webhook does not have a token associated with it')
thread_id: Optional[int] = None
if thread is not MISSING:
thread_id = thread.id
adapter = async_context.get() adapter = async_context.get()
data = await adapter.get_webhook_message( data = await adapter.get_webhook_message(
self.id, self.id,
self.token, self.token,
id, id,
session=self.session, session=self.session,
thread_id=thread_id,
) )
return self._create_message(data) return self._create_message(data, thread=thread)
async def edit_message( async def edit_message(
self, self,
message_id: int, message_id: int,
*, *,
content: Optional[str] = MISSING, content: Optional[str] = MISSING,
embeds: List[Embed] = MISSING, embeds: Sequence[Embed] = MISSING,
embed: Optional[Embed] = MISSING, embed: Optional[Embed] = MISSING,
attachments: List[Union[Attachment, File]] = MISSING, attachments: Sequence[Union[Attachment, File]] = MISSING,
allowed_mentions: Optional[AllowedMentions] = None, allowed_mentions: Optional[AllowedMentions] = None,
thread: Snowflake = MISSING,
) -> WebhookMessage: ) -> WebhookMessage:
"""|coro| """|coro|
@ -1457,8 +1399,8 @@ class Webhook(BaseWebhook):
The edit is no longer in-place, instead the newly edited message is returned. The edit is no longer in-place, instead the newly edited message is returned.
.. versionchanged:: 2.0 .. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising This function will now raise :exc:`ValueError` instead of
:exc:`ValueError`. ``InvalidArgument``.
Parameters Parameters
------------ ------------
@ -1479,6 +1421,10 @@ class Webhook(BaseWebhook):
allowed_mentions: :class:`AllowedMentions` allowed_mentions: :class:`AllowedMentions`
Controls the mentions being processed in this message. Controls the mentions being processed in this message.
See :meth:`.abc.Messageable.send` for more information. See :meth:`.abc.Messageable.send` for more information.
thread: :class:`~discord.abc.Snowflake`
The thread the webhook message belongs to.
.. versionadded:: 2.0
Raises Raises
------- -------
@ -1511,6 +1457,11 @@ class Webhook(BaseWebhook):
allowed_mentions=allowed_mentions, allowed_mentions=allowed_mentions,
previous_allowed_mentions=previous_mentions, previous_allowed_mentions=previous_mentions,
) )
thread_id: Optional[int] = None
if thread is not MISSING:
thread_id = thread.id
adapter = async_context.get() adapter = async_context.get()
data = await adapter.edit_webhook_message( data = await adapter.edit_webhook_message(
self.id, self.id,
@ -1520,12 +1471,13 @@ class Webhook(BaseWebhook):
payload=params.payload, payload=params.payload,
multipart=params.multipart, multipart=params.multipart,
files=params.files, files=params.files,
thread_id=thread_id,
) )
message = self._create_message(data) message = self._create_message(data, thread=thread)
return message return message
async def delete_message(self, message_id: int, /) -> None: async def delete_message(self, message_id: int, /, *, thread: Snowflake = MISSING) -> None:
"""|coro| """|coro|
Deletes a message owned by this webhook. Deletes a message owned by this webhook.
@ -1540,13 +1492,17 @@ class Webhook(BaseWebhook):
``message_id`` parameter is now positional-only. ``message_id`` parameter is now positional-only.
.. versionchanged:: 2.0 .. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising This function will now raise :exc:`ValueError` instead of
:exc:`ValueError`. ``InvalidArgument``.
Parameters Parameters
------------ ------------
message_id: :class:`int` message_id: :class:`int`
The message ID to delete. The message ID to delete.
thread: :class:`~discord.abc.Snowflake`
The thread the webhook message belongs to.
.. versionadded:: 2.0
Raises Raises
------- -------
@ -1560,10 +1516,15 @@ class Webhook(BaseWebhook):
if self.token is None: if self.token is None:
raise ValueError('This webhook does not have a token associated with it') raise ValueError('This webhook does not have a token associated with it')
thread_id: Optional[int] = None
if thread is not MISSING:
thread_id = thread.id
adapter = async_context.get() adapter = async_context.get()
await adapter.delete_webhook_message( await adapter.delete_webhook_message(
self.id, self.id,
self.token, self.token,
message_id, message_id,
session=self.session, session=self.session,
thread_id=thread_id,
) )

148
discord/webhook/sync.py

@ -37,7 +37,7 @@ import time
import re import re
from urllib.parse import quote as urlquote from urllib.parse import quote as urlquote
from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Union, overload from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Sequence, Tuple, Union, TypeVar, Type, overload
import weakref import weakref
from .. import utils from .. import utils
@ -56,36 +56,50 @@ __all__ = (
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import Self
from types import TracebackType
from ..file import File from ..file import File
from ..embeds import Embed from ..embeds import Embed
from ..mentions import AllowedMentions from ..mentions import AllowedMentions
from ..message import Attachment from ..message import Attachment
from ..abc import Snowflake
from ..state import ConnectionState
from ..types.webhook import ( from ..types.webhook import (
Webhook as WebhookPayload, Webhook as WebhookPayload,
) )
from ..abc import Snowflake from ..types.message import (
Message as MessagePayload,
)
BE = TypeVar('BE', bound=BaseException)
try: try:
from requests import Session, Response from requests import Session, Response
except ModuleNotFoundError: except ModuleNotFoundError:
pass pass
MISSING = utils.MISSING MISSING: Any = utils.MISSING
class DeferredLock: class DeferredLock:
def __init__(self, lock: threading.Lock): def __init__(self, lock: threading.Lock) -> None:
self.lock = lock self.lock: threading.Lock = lock
self.delta: Optional[float] = None self.delta: Optional[float] = None
def __enter__(self): def __enter__(self) -> Self:
self.lock.acquire() self.lock.acquire()
return self return self
def delay_by(self, delta: float) -> None: def delay_by(self, delta: float) -> None:
self.delta = delta self.delta = delta
def __exit__(self, type, value, traceback): def __exit__(
self,
exc_type: Optional[Type[BE]],
exc: Optional[BE],
traceback: Optional[TracebackType],
) -> None:
if self.delta: if self.delta:
time.sleep(self.delta) time.sleep(self.delta)
self.lock.release() self.lock.release()
@ -102,7 +116,7 @@ class WebhookAdapter:
*, *,
payload: Optional[Dict[str, Any]] = None, payload: Optional[Dict[str, Any]] = None,
multipart: Optional[List[Dict[str, Any]]] = None, multipart: Optional[List[Dict[str, Any]]] = None,
files: Optional[List[File]] = None, files: Optional[Sequence[File]] = None,
reason: Optional[str] = None, reason: Optional[str] = None,
auth_token: Optional[str] = None, auth_token: Optional[str] = None,
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
@ -218,7 +232,7 @@ class WebhookAdapter:
token: Optional[str] = None, token: Optional[str] = None,
session: Session, session: Session,
reason: Optional[str] = None, reason: Optional[str] = None,
): ) -> None:
route = Route('DELETE', '/webhooks/{webhook_id}', webhook_id=webhook_id) route = Route('DELETE', '/webhooks/{webhook_id}', webhook_id=webhook_id)
return self.request(route, session, reason=reason, auth_token=token) return self.request(route, session, reason=reason, auth_token=token)
@ -229,7 +243,7 @@ class WebhookAdapter:
*, *,
session: Session, session: Session,
reason: Optional[str] = None, reason: Optional[str] = None,
): ) -> None:
route = Route('DELETE', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token) route = Route('DELETE', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
return self.request(route, session, reason=reason) return self.request(route, session, reason=reason)
@ -241,7 +255,7 @@ class WebhookAdapter:
*, *,
session: Session, session: Session,
reason: Optional[str] = None, reason: Optional[str] = None,
): ) -> WebhookPayload:
route = Route('PATCH', '/webhooks/{webhook_id}', webhook_id=webhook_id) route = Route('PATCH', '/webhooks/{webhook_id}', webhook_id=webhook_id)
return self.request(route, session, reason=reason, payload=payload, auth_token=token) return self.request(route, session, reason=reason, payload=payload, auth_token=token)
@ -253,7 +267,7 @@ class WebhookAdapter:
*, *,
session: Session, session: Session,
reason: Optional[str] = None, reason: Optional[str] = None,
): ) -> WebhookPayload:
route = Route('PATCH', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token) route = Route('PATCH', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
return self.request(route, session, reason=reason, payload=payload) return self.request(route, session, reason=reason, payload=payload)
@ -265,10 +279,10 @@ class WebhookAdapter:
session: Session, session: Session,
payload: Optional[Dict[str, Any]] = None, payload: Optional[Dict[str, Any]] = None,
multipart: Optional[List[Dict[str, Any]]] = None, multipart: Optional[List[Dict[str, Any]]] = None,
files: Optional[List[File]] = None, files: Optional[Sequence[File]] = None,
thread_id: Optional[int] = None, thread_id: Optional[int] = None,
wait: bool = False, wait: bool = False,
): ) -> MessagePayload:
params = {'wait': int(wait)} params = {'wait': int(wait)}
if thread_id: if thread_id:
params['thread_id'] = thread_id params['thread_id'] = thread_id
@ -282,7 +296,8 @@ class WebhookAdapter:
message_id: int, message_id: int,
*, *,
session: Session, session: Session,
): thread_id: Optional[int] = None,
) -> MessagePayload:
route = Route( route = Route(
'GET', 'GET',
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}', '/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}',
@ -290,7 +305,8 @@ class WebhookAdapter:
webhook_token=token, webhook_token=token,
message_id=message_id, message_id=message_id,
) )
return self.request(route, session) params = None if thread_id is None else {'thread_id': thread_id}
return self.request(route, session, params=params)
def edit_webhook_message( def edit_webhook_message(
self, self,
@ -301,8 +317,9 @@ class WebhookAdapter:
session: Session, session: Session,
payload: Optional[Dict[str, Any]] = None, payload: Optional[Dict[str, Any]] = None,
multipart: Optional[List[Dict[str, Any]]] = None, multipart: Optional[List[Dict[str, Any]]] = None,
files: Optional[List[File]] = None, files: Optional[Sequence[File]] = None,
): thread_id: Optional[int] = None,
) -> MessagePayload:
route = Route( route = Route(
'PATCH', 'PATCH',
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}', '/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}',
@ -310,7 +327,8 @@ class WebhookAdapter:
webhook_token=token, webhook_token=token,
message_id=message_id, message_id=message_id,
) )
return self.request(route, session, payload=payload, multipart=multipart, files=files) params = None if thread_id is None else {'thread_id': thread_id}
return self.request(route, session, payload=payload, multipart=multipart, files=files, params=params)
def delete_webhook_message( def delete_webhook_message(
self, self,
@ -319,7 +337,8 @@ class WebhookAdapter:
message_id: int, message_id: int,
*, *,
session: Session, session: Session,
): thread_id: Optional[int] = None,
) -> None:
route = Route( route = Route(
'DELETE', 'DELETE',
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}', '/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}',
@ -327,7 +346,8 @@ class WebhookAdapter:
webhook_token=token, webhook_token=token,
message_id=message_id, message_id=message_id,
) )
return self.request(route, session) params = None if thread_id is None else {'thread_id': thread_id}
return self.request(route, session, params=params)
def fetch_webhook( def fetch_webhook(
self, self,
@ -335,7 +355,7 @@ class WebhookAdapter:
token: str, token: str,
*, *,
session: Session, session: Session,
): ) -> WebhookPayload:
route = Route('GET', '/webhooks/{webhook_id}', webhook_id=webhook_id) route = Route('GET', '/webhooks/{webhook_id}', webhook_id=webhook_id)
return self.request(route, session=session, auth_token=token) return self.request(route, session=session, auth_token=token)
@ -345,7 +365,7 @@ class WebhookAdapter:
token: str, token: str,
*, *,
session: Session, session: Session,
): ) -> WebhookPayload:
route = Route('GET', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token) route = Route('GET', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
return self.request(route, session=session) return self.request(route, session=session)
@ -380,16 +400,16 @@ class SyncWebhookMessage(Message):
def edit( def edit(
self, self,
content: Optional[str] = MISSING, content: Optional[str] = MISSING,
embeds: List[Embed] = MISSING, embeds: Sequence[Embed] = MISSING,
embed: Optional[Embed] = MISSING, embed: Optional[Embed] = MISSING,
attachments: List[Union[Attachment, File]] = MISSING, attachments: Sequence[Union[Attachment, File]] = MISSING,
allowed_mentions: Optional[AllowedMentions] = None, allowed_mentions: Optional[AllowedMentions] = None,
) -> SyncWebhookMessage: ) -> SyncWebhookMessage:
"""Edits the message. """Edits the message.
.. versionchanged:: 2.0 .. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising This function will now raise :exc:`TypeError` or
:exc:`ValueError` or :exc:`TypeError` in various cases. :exc:`ValueError` instead of ``InvalidArgument``.
Parameters Parameters
------------ ------------
@ -437,6 +457,7 @@ class SyncWebhookMessage(Message):
embed=embed, embed=embed,
attachments=attachments, attachments=attachments,
allowed_mentions=allowed_mentions, allowed_mentions=allowed_mentions,
thread=self._state._thread,
) )
def add_files(self, *files: File) -> SyncWebhookMessage: def add_files(self, *files: File) -> SyncWebhookMessage:
@ -508,7 +529,7 @@ class SyncWebhookMessage(Message):
if delay is not None: if delay is not None:
time.sleep(delay) time.sleep(delay)
self._state._webhook.delete_message(self.id) self._state._webhook.delete_message(self.id, thread=self._state._thread)
class SyncWebhook(BaseWebhook): class SyncWebhook(BaseWebhook):
@ -569,11 +590,17 @@ class SyncWebhook(BaseWebhook):
__slots__: Tuple[str, ...] = ('session',) __slots__: Tuple[str, ...] = ('session',)
def __init__(self, data: WebhookPayload, session: Session, token: Optional[str] = None, state=None): def __init__(
self,
data: WebhookPayload,
session: Session,
token: Optional[str] = None,
state: Optional[Union[ConnectionState, _WebhookState]] = None,
) -> None:
super().__init__(data, token, state) super().__init__(data, token, state)
self.session = session self.session: Session = session
def __repr__(self): def __repr__(self) -> str:
return f'<Webhook id={self.id!r}>' return f'<Webhook id={self.id!r}>'
@property @property
@ -812,8 +839,8 @@ class SyncWebhook(BaseWebhook):
return SyncWebhook(data=data, session=self.session, token=self.auth_token, state=self._state) return SyncWebhook(data=data, session=self.session, token=self.auth_token, state=self._state)
def _create_message(self, data): def _create_message(self, data: MessagePayload, *, thread: Snowflake = MISSING) -> SyncWebhookMessage:
state = _WebhookState(self, parent=self._state) state = _WebhookState(self, parent=self._state, thread=thread)
# state may be artificial (unlikely at this point...) # state may be artificial (unlikely at this point...)
channel = self.channel or PartialMessageable(state=self._state, id=int(data['channel_id'])) # type: ignore channel = self.channel or PartialMessageable(state=self._state, id=int(data['channel_id'])) # type: ignore
# state is artificial # state is artificial
@ -828,10 +855,11 @@ class SyncWebhook(BaseWebhook):
avatar_url: Any = MISSING, avatar_url: Any = MISSING,
tts: bool = MISSING, tts: bool = MISSING,
file: File = MISSING, file: File = MISSING,
files: List[File] = MISSING, files: Sequence[File] = MISSING,
embed: Embed = MISSING, embed: Embed = MISSING,
embeds: List[Embed] = MISSING, embeds: Sequence[Embed] = MISSING,
allowed_mentions: AllowedMentions = MISSING, allowed_mentions: AllowedMentions = MISSING,
thread: Snowflake = MISSING,
wait: Literal[True], wait: Literal[True],
suppress_embeds: bool = MISSING, suppress_embeds: bool = MISSING,
) -> SyncWebhookMessage: ) -> SyncWebhookMessage:
@ -846,10 +874,11 @@ class SyncWebhook(BaseWebhook):
avatar_url: Any = MISSING, avatar_url: Any = MISSING,
tts: bool = MISSING, tts: bool = MISSING,
file: File = MISSING, file: File = MISSING,
files: List[File] = MISSING, files: Sequence[File] = MISSING,
embed: Embed = MISSING, embed: Embed = MISSING,
embeds: List[Embed] = MISSING, embeds: Sequence[Embed] = MISSING,
allowed_mentions: AllowedMentions = MISSING, allowed_mentions: AllowedMentions = MISSING,
thread: Snowflake = MISSING,
wait: Literal[False] = ..., wait: Literal[False] = ...,
suppress_embeds: bool = MISSING, suppress_embeds: bool = MISSING,
) -> None: ) -> None:
@ -863,9 +892,9 @@ class SyncWebhook(BaseWebhook):
avatar_url: Any = MISSING, avatar_url: Any = MISSING,
tts: bool = False, tts: bool = False,
file: File = MISSING, file: File = MISSING,
files: List[File] = MISSING, files: Sequence[File] = MISSING,
embed: Embed = MISSING, embed: Embed = MISSING,
embeds: List[Embed] = MISSING, embeds: Sequence[Embed] = MISSING,
allowed_mentions: AllowedMentions = MISSING, allowed_mentions: AllowedMentions = MISSING,
thread: Snowflake = MISSING, thread: Snowflake = MISSING,
wait: bool = False, wait: bool = False,
@ -984,9 +1013,9 @@ class SyncWebhook(BaseWebhook):
wait=wait, wait=wait,
) )
if wait: if wait:
return self._create_message(data) return self._create_message(data, thread=thread)
def fetch_message(self, id: int, /) -> SyncWebhookMessage: def fetch_message(self, id: int, /, *, thread: Snowflake = MISSING) -> SyncWebhookMessage:
"""Retrieves a single :class:`~discord.SyncWebhookMessage` owned by this webhook. """Retrieves a single :class:`~discord.SyncWebhookMessage` owned by this webhook.
.. versionadded:: 2.0 .. versionadded:: 2.0
@ -995,6 +1024,8 @@ class SyncWebhook(BaseWebhook):
------------ ------------
id: :class:`int` id: :class:`int`
The message ID to look for. The message ID to look for.
thread: :class:`~discord.abc.Snowflake`
The thread to look in.
Raises Raises
-------- --------
@ -1016,24 +1047,30 @@ class SyncWebhook(BaseWebhook):
if self.token is None: if self.token is None:
raise ValueError('This webhook does not have a token associated with it') raise ValueError('This webhook does not have a token associated with it')
thread_id: Optional[int] = None
if thread is not MISSING:
thread_id = thread.id
adapter: WebhookAdapter = _get_webhook_adapter() adapter: WebhookAdapter = _get_webhook_adapter()
data = adapter.get_webhook_message( data = adapter.get_webhook_message(
self.id, self.id,
self.token, self.token,
id, id,
session=self.session, session=self.session,
thread_id=thread_id,
) )
return self._create_message(data) return self._create_message(data, thread=thread)
def edit_message( def edit_message(
self, self,
message_id: int, message_id: int,
*, *,
content: Optional[str] = MISSING, content: Optional[str] = MISSING,
embeds: List[Embed] = MISSING, embeds: Sequence[Embed] = MISSING,
embed: Optional[Embed] = MISSING, embed: Optional[Embed] = MISSING,
attachments: List[Union[Attachment, File]] = MISSING, attachments: Sequence[Union[Attachment, File]] = MISSING,
allowed_mentions: Optional[AllowedMentions] = None, allowed_mentions: Optional[AllowedMentions] = None,
thread: Snowflake = MISSING,
) -> SyncWebhookMessage: ) -> SyncWebhookMessage:
"""Edits a message owned by this webhook. """Edits a message owned by this webhook.
@ -1061,6 +1098,10 @@ class SyncWebhook(BaseWebhook):
allowed_mentions: :class:`AllowedMentions` allowed_mentions: :class:`AllowedMentions`
Controls the mentions being processed in this message. Controls the mentions being processed in this message.
See :meth:`.abc.Messageable.send` for more information. See :meth:`.abc.Messageable.send` for more information.
thread: :class:`~discord.abc.Snowflake`
The thread the webhook message belongs to.
.. versionadded:: 2.0
Raises Raises
------- -------
@ -1087,6 +1128,11 @@ class SyncWebhook(BaseWebhook):
allowed_mentions=allowed_mentions, allowed_mentions=allowed_mentions,
previous_allowed_mentions=previous_mentions, previous_allowed_mentions=previous_mentions,
) )
thread_id: Optional[int] = None
if thread is not MISSING:
thread_id = thread.id
adapter: WebhookAdapter = _get_webhook_adapter() adapter: WebhookAdapter = _get_webhook_adapter()
data = adapter.edit_webhook_message( data = adapter.edit_webhook_message(
self.id, self.id,
@ -1096,10 +1142,11 @@ class SyncWebhook(BaseWebhook):
payload=params.payload, payload=params.payload,
multipart=params.multipart, multipart=params.multipart,
files=params.files, files=params.files,
thread_id=thread_id,
) )
return self._create_message(data) return self._create_message(data, thread=thread)
def delete_message(self, message_id: int, /) -> None: def delete_message(self, message_id: int, /, *, thread: Snowflake = MISSING) -> None:
"""Deletes a message owned by this webhook. """Deletes a message owned by this webhook.
This is a lower level interface to :meth:`WebhookMessage.delete` in case This is a lower level interface to :meth:`WebhookMessage.delete` in case
@ -1111,6 +1158,10 @@ class SyncWebhook(BaseWebhook):
------------ ------------
message_id: :class:`int` message_id: :class:`int`
The message ID to delete. The message ID to delete.
hread: :class:`~discord.abc.Snowflake`
The thread the webhook message belongs to.
.. versionadded:: 2.0
Raises Raises
------- -------
@ -1124,10 +1175,15 @@ class SyncWebhook(BaseWebhook):
if self.token is None: if self.token is None:
raise ValueError('This webhook does not have a token associated with it') raise ValueError('This webhook does not have a token associated with it')
thread_id: Optional[int] = None
if thread is not MISSING:
thread_id = thread.id
adapter: WebhookAdapter = _get_webhook_adapter() adapter: WebhookAdapter = _get_webhook_adapter()
adapter.delete_webhook_message( adapter.delete_webhook_message(
self.id, self.id,
self.token, self.token,
message_id, message_id,
session=self.session, session=self.session,
thread_id=thread_id,
) )

25
discord/widget.py

@ -188,7 +188,7 @@ class WidgetMember(BaseUser):
except KeyError: except KeyError:
activity = None activity = None
else: else:
activity = create_activity(game) activity = create_activity(game, state)
self.activity: Optional[Union[BaseActivity, Spotify]] = activity self.activity: Optional[Union[BaseActivity, Spotify]] = activity
@ -231,7 +231,7 @@ class Widget:
channels: List[:class:`WidgetChannel`] channels: List[:class:`WidgetChannel`]
The accessible voice channels in the guild. The accessible voice channels in the guild.
members: List[:class:`Member`] members: List[:class:`Member`]
The online members in the server. Offline members The online members in the guild. Offline members
do not appear in the widget. do not appear in the widget.
.. note:: .. note::
@ -240,10 +240,15 @@ class Widget:
the users will be "anonymized" with linear IDs and discriminator the users will be "anonymized" with linear IDs and discriminator
information being incorrect. Likewise, the number of members information being incorrect. Likewise, the number of members
retrieved is capped. retrieved is capped.
presence_count: :class:`int`
The approximate number of online members in the guild.
Offline members are not included in this count.
.. versionadded:: 2.0
""" """
__slots__ = ('_state', 'channels', '_invite', 'id', 'members', 'name') __slots__ = ('_state', 'channels', '_invite', 'id', 'members', 'name', 'presence_count')
def __init__(self, *, state: ConnectionState, data: WidgetPayload) -> None: def __init__(self, *, state: ConnectionState, data: WidgetPayload) -> None:
self._state = state self._state = state
@ -268,10 +273,12 @@ class Widget:
self.members.append(WidgetMember(state=self._state, data=member, connected_channel=connected_channel)) self.members.append(WidgetMember(state=self._state, data=member, connected_channel=connected_channel))
self.presence_count: int = data['presence_count']
def __str__(self) -> str: def __str__(self) -> str:
return self.json_url return self.json_url
def __eq__(self, other: Any) -> bool: def __eq__(self, other: object) -> bool:
if isinstance(other, Widget): if isinstance(other, Widget):
return self.id == other.id return self.id == other.id
return False return False
@ -290,11 +297,11 @@ class Widget:
return f"https://discord.com/api/guilds/{self.id}/widget.json" return f"https://discord.com/api/guilds/{self.id}/widget.json"
@property @property
def invite_url(self) -> str: def invite_url(self) -> Optional[str]:
"""Optional[:class:`str`]: The invite URL for the guild, if available.""" """Optional[:class:`str`]: The invite URL for the guild, if available."""
return self._invite return self._invite
async def fetch_invite(self, *, with_counts: bool = True) -> Invite: async def fetch_invite(self, *, with_counts: bool = True) -> Optional[Invite]:
"""|coro| """|coro|
Retrieves an :class:`Invite` from the widget's invite URL. Retrieves an :class:`Invite` from the widget's invite URL.
@ -310,9 +317,11 @@ class Widget:
Returns Returns
-------- --------
:class:`Invite` Optional[:class:`Invite`]
The invite from the widget's invite URL. The invite from the widget's invite URL, if available.
""" """
if self._invite:
resolved = resolve_invite(self._invite) resolved = resolve_invite(self._invite)
data = await self._state.http.get_invite(resolved.code, with_counts=with_counts) data = await self._state.http.get_invite(resolved.code, with_counts=with_counts)
return Invite.from_incomplete(state=self._state, data=data) return Invite.from_incomplete(state=self._state, data=data)
return None

16
docs/_static/custom.js

@ -95,3 +95,19 @@ document.addEventListener('keydown', (event) => {
activeModal.close(); activeModal.close();
} }
}); });
function searchBarClick(event, which) {
event.preventDefault();
if (event.button === 1 || event.buttons === 4) {
which.target = "_blank"; // Middle mouse button was clicked. Set our target to a new tab.
}
else if (event.button === 2) {
return // Right button was clicked... Don't do anything here.
}
else {
which.target = "_self"; // Revert to same window.
}
which.submit();
}

18
docs/_static/style.css

@ -61,6 +61,8 @@ Historically however, thanks to:
--search-focus: var(--blue-1); --search-focus: var(--blue-1);
--search-button: var(--grey-1); --search-button: var(--grey-1);
--search-button-hover: var(--grey-1-8); --search-button-hover: var(--grey-1-8);
--search-sidebar-background: var(--grey-1);
--search-sidebar-text: var(--grey-7);
--footer-text: var(--grey-5); --footer-text: var(--grey-5);
--footer-link: var(--grey-6); --footer-link: var(--grey-6);
--hr-border: var(--grey-2); --hr-border: var(--grey-2);
@ -167,6 +169,8 @@ Historically however, thanks to:
--attribute-table-entry-hover-text: var(--blue-1); --attribute-table-entry-hover-text: var(--blue-1);
--attribute-table-badge: var(--grey-4); --attribute-table-badge: var(--grey-4);
--highlighted-text: rgba(250, 166, 26, 0.2); --highlighted-text: rgba(250, 166, 26, 0.2);
--search-sidebar-background: var(--grey-7);
--search-sidebar-text: var(--search-text);
} }
img[src$="snake_dark.svg"] { img[src$="snake_dark.svg"] {
@ -523,6 +527,20 @@ input[type=search]:focus ~ button[type=submit] {
color: var(--search-button-hover); color: var(--search-button-hover);
} }
/* search sidebar */
.search-sidebar > input[type=search],
.search-sidebar > button[type=submit] {
background-color: var(--search-sidebar-background);
color: var(--search-sidebar-text);
}
.sidebar-toggle .search-sidebar > input[type=search],
.sidebar-toggle .search-sidebar > button[type=submit] {
background-color: var(--mobile-nav-background);
color: var(--mobile-nav-text);
}
/* main content area */ /* main content area */
main { main {

12
docs/_templates/layout.html

@ -89,10 +89,10 @@
<option value="{{ pathto(p + '/index')|e }}" {% if pagename is prefixedwith p %}selected{% endif %}>{{ ext }}</option> <option value="{{ pathto(p + '/index')|e }}" {% if pagename is prefixedwith p %}selected{% endif %}>{{ ext }}</option>
{%- endfor %} {%- endfor %}
</select> </select>
<form role="search" class="search" action="{{ pathto('search') }}" method="get"> <form id="search-form" role="search" class="search" action="{{ pathto('search') }}" method="get">
<div class="search-wrapper"> <div class="search-wrapper">
<input type="search" name="q" placeholder="{{ _('Search documentation') }}" /> <input type="search" name="q" placeholder="{{ _('Search documentation') }}" />
<button type="submit"> <button type="submit" onmousedown="searchBarClick(event, document.getElementById('search-form'));">
<span class="material-icons">search</span> <span class="material-icons">search</span>
</button> </button>
</div> </div>
@ -110,6 +110,14 @@
<span class="material-icons">settings</span> <span class="material-icons">settings</span>
</span> </span>
<div id="sidebar"> <div id="sidebar">
<form id="sidebar-form" role="search" class="search" action="{{ pathto('search') }}" method="get">
<div class="search-wrapper search-sidebar">
<input type="search" name="q" placeholder="{{ _('Search documentation') }}" />
<button type="submit" onmousedown="searchBarClick(event, document.getElementById('sidebar-form'));">
<span class="material-icons">search</span>
</button>
</div>
</form>
{%- include "localtoc.html" %} {%- include "localtoc.html" %}
</div> </div>
</aside> </aside>

1004
docs/api.rst

File diff suppressed because it is too large

2
docs/conf.py

@ -145,6 +145,8 @@ pygments_style = 'friendly'
# Nitpicky mode options # Nitpicky mode options
nitpick_ignore_files = [ nitpick_ignore_files = [
"migrating_to_async",
"migrating_to_v1",
"migrating", "migrating",
"whats_new", "whats_new",
] ]

21
docs/crowdin.yml

@ -0,0 +1,21 @@
# -*- coding: utf-8 -*-
project_id: "362783"
api_token_env: CROWDIN_API_KEY
files:
- source: /_build/locale/**/*.pot
translation: /locale/%two_letters_code%/LC_MESSAGES/%original_path%/%file_name%.po
# You must use `crowdin download --all` for this project
# I discovered after like an hour of debugging the Java CLI that `--all` actually means "use server sources"
# Without this, crowdin tries to determine the mapping itself, and decides that because
# `/locale/ja/LC_MESSAGES/_build/locale/...` doesn't exist, that it won't download anything
# There is no workaround for this. I tried. Trying to adjust the project base path just breaks things further.
# Crowdin does the conflict resolution on its end. The process to update translations is thus:
# - make gettext
# - crowdin upload
# - crowdin download --all
# You must set ${CROWDIN_API_KEY} in the environment.
# I will write an Actions workflow for this at a later date.

7
docs/ext/commands/api.rst

@ -411,6 +411,9 @@ Converters
.. autoclass:: discord.ext.commands.GuildStickerConverter .. autoclass:: discord.ext.commands.GuildStickerConverter
:members: :members:
.. autoclass:: discord.ext.commands.ScheduledEventConverter
:members:
.. autoclass:: discord.ext.commands.clean_content .. autoclass:: discord.ext.commands.clean_content
:members: :members:
@ -539,6 +542,9 @@ Exceptions
.. autoexception:: discord.ext.commands.GuildStickerNotFound .. autoexception:: discord.ext.commands.GuildStickerNotFound
:members: :members:
.. autoexception:: discord.ext.commands.ScheduledEventNotFound
:members:
.. autoexception:: discord.ext.commands.BadBoolArgument .. autoexception:: discord.ext.commands.BadBoolArgument
:members: :members:
@ -623,6 +629,7 @@ Exception Hierarchy
- :exc:`~.commands.BadInviteArgument` - :exc:`~.commands.BadInviteArgument`
- :exc:`~.commands.EmojiNotFound` - :exc:`~.commands.EmojiNotFound`
- :exc:`~.commands.GuildStickerNotFound` - :exc:`~.commands.GuildStickerNotFound`
- :exc:`~.commands.ScheduledEventNotFound`
- :exc:`~.commands.PartialEmojiConversionFailure` - :exc:`~.commands.PartialEmojiConversionFailure`
- :exc:`~.commands.BadBoolArgument` - :exc:`~.commands.BadBoolArgument`
- :exc:`~.commands.ThreadNotFound` - :exc:`~.commands.ThreadNotFound`

5
docs/ext/commands/cogs.rst

@ -58,7 +58,7 @@ Once you have defined your cogs, you need to tell the bot to register the cogs t
.. code-block:: python3 .. code-block:: python3
bot.add_cog(Greetings(bot)) await bot.add_cog(Greetings(bot))
This binds the cog to the bot, adding all commands and listeners to the bot automatically. This binds the cog to the bot, adding all commands and listeners to the bot automatically.
@ -66,7 +66,7 @@ Note that we reference the cog by name, which we can override through :ref:`ext_
.. code-block:: python3 .. code-block:: python3
bot.remove_cog('Greetings') await bot.remove_cog('Greetings')
Using Cogs Using Cogs
------------- -------------
@ -112,6 +112,7 @@ As cogs get more complicated and have more commands, there comes a point where w
They are as follows: They are as follows:
- :meth:`.Cog.cog_load`
- :meth:`.Cog.cog_unload` - :meth:`.Cog.cog_unload`
- :meth:`.Cog.cog_check` - :meth:`.Cog.cog_check`
- :meth:`.Cog.cog_command_error` - :meth:`.Cog.cog_command_error`

19
docs/ext/commands/commands.rst

@ -11,6 +11,13 @@ how you can arbitrarily nest groups and commands to have a rich sub-command syst
Commands are defined by attaching it to a regular Python function. The command is then invoked by the user using a similar Commands are defined by attaching it to a regular Python function. The command is then invoked by the user using a similar
signature to the Python function. signature to the Python function.
.. warning::
You must have access to the :attr:`~discord.Intents.message_content` intent for the commands extension
to function. This must be set both in the developer portal and within your code.
Failure to do this will result in your bot not responding to any of your commands.
For example, in the given command definition: For example, in the given command definition:
.. code-block:: python3 .. code-block:: python3
@ -171,9 +178,9 @@ As seen earlier, every command must take at least a single parameter, called the
This parameter gives you access to something called the "invocation context". Essentially all the information you need to This parameter gives you access to something called the "invocation context". Essentially all the information you need to
know how the command was executed. It contains a lot of useful information: know how the command was executed. It contains a lot of useful information:
- :attr:`.Context.guild` to fetch the :class:`Guild` of the command, if any. - :attr:`.Context.guild` returns the :class:`Guild` of the command, if any.
- :attr:`.Context.message` to fetch the :class:`Message` of the command. - :attr:`.Context.message` returns the :class:`Message` of the command.
- :attr:`.Context.author` to fetch the :class:`Member` or :class:`User` that called the command. - :attr:`.Context.author` returns the :class:`Member` or :class:`User` that called the command.
- :meth:`.Context.send` to send a message to the channel the command was used in. - :meth:`.Context.send` to send a message to the channel the command was used in.
The context implements the :class:`abc.Messageable` interface, so anything you can do on a :class:`abc.Messageable` you The context implements the :class:`abc.Messageable` interface, so anything you can do on a :class:`abc.Messageable` you
@ -393,6 +400,8 @@ A lot of discord models work out of the gate as a parameter:
- :class:`Emoji` - :class:`Emoji`
- :class:`PartialEmoji` - :class:`PartialEmoji`
- :class:`Thread` (since v2.0) - :class:`Thread` (since v2.0)
- :class:`GuildSticker` (since v2.0)
- :class:`ScheduledEvent` (since v2.0)
Having any of these set as the converter will intelligently convert the argument to the appropriate target type you Having any of these set as the converter will intelligently convert the argument to the appropriate target type you
specify. specify.
@ -441,6 +450,10 @@ converter is given below:
+--------------------------+-------------------------------------------------+ +--------------------------+-------------------------------------------------+
| :class:`Thread` | :class:`~ext.commands.ThreadConverter` | | :class:`Thread` | :class:`~ext.commands.ThreadConverter` |
+--------------------------+-------------------------------------------------+ +--------------------------+-------------------------------------------------+
| :class:`GuildSticker` | :class:`~ext.commands.GuildStickerConverter` |
+--------------------------+-------------------------------------------------+
| :class:`ScheduledEvent` | :class:`~ext.commands.ScheduledEventConverter` |
+--------------------------+-------------------------------------------------+
By providing the converter it allows us to use them as building blocks for another converter: By providing the converter it allows us to use them as building blocks for another converter:

10
docs/ext/commands/extensions.rst

@ -24,10 +24,10 @@ An example extension looks like this:
async def hello(ctx): async def hello(ctx):
await ctx.send(f'Hello {ctx.author.display_name}.') await ctx.send(f'Hello {ctx.author.display_name}.')
def setup(bot): async def setup(bot):
bot.add_command(hello) bot.add_command(hello)
In this example we define a simple command, and when the extension is loaded this command is added to the bot. Now the final step to this is loading the extension, which we do by calling :meth:`.Bot.load_extension`. To load this extension we call ``bot.load_extension('hello')``. In this example we define a simple command, and when the extension is loaded this command is added to the bot. Now the final step to this is loading the extension, which we do by calling :meth:`.Bot.load_extension`. To load this extension we call ``await bot.load_extension('hello')``.
.. admonition:: Cogs .. admonition:: Cogs
:class: helpful :class: helpful
@ -45,7 +45,7 @@ When you make a change to the extension and want to reload the references, the l
.. code-block:: python3 .. code-block:: python3
>>> bot.reload_extension('hello') >>> await bot.reload_extension('hello')
Once the extension reloads, any changes that we did will be applied. This is useful if we want to add or remove functionality without restarting our bot. If an error occurred during the reloading process, the bot will pretend as if the reload never happened. Once the extension reloads, any changes that we did will be applied. This is useful if we want to add or remove functionality without restarting our bot. If an error occurred during the reloading process, the bot will pretend as if the reload never happened.
@ -57,8 +57,8 @@ Although rare, sometimes an extension needs to clean-up or know when it's being
.. code-block:: python3 .. code-block:: python3
:caption: basic_ext.py :caption: basic_ext.py
def setup(bot): async def setup(bot):
print('I am being loaded!') print('I am being loaded!')
def teardown(bot): async def teardown(bot):
print('I am being unloaded!') print('I am being unloaded!')

4
docs/faq.rst

@ -326,10 +326,6 @@ Quick example: ::
embed.set_image(url="attachment://image.png") embed.set_image(url="attachment://image.png")
await channel.send(file=file, embed=embed) await channel.send(file=file, embed=embed)
.. note ::
Due to a Discord limitation, filenames may not include underscores.
Is there an event for audit log entries being created? Is there an event for audit log entries being created?
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

1397
docs/migrating.rst

File diff suppressed because it is too large

1174
docs/migrating_to_v1.rst

File diff suppressed because it is too large

7
docs/quickstart.rst

@ -19,9 +19,14 @@ It looks something like this:
.. code-block:: python3 .. code-block:: python3
# This example requires the 'message_content' intent.
import discord import discord
client = discord.Client() intents = discord.Intents.default()
intents.message_content = True
client = discord.Client(intents=intents)
@client.event @client.event
async def on_ready(): async def on_ready():

1
examples/background_task.py

@ -9,6 +9,7 @@ class MyClient(discord.Client):
# an attribute we can access from our task # an attribute we can access from our task
self.counter = 0 self.counter = 0
async def setup_hook(self) -> None:
# start the task to run in the background # start the task to run in the background
self.my_background_task.start() self.my_background_task.start()

1
examples/background_task_asyncio.py

@ -5,6 +5,7 @@ class MyClient(discord.Client):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
async def setup_hook(self) -> None:
# create the background task and run it in the background # create the background task and run it in the background
self.bg_task = self.loop.create_task(self.my_background_task()) self.bg_task = self.loop.create_task(self.my_background_task())

8
examples/basic_voice.py

@ -131,5 +131,9 @@ async def on_ready():
print(f'Logged in as {bot.user} (ID: {bot.user.id})') print(f'Logged in as {bot.user} (ID: {bot.user.id})')
print('------') print('------')
bot.add_cog(Music(bot)) async def main():
bot.run('token') async with bot:
await bot.add_cog(Music(bot))
await bot.start('token')
asyncio.run(main())

69
examples/modal.py

@ -0,0 +1,69 @@
import discord
from discord import app_commands
import traceback
# Just default intents and a `discord.Client` instance
# We don't need a `commands.Bot` instance because we are not
# creating text-based commands.
intents = discord.Intents.default()
client = discord.Client(intents=intents)
# We need an `discord.app_commands.CommandTree` instance
# to register application commands (slash commands in this case)
tree = app_commands.CommandTree(client)
# The guild in which this slash command will be registered.
# As global commands can take up to an hour to propagate, it is ideal
# to test it in a guild.
TEST_GUILD = discord.Object(ID)
@client.event
async def on_ready():
print(f'Logged in as {client.user} (ID: {client.user.id})')
print('------')
# Sync the application command with Discord.
await tree.sync(guild=TEST_GUILD)
class Feedback(discord.ui.Modal, title='Feedback'):
# Our modal classes MUST subclass `discord.ui.Modal`,
# but the title can be whatever you want.
# This will be a short input, where the user can enter their name
# It will also have a placeholder, as denoted by the `placeholder` kwarg.
# By default, it is required and is a short-style input which is exactly
# what we want.
name = discord.ui.TextInput(
label='Name',
placeholder='Your name here...',
)
# This is a longer, paragraph style input, where user can submit feedback
# Unlike the name, it is not required. If filled out, however, it will
# only accept a maximum of 300 characters, as denoted by the
# `max_length=300` kwarg.
feedback = discord.ui.TextInput(
label='What do you think of this new feature?',
style=discord.TextStyle.long,
placeholder='Type your feedback here...',
required=False,
max_length=300,
)
async def on_submit(self, interaction: discord.Interaction):
await interaction.response.send_message(f'Thanks for your feedback, {self.name.value}!', ephemeral=True)
async def on_error(self, error: Exception, interaction: discord.Interaction) -> None:
await interaction.response.send_message('Oops! Something went wrong.', ephemeral=True)
# Make sure we know what the error actually is
traceback.print_tb(error.__traceback__)
@tree.command(guild=TEST_GUILD, description="Submit feedback")
async def feedback(interaction: discord.Interaction):
# Send the modal with an instance of our `Feedback` class
await interaction.response.send_modal(Feedback())
client.run('token')

2
pyproject.toml

@ -28,6 +28,7 @@ line_length = 125
[tool.pyright] [tool.pyright]
include = [ include = [
"discord", "discord",
"discord/app_commands",
"discord/types", "discord/types",
"discord/ext", "discord/ext",
"discord/ext/commands", "discord/ext/commands",
@ -39,6 +40,7 @@ exclude = [
"dist", "dist",
"docs", "docs",
] ]
reportUnnecessaryTypeIgnoreComment = "warning"
pythonVersion = "3.8" pythonVersion = "3.8"
typeCheckingMode = "basic" typeCheckingMode = "basic"

2
requirements.txt

@ -1 +1 @@
aiohttp>=3.6.0,<3.9.0 aiohttp>=3.6.0,<4

2
setup.py

@ -34,7 +34,7 @@ with open('README.rst') as f:
readme = f.read() readme = f.read()
extras_require = { extras_require = {
'voice': ['PyNaCl>=1.3.0,<1.5'], 'voice': ['PyNaCl>=1.3.0,<1.6'],
'docs': [ 'docs': [
'sphinx==4.4.0', 'sphinx==4.4.0',
'sphinxcontrib_trio==1.1.2', 'sphinxcontrib_trio==1.1.2',

103
tests/test_ext_tasks.py

@ -10,6 +10,7 @@ import asyncio
import datetime import datetime
import pytest import pytest
import sys
from discord import utils from discord import utils
from discord.ext import tasks from discord.ext import tasks
@ -75,3 +76,105 @@ async def test_explicit_initial_runs_tomorrow_multi():
assert not has_run assert not has_run
finally: finally:
loop.cancel() loop.cancel()
def test_task_regression_issue7659():
jst = datetime.timezone(datetime.timedelta(hours=9))
# 00:00, 03:00, 06:00, 09:00, 12:00, 15:00, 18:00, 21:00
times = [datetime.time(hour=h, tzinfo=jst) for h in range(0, 24, 3)]
@tasks.loop(time=times)
async def loop():
pass
before_midnight = datetime.datetime(2022, 3, 12, 23, 50, 59, tzinfo=jst)
after_midnight = before_midnight + datetime.timedelta(minutes=9, seconds=2)
expected_before_midnight = datetime.datetime(2022, 3, 13, 0, 0, 0, tzinfo=jst)
expected_after_midnight = datetime.datetime(2022, 3, 13, 3, 0, 0, tzinfo=jst)
assert loop._get_next_sleep_time(before_midnight) == expected_before_midnight
assert loop._get_next_sleep_time(after_midnight) == expected_after_midnight
today = datetime.date.today()
minute_before = [datetime.datetime.combine(today, time, tzinfo=jst) - datetime.timedelta(minutes=1) for time in times]
for before, expected_time in zip(minute_before, times):
expected = datetime.datetime.combine(today, expected_time, tzinfo=jst)
actual = loop._get_next_sleep_time(before)
assert actual == expected
def test_task_regression_issue7676():
jst = datetime.timezone(datetime.timedelta(hours=9))
# 00:00, 03:00, 06:00, 09:00, 12:00, 15:00, 18:00, 21:00
times = [datetime.time(hour=h, tzinfo=jst) for h in range(0, 24, 3)]
@tasks.loop(time=times)
async def loop():
pass
# Create pseudo UTC times
now = utils.utcnow()
today = now.date()
times_before_in_utc = [
datetime.datetime.combine(today, time, tzinfo=jst).astimezone(datetime.timezone.utc) - datetime.timedelta(minutes=1)
for time in times
]
for before, expected_time in zip(times_before_in_utc, times):
actual = loop._get_next_sleep_time(before)
actual_time = actual.timetz()
assert actual_time == expected_time
@pytest.mark.skipif(sys.version_info < (3, 9), reason="zoneinfo requires 3.9")
def test_task_is_imaginary():
import zoneinfo
tz = zoneinfo.ZoneInfo('America/New_York')
# 2:30 AM was skipped
dt = datetime.datetime(2022, 3, 13, 2, 30, tzinfo=tz)
assert tasks.is_imaginary(dt)
now = utils.utcnow()
# UTC time is never imaginary or ambiguous
assert not tasks.is_imaginary(now)
@pytest.mark.skipif(sys.version_info < (3, 9), reason="zoneinfo requires 3.9")
def test_task_is_ambiguous():
import zoneinfo
tz = zoneinfo.ZoneInfo('America/New_York')
# 1:30 AM happened twice
dt = datetime.datetime(2022, 11, 6, 1, 30, tzinfo=tz)
assert tasks.is_ambiguous(dt)
now = utils.utcnow()
# UTC time is never imaginary or ambiguous
assert not tasks.is_imaginary(now)
@pytest.mark.skipif(sys.version_info < (3, 9), reason="zoneinfo requires 3.9")
@pytest.mark.parametrize(
('dt', 'key', 'expected'),
[
(datetime.datetime(2022, 11, 6, 1, 30), 'America/New_York', datetime.datetime(2022, 11, 6, 1, 30, fold=1)),
(datetime.datetime(2022, 3, 13, 2, 30), 'America/New_York', datetime.datetime(2022, 3, 13, 3, 30)),
(datetime.datetime(2022, 4, 8, 2, 30), 'America/New_York', datetime.datetime(2022, 4, 8, 2, 30)),
(datetime.datetime(2023, 1, 7, 12, 30), 'UTC', datetime.datetime(2023, 1, 7, 12, 30)),
],
)
def test_task_date_resolve(dt, key, expected):
import zoneinfo
tz = zoneinfo.ZoneInfo(key)
actual = tasks.resolve_datetime(dt.replace(tzinfo=tz))
expected = expected.replace(tzinfo=tz)
assert actual == expected

Loading…
Cancel
Save