Browse Source

Rebase to upstream

pull/10109/head
dolfies 3 years ago
parent
commit
c932b5d06b
  1. 44
      .github/workflows/crowdin_download.yml
  2. 44
      .github/workflows/crowdin_upload.yml
  3. 3
      discord/__init__.py
  4. 38
      discord/__main__.py
  5. 69
      discord/abc.py
  6. 59
      discord/activity.py
  7. 69
      discord/asset.py
  8. 36
      discord/audit_logs.py
  9. 60
      discord/channel.py
  10. 265
      discord/client.py
  11. 13
      discord/colour.py
  12. 10
      discord/components.py
  13. 10
      discord/context_managers.py
  14. 190
      discord/embeds.py
  15. 10
      discord/emoji.py
  16. 45
      discord/enums.py
  17. 24
      discord/ext/commands/_types.py
  18. 334
      discord/ext/commands/bot.py
  19. 134
      discord/ext/commands/cog.py
  20. 16
      discord/ext/commands/context.py
  21. 153
      discord/ext/commands/converter.py
  22. 4
      discord/ext/commands/cooldowns.py
  23. 195
      discord/ext/commands/core.py
  24. 33
      discord/ext/commands/errors.py
  25. 24
      discord/ext/commands/flags.py
  26. 585
      discord/ext/commands/help.py
  27. 37
      discord/ext/commands/view.py
  28. 186
      discord/ext/tasks/__init__.py
  29. 2
      discord/file.py
  30. 21
      discord/flags.py
  31. 84
      discord/gateway.py
  32. 141
      discord/guild.py
  33. 75
      discord/http.py
  34. 21
      discord/integrations.py
  35. 59
      discord/invite.py
  36. 58
      discord/member.py
  37. 10
      discord/mentions.py
  38. 2279
      discord/message.py
  39. 5
      discord/opus.py
  40. 15
      discord/partial_emoji.py
  41. 19
      discord/permissions.py
  42. 31
      discord/player.py
  43. 10
      discord/reaction.py
  44. 8
      discord/role.py
  45. 78
      discord/scheduled_event.py
  46. 22
      discord/stage_instance.py
  47. 74
      discord/state.py
  48. 4
      discord/template.py
  49. 70
      discord/threads.py
  50. 1
      discord/types/activity.py
  51. 1
      discord/types/channel.py
  52. 242
      discord/types/interactions.py
  53. 4
      discord/types/scheduled_event.py
  54. 14
      discord/types/widget.py
  55. 10
      discord/user.py
  56. 53
      discord/utils.py
  57. 25
      discord/voice_client.py
  58. 337
      discord/webhook/async_.py
  59. 148
      discord/webhook/sync.py
  60. 31
      discord/widget.py
  61. 16
      docs/_static/custom.js
  62. 18
      docs/_static/style.css
  63. 12
      docs/_templates/layout.html
  64. 1008
      docs/api.rst
  65. 2
      docs/conf.py
  66. 21
      docs/crowdin.yml
  67. 7
      docs/ext/commands/api.rst
  68. 5
      docs/ext/commands/cogs.rst
  69. 19
      docs/ext/commands/commands.rst
  70. 10
      docs/ext/commands/extensions.rst
  71. 4
      docs/faq.rst
  72. 1397
      docs/migrating.rst
  73. 1174
      docs/migrating_to_v1.rst
  74. 7
      docs/quickstart.rst
  75. 1
      examples/background_task.py
  76. 1
      examples/background_task_asyncio.py
  77. 8
      examples/basic_voice.py
  78. 69
      examples/modal.py
  79. 2
      pyproject.toml
  80. 2
      requirements.txt
  81. 2
      setup.py
  82. 103
      tests/test_ext_tasks.py

44
.github/workflows/crowdin_download.yml

@ -0,0 +1,44 @@
name: crowdin download
on:
schedule:
- cron: '0 18 * * 1'
workflow_dispatch:
jobs:
download:
runs-on: ubuntu-latest
environment: Crowdin
name: download
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
ref: master
- name: Install system dependencies
run: |
wget -qO - https://artifacts.crowdin.com/repo/GPG-KEY-crowdin | sudo apt-key add -
echo "deb https://artifacts.crowdin.com/repo/deb/ /" | sudo tee -a /etc/apt/sources.list.d/crowdin.list
sudo apt-get update -qq
sudo apt-get install -y crowdin3
- name: Download translations
shell: bash
run: |
cd docs
crowdin download --all
env:
CROWDIN_API_KEY: ${{ secrets.CROWDIN_API_KEY }}
- name: Create pull request
id: cpr_crowdin
uses: peter-evans/create-pull-request@v3
with:
token: ${{ secrets.GITHUB_TOKEN }}
commit-message: Crowdin translations download
title: "[Crowdin] Updated translation files"
body: |
Created by the [Crowdin download workflow](.github/workflows/crowdin_download.yml).
branch: "auto/crowdin"
author: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

44
.github/workflows/crowdin_upload.yml

@ -0,0 +1,44 @@
name: crowdin upload
on:
workflow_dispatch:
jobs:
upload:
runs-on: ubuntu-latest
environment: Crowdin
name: upload
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Set up CPython 3.x
uses: actions/setup-python@v2
with:
python-version: 3.x
- name: Install system dependencies
run: |
wget -qO - https://artifacts.crowdin.com/repo/GPG-KEY-crowdin | sudo apt-key add -
echo "deb https://artifacts.crowdin.com/repo/deb/ /" | sudo tee -a /etc/apt/sources.list.d/crowdin.list
sudo apt-get update -qq
sudo apt-get install -y crowdin3
- name: Install Python dependencies
run: |
python -m pip install --upgrade pip setuptools wheel
pip install -e .[docs,speed,voice]
- name: Build gettext
run: |
cd docs
make gettext
- name: Upload sources
shell: bash
run: |
cd docs
crowdin upload
env:
CROWDIN_API_KEY: ${{ secrets.CROWDIN_API_KEY }}

3
discord/__init__.py

@ -77,3 +77,6 @@ class _VersionInfo(NamedTuple):
version_info: _VersionInfo = _VersionInfo(major=2, minor=0, micro=0, releaselevel='alpha', serial=2)
logging.getLogger(__name__).addHandler(logging.NullHandler())
del logging, NamedTuple, Literal, _VersionInfo

38
discord/__main__.py

@ -22,6 +22,10 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import Optional, Tuple, Dict
import argparse
import sys
from pathlib import Path
@ -32,7 +36,7 @@ import aiohttp
import platform
def show_version():
def show_version() -> None:
entries = []
entries.append('- Python v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}'.format(sys.version_info))
@ -49,7 +53,7 @@ def show_version():
print('\n'.join(entries))
def core(parser, args):
def core(parser: argparse.ArgumentParser, args: argparse.Namespace) -> None:
if args.version:
show_version()
@ -63,9 +67,11 @@ import config
class Bot(commands.Bot):
def __init__(self, **kwargs):
super().__init__(command_prefix=commands.when_mentioned_or('{prefix}'), **kwargs)
async def setup_hook(self):
for cog in config.cogs:
try:
self.load_extension(cog)
await self.load_extension(cog)
except Exception as exc:
print(f'Could not load extension {{cog}} due to {{exc.__class__.__name__}}: {{exc}}')
@ -119,12 +125,16 @@ class {name}(commands.Cog{attrs}):
def __init__(self, bot):
self.bot = bot
{extra}
def setup(bot):
bot.add_cog({name}(bot))
async def setup(bot):
await bot.add_cog({name}(bot))
'''
_cog_extras = '''
def cog_unload(self):
async def cog_load(self):
# loading logic goes here
pass
async def cog_unload(self):
# clean up logic goes here
pass
@ -158,7 +168,7 @@ _cog_extras = '''
# certain file names and directory names are forbidden
# see: https://msdn.microsoft.com/en-us/library/windows/desktop/aa365247%28v=vs.85%29.aspx
# although some of this doesn't apply to Linux, we might as well be consistent
_base_table = {
_base_table: Dict[str, Optional[str]] = {
'<': '-',
'>': '-',
':': '-',
@ -176,7 +186,7 @@ _base_table.update((chr(i), None) for i in range(32))
_translation_table = str.maketrans(_base_table)
def to_path(parser, name, *, replace_spaces=False):
def to_path(parser: argparse.ArgumentParser, name: str, *, replace_spaces: bool = False) -> Path:
if isinstance(name, Path):
return name
@ -214,7 +224,7 @@ def to_path(parser, name, *, replace_spaces=False):
return Path(name)
def newbot(parser, args):
def newbot(parser: argparse.ArgumentParser, args: argparse.Namespace) -> None:
new_directory = to_path(parser, args.directory) / to_path(parser, args.name)
# as a note exist_ok for Path is a 3.5+ only feature
@ -255,7 +265,7 @@ def newbot(parser, args):
print('successfully made bot at', new_directory)
def newcog(parser, args):
def newcog(parser: argparse.ArgumentParser, args: argparse.Namespace) -> None:
cog_dir = to_path(parser, args.directory)
try:
cog_dir.mkdir(exist_ok=True)
@ -289,7 +299,7 @@ def newcog(parser, args):
print('successfully made cog at', directory)
def add_newbot_args(subparser):
def add_newbot_args(subparser: argparse._SubParsersAction[argparse.ArgumentParser]) -> None:
parser = subparser.add_parser('newbot', help='creates a command bot project quickly')
parser.set_defaults(func=newbot)
@ -299,7 +309,7 @@ def add_newbot_args(subparser):
parser.add_argument('--no-git', help='do not create a .gitignore file', action='store_true', dest='no_git')
def add_newcog_args(subparser):
def add_newcog_args(subparser: argparse._SubParsersAction[argparse.ArgumentParser]) -> None:
parser = subparser.add_parser('newcog', help='creates a new cog template quickly')
parser.set_defaults(func=newcog)
@ -311,7 +321,7 @@ def add_newcog_args(subparser):
parser.add_argument('--full', help='add all special methods as well', action='store_true')
def parse_args():
def parse_args() -> Tuple[argparse.ArgumentParser, argparse.Namespace]:
parser = argparse.ArgumentParser(prog='discord', description='Tools for helping with discord.py')
parser.add_argument('-v', '--version', action='store_true', help='shows the library version')
parser.set_defaults(func=core)
@ -322,7 +332,7 @@ def parse_args():
return parser, parser.parse_args()
def main():
def main() -> None:
parser, args = parse_args()
args.func(parser, args)

69
discord/abc.py

@ -90,6 +90,9 @@ if TYPE_CHECKING:
GuildChannel as GuildChannelPayload,
OverwriteType,
)
from .types.snowflake import (
SnowflakeList,
)
PartialMessageableChannel = Union[TextChannel, Thread, DMChannel, PartialMessageable]
MessageableChannel = Union[PartialMessageableChannel, GroupChannel]
@ -725,7 +728,14 @@ class GuildChannel:
) -> None:
...
async def set_permissions(self, target, *, overwrite=_undefined, reason=None, **permissions):
async def set_permissions(
self,
target: Union[Member, Role],
*,
overwrite: Any = _undefined,
reason: Optional[str] = None,
**permissions: bool,
) -> None:
r"""|coro|
Sets the channel specific permission overwrites for a target in the
@ -769,8 +779,8 @@ class GuildChannel:
await channel.set_permissions(member, overwrite=overwrite)
.. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising
:exc:`TypeError`.
This function will now raise :exc:`TypeError` instead of
``InvalidArgument``.
Parameters
@ -934,7 +944,7 @@ class GuildChannel:
) -> None:
...
async def move(self, **kwargs) -> None:
async def move(self, **kwargs: Any) -> None:
"""|coro|
A rich interface to help move a channel relative to other channels.
@ -952,8 +962,8 @@ class GuildChannel:
.. versionadded:: 1.7
.. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising
:exc:`ValueError` or :exc:`TypeError` in various cases.
This function will now raise :exc:`TypeError` or
:exc:`ValueError` instead of ``InvalidArgument``.
Parameters
------------
@ -1210,7 +1220,7 @@ class Messageable:
content: Optional[str] = ...,
*,
tts: bool = ...,
files: List[File] = ...,
files: Sequence[File] = ...,
stickers: Sequence[Union[GuildSticker, StickerItem]] = ...,
delete_after: float = ...,
nonce: Union[str, int] = ...,
@ -1244,7 +1254,7 @@ class Messageable:
content: Optional[str] = ...,
*,
tts: bool = ...,
files: List[File] = ...,
files: Sequence[File] = ...,
stickers: Sequence[Union[GuildSticker, StickerItem]] = ...,
delete_after: float = ...,
nonce: Union[str, int] = ...,
@ -1257,19 +1267,19 @@ class Messageable:
async def send(
self,
content=None,
content: Optional[str] = None,
*,
tts=False,
file=None,
files=None,
stickers=None,
delete_after=None,
nonce=MISSING,
allowed_mentions=None,
reference=None,
mention_author=None,
suppress_embeds=False,
):
tts: bool = False,
file: Optional[File] = None,
files: Optional[Sequence[File]] = None,
stickers: Optional[Sequence[Union[GuildSticker, StickerItem]]] = None,
delete_after: Optional[float] = None,
nonce: Optional[Union[str, int]] = MISSING,
allowed_mentions: Optional[AllowedMentions] = None,
reference: Optional[Union[Message, MessageReference, PartialMessage]] = None,
mention_author: Optional[bool] = None,
suppress_embeds: bool = False,
) -> Message:
"""|coro|
Sends a message to the destination with the content given.
@ -1283,8 +1293,8 @@ class Messageable:
**Specifying both parameters will lead to an exception**.
.. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising
:exc:`ValueError` or :exc:`TypeError` in various cases.
This function will now raise :exc:`TypeError` or
:exc:`ValueError` instead of ``InvalidArgument``.
Parameters
------------
@ -1362,17 +1372,17 @@ class Messageable:
nonce = str(utils.time_snowflake(datetime.utcnow()))
if stickers is not None:
stickers = [sticker.id for sticker in stickers]
sticker_ids: SnowflakeList = [sticker.id for sticker in stickers]
else:
stickers = MISSING
sticker_ids = MISSING
if reference is not None:
try:
reference = reference.to_message_reference_dict()
reference_dict = reference.to_message_reference_dict()
except AttributeError:
raise TypeError('reference parameter must be Message, MessageReference, or PartialMessage') from None
else:
reference = MISSING
reference_dict = MISSING
if suppress_embeds:
from .message import MessageFlags # circular import
@ -1388,10 +1398,10 @@ class Messageable:
files=files if files is not None else MISSING,
nonce=nonce,
allowed_mentions=allowed_mentions,
message_reference=reference,
message_reference=reference_dict,
previous_allowed_mentions=previous_allowed_mention,
mention_author=mention_author,
stickers=stickers,
stickers=sticker_ids,
flags=flags,
) as params:
data = await state.http.send_message(channel.id, params=params)
@ -1823,7 +1833,8 @@ class Connectable(Protocol):
if cls is MISSING:
cls = VoiceClient
voice = cls(state.client, channel)
# The type checker doesn't understand that VoiceClient *is* T here.
voice: T = cls(state.client, channel) # type: ignore
if not isinstance(voice, VoiceProtocol):
raise TypeError('Type must meet VoiceProtocol abstract base class')

59
discord/activity.py

@ -99,6 +99,8 @@ if TYPE_CHECKING:
ActivityButton,
)
from .state import ConnectionState
class BaseActivity:
"""The base activity that all user-settable activities inherit from.
@ -121,7 +123,7 @@ class BaseActivity:
__slots__ = ('_created_at',)
def __init__(self, **kwargs):
def __init__(self, **kwargs: Any) -> None:
self._created_at: Optional[float] = kwargs.pop('created_at', None)
@property
@ -216,7 +218,7 @@ class Activity(BaseActivity):
'buttons',
)
def __init__(self, **kwargs):
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.state: Optional[str] = kwargs.pop('state', None)
self.details: Optional[str] = kwargs.pop('details', None)
@ -377,7 +379,7 @@ class Game(BaseActivity):
__slots__ = ('name', '_end', '_start')
def __init__(self, name: str, **extra):
def __init__(self, name: str, **extra: Any) -> None:
super().__init__(**extra)
self.name: str = name
@ -434,10 +436,10 @@ class Game(BaseActivity):
}
# fmt: on
def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
return isinstance(other, Game) and other.name == self.name
def __ne__(self, other: Any) -> bool:
def __ne__(self, other: object) -> bool:
return not self.__eq__(other)
def __hash__(self) -> int:
@ -491,7 +493,7 @@ class Streaming(BaseActivity):
__slots__ = ('platform', 'name', 'game', 'url', 'details', 'assets')
def __init__(self, *, name: Optional[str], url: str, **extra: Any):
def __init__(self, *, name: Optional[str], url: str, **extra: Any) -> None:
super().__init__(**extra)
self.platform: Optional[str] = name
self.name: Optional[str] = extra.pop('details', name)
@ -515,7 +517,7 @@ class Streaming(BaseActivity):
return f'<Streaming name={self.name!r}>'
@property
def twitch_name(self):
def twitch_name(self) -> Optional[str]:
"""Optional[:class:`str`]: If provided, the twitch name of the user streaming.
This corresponds to the ``large_image`` key of the :attr:`Streaming.assets`
@ -542,10 +544,10 @@ class Streaming(BaseActivity):
ret['details'] = self.details
return ret
def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
return isinstance(other, Streaming) and other.name == self.name and other.url == self.url
def __ne__(self, other: Any) -> bool:
def __ne__(self, other: object) -> bool:
return not self.__eq__(other)
def __hash__(self) -> int:
@ -577,14 +579,14 @@ class Spotify:
__slots__ = ('_state', '_details', '_timestamps', '_assets', '_party', '_sync_id', '_session_id', '_created_at')
def __init__(self, **data):
def __init__(self, **data: Any) -> None:
self._state: str = data.pop('state', '')
self._details: str = data.pop('details', '')
self._timestamps: Dict[str, int] = data.pop('timestamps', {})
self._timestamps: ActivityTimestamps = data.pop('timestamps', {})
self._assets: ActivityAssets = data.pop('assets', {})
self._party: ActivityParty = data.pop('party', {})
self._sync_id: str = data.pop('sync_id')
self._session_id: str = data.pop('session_id')
self._sync_id: str = data.pop('sync_id', '')
self._session_id: Optional[str] = data.pop('session_id')
self._created_at: Optional[float] = data.pop('created_at', None)
@property
@ -636,7 +638,7 @@ class Spotify:
""":class:`str`: The activity's name. This will always return "Spotify"."""
return 'Spotify'
def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
return (
isinstance(other, Spotify)
and other._session_id == self._session_id
@ -644,7 +646,7 @@ class Spotify:
and other.start == self.start
)
def __ne__(self, other: Any) -> bool:
def __ne__(self, other: object) -> bool:
return not self.__eq__(other)
def __hash__(self) -> int:
@ -705,12 +707,14 @@ class Spotify:
@property
def start(self) -> datetime.datetime:
""":class:`datetime.datetime`: When the user started playing this song in UTC."""
return datetime.datetime.fromtimestamp(self._timestamps['start'] / 1000, tz=datetime.timezone.utc)
# the start key will be present here
return datetime.datetime.fromtimestamp(self._timestamps['start'] / 1000, tz=datetime.timezone.utc) # type: ignore
@property
def end(self) -> datetime.datetime:
""":class:`datetime.datetime`: When the user will stop playing this song in UTC."""
return datetime.datetime.fromtimestamp(self._timestamps['end'] / 1000, tz=datetime.timezone.utc)
# the end key will be present here
return datetime.datetime.fromtimestamp(self._timestamps['end'] / 1000, tz=datetime.timezone.utc) # type: ignore
@property
def duration(self) -> datetime.timedelta:
@ -820,10 +824,10 @@ class CustomActivity(BaseActivity):
o['expires_at'] = expiry.isoformat()
return o
def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
return isinstance(other, CustomActivity) and other.name == self.name and other.emoji == self.emoji
def __ne__(self, other: Any) -> bool:
def __ne__(self, other: object) -> bool:
return not self.__eq__(other)
def __hash__(self) -> int:
@ -845,16 +849,16 @@ ActivityTypes = Union[Activity, Game, CustomActivity, Streaming, Spotify]
@overload
def create_activity(data: ActivityPayload) -> ActivityTypes:
def create_activity(data: ActivityPayload, state: ConnectionState) -> ActivityTypes:
...
@overload
def create_activity(data: None) -> None:
def create_activity(data: None, state: ConnectionState) -> None:
...
def create_activity(data: Optional[ActivityPayload]) -> Optional[ActivityTypes]:
def create_activity(data: Optional[ActivityPayload], state: ConnectionState) -> Optional[ActivityTypes]:
if not data:
return None
@ -867,10 +871,10 @@ def create_activity(data: Optional[ActivityPayload]) -> Optional[ActivityTypes]:
try:
name = data.pop('name')
except KeyError:
return Activity(**data)
ret = Activity(**data)
else:
# We removed the name key from data already
return CustomActivity(name=name, **data) # type: ignore
ret = CustomActivity(name=name, **data) # type: ignore
elif game_type is ActivityType.streaming:
if 'url' in data:
# The url won't be None here
@ -878,7 +882,12 @@ def create_activity(data: Optional[ActivityPayload]) -> Optional[ActivityTypes]:
return Activity(**data)
elif game_type is ActivityType.listening and 'sync_id' in data and 'session_id' in data:
return Spotify(**data)
return Activity(**data)
else:
ret = Activity(**data)
if isinstance(ret.emoji, PartialEmoji):
ret.emoji._state = state
return ret
def create_settings_activity(*, data, state):

69
discord/asset.py

@ -39,6 +39,13 @@ __all__ = (
# fmt: on
if TYPE_CHECKING:
from typing_extensions import Self
from .state import ConnectionState
from .webhook.async_ import _WebhookState
_State = Union[ConnectionState, _WebhookState]
ValidStaticFormatTypes = Literal['webp', 'jpeg', 'jpg', 'png']
ValidAssetFormatTypes = Literal['webp', 'jpeg', 'jpg', 'png', 'gif']
@ -77,7 +84,7 @@ class AssetMixin:
return await self._state.http.get_from_cdn(self.url)
async def save(self, fp: Union[str, bytes, os.PathLike, io.BufferedIOBase], *, seek_begin: bool = True) -> int:
async def save(self, fp: Union[str, bytes, os.PathLike[Any], io.BufferedIOBase], *, seek_begin: bool = True) -> int:
"""|coro|
Saves this asset into a file-like object.
@ -153,14 +160,14 @@ class Asset(AssetMixin):
BASE = 'https://cdn.discordapp.com'
def __init__(self, state, *, url: str, key: str, animated: bool = False):
self._state = state
self._url = url
self._animated = animated
self._key = key
def __init__(self, state: _State, *, url: str, key: str, animated: bool = False) -> None:
self._state: _State = state
self._url: str = url
self._animated: bool = animated
self._key: str = key
@classmethod
def _from_default_avatar(cls, state, index: int) -> Asset:
def _from_default_avatar(cls, state: _State, index: int) -> Self:
return cls(
state,
url=f'{cls.BASE}/embed/avatars/{index}.png',
@ -169,7 +176,7 @@ class Asset(AssetMixin):
)
@classmethod
def _from_avatar(cls, state, user_id: int, avatar: str) -> Asset:
def _from_avatar(cls, state: _State, user_id: int, avatar: str) -> Self:
animated = avatar.startswith('a_')
format = 'gif' if animated else 'png'
return cls(
@ -180,7 +187,7 @@ class Asset(AssetMixin):
)
@classmethod
def _from_guild_avatar(cls, state, guild_id: int, member_id: int, avatar: str) -> Asset:
def _from_guild_avatar(cls, state: _State, guild_id: int, member_id: int, avatar: str) -> Self:
animated = avatar.startswith('a_')
format = 'gif' if animated else 'png'
return cls(
@ -191,7 +198,7 @@ class Asset(AssetMixin):
)
@classmethod
def _from_icon(cls, state, object_id: int, icon_hash: str, path: str) -> Asset:
def _from_icon(cls, state: _State, object_id: int, icon_hash: str, path: str) -> Self:
return cls(
state,
url=f'{cls.BASE}/{path}-icons/{object_id}/{icon_hash}.png?size=1024',
@ -200,7 +207,7 @@ class Asset(AssetMixin):
)
@classmethod
def _from_cover_image(cls, state, object_id: int, cover_image_hash: str) -> Asset:
def _from_cover_image(cls, state: _State, object_id: int, cover_image_hash: str) -> Self:
return cls(
state,
url=f'{cls.BASE}/app-assets/{object_id}/store/{cover_image_hash}.png?size=1024',
@ -209,7 +216,7 @@ class Asset(AssetMixin):
)
@classmethod
def _from_scheduled_event_cover_image(cls, state, scheduled_event_id: int, cover_image_hash: str) -> Asset:
def _from_scheduled_event_cover_image(cls, state: _State, scheduled_event_id: int, cover_image_hash: str) -> Self:
return cls(
state,
url=f'{cls.BASE}/guild-events/{scheduled_event_id}/{cover_image_hash}.png?size=1024',
@ -218,7 +225,7 @@ class Asset(AssetMixin):
)
@classmethod
def _from_guild_image(cls, state, guild_id: int, image: str, path: str) -> Asset:
def _from_guild_image(cls, state: _State, guild_id: int, image: str, path: str) -> Self:
animated = image.startswith('a_')
format = 'gif' if animated else 'png'
return cls(
@ -229,7 +236,7 @@ class Asset(AssetMixin):
)
@classmethod
def _from_guild_icon(cls, state, guild_id: int, icon_hash: str) -> Asset:
def _from_guild_icon(cls, state: _State, guild_id: int, icon_hash: str) -> Self:
animated = icon_hash.startswith('a_')
format = 'gif' if animated else 'png'
return cls(
@ -240,7 +247,7 @@ class Asset(AssetMixin):
)
@classmethod
def _from_sticker_banner(cls, state, banner: int) -> Asset:
def _from_sticker_banner(cls, state: _State, banner: int) -> Self:
return cls(
state,
url=f'{cls.BASE}/app-assets/710982414301790216/store/{banner}.png',
@ -249,7 +256,7 @@ class Asset(AssetMixin):
)
@classmethod
def _from_user_banner(cls, state, user_id: int, banner_hash: str) -> Asset:
def _from_user_banner(cls, state: _State, user_id: int, banner_hash: str) -> Self:
animated = banner_hash.startswith('a_')
format = 'gif' if animated else 'png'
return cls(
@ -274,14 +281,14 @@ class Asset(AssetMixin):
def __len__(self) -> int:
return len(self._url)
def __repr__(self):
def __repr__(self) -> str:
shorten = self._url.replace(self.BASE, '')
return f'<Asset url={shorten!r}>'
def __eq__(self, other):
def __eq__(self, other: object) -> bool:
return isinstance(other, Asset) and self._url == other._url
def __hash__(self):
def __hash__(self) -> int:
return hash(self._url)
@property
@ -304,12 +311,12 @@ class Asset(AssetMixin):
size: int = MISSING,
format: ValidAssetFormatTypes = MISSING,
static_format: ValidStaticFormatTypes = MISSING,
) -> Asset:
) -> Self:
"""Returns a new asset with the passed components replaced.
.. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising
:exc:`ValueError`.
This function will now raise :exc:`ValueError` instead of
``InvalidArgument``.
Parameters
-----------
@ -359,12 +366,12 @@ class Asset(AssetMixin):
url = str(url)
return Asset(state=self._state, url=url, key=self._key, animated=self._animated)
def with_size(self, size: int, /) -> Asset:
def with_size(self, size: int, /) -> Self:
"""Returns a new asset with the specified size.
.. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising
:exc:`ValueError`.
This function will now raise :exc:`ValueError` instead of
``InvalidArgument``.
Parameters
------------
@ -387,12 +394,12 @@ class Asset(AssetMixin):
url = str(yarl.URL(self._url).with_query(size=size))
return Asset(state=self._state, url=url, key=self._key, animated=self._animated)
def with_format(self, format: ValidAssetFormatTypes, /) -> Asset:
def with_format(self, format: ValidAssetFormatTypes, /) -> Self:
"""Returns a new asset with the specified format.
.. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising
:exc:`ValueError`.
This function will now raise :exc:`ValueError` instead of
``InvalidArgument``.
Parameters
------------
@ -422,15 +429,15 @@ class Asset(AssetMixin):
url = str(url.with_path(f'{path}.{format}').with_query(url.raw_query_string))
return Asset(state=self._state, url=url, key=self._key, animated=self._animated)
def with_static_format(self, format: ValidStaticFormatTypes, /) -> Asset:
def with_static_format(self, format: ValidStaticFormatTypes, /) -> Self:
"""Returns a new asset with the specified static format.
This only changes the format if the underlying asset is
not animated. Otherwise, the asset is not changed.
.. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising
:exc:`ValueError`.
This function will now raise :exc:`ValueError` instead of
``InvalidArgument``.
Parameters
------------

36
discord/audit_logs.py

@ -49,12 +49,13 @@ if TYPE_CHECKING:
from .guild import Guild
from .member import Member
from .role import Role
from .scheduled_event import ScheduledEvent
from .state import ConnectionState
from .types.audit_log import (
AuditLogChange as AuditLogChangePayload,
AuditLogEntry as AuditLogEntryPayload,
)
from .types.channel import (
PartialChannel as PartialChannelPayload,
PermissionOverwrite as PermissionOverwritePayload,
)
from .types.invite import Invite as InvitePayload
@ -241,8 +242,8 @@ class AuditLogChanges:
# fmt: on
def __init__(self, entry: AuditLogEntry, data: List[AuditLogChangePayload]):
self.before = AuditLogDiff()
self.after = AuditLogDiff()
self.before: AuditLogDiff = AuditLogDiff()
self.after: AuditLogDiff = AuditLogDiff()
for elem in data:
attr = elem['key']
@ -389,16 +390,17 @@ class AuditLogEntry(Hashable):
"""
def __init__(self, *, users: Dict[int, User], data: AuditLogEntryPayload, guild: Guild):
self._state = guild._state
self.guild = guild
self._users = users
self._state: ConnectionState = guild._state
self.guild: Guild = guild
self._users: Dict[int, User] = users
self._from_data(data)
def _from_data(self, data: AuditLogEntryPayload) -> None:
self.action = enums.try_enum(enums.AuditLogAction, data['action_type'])
self.id = int(data['id'])
self.action: enums.AuditLogAction = enums.try_enum(enums.AuditLogAction, data['action_type'])
self.id: int = int(data['id'])
self.reason = data.get('reason')
# This key is technically not usually present
self.reason: Optional[str] = data.get('reason')
extra = data.get('options')
# fmt: off
@ -462,10 +464,13 @@ class AuditLogEntry(Hashable):
self._changes = data.get('changes', [])
user_id = utils._get_as_snowflake(data, 'user_id')
self.user = user_id and self._get_member(user_id)
self.user: Optional[Union[User, Member]] = self._get_member(user_id)
self._target_id = utils._get_as_snowflake(data, 'target_id')
def _get_member(self, user_id: int) -> Union[Member, User, None]:
def _get_member(self, user_id: Optional[int]) -> Union[Member, User, None]:
if user_id is None:
return None
return self.guild.get_member(user_id) or self._users.get(user_id)
def __repr__(self) -> str:
@ -478,12 +483,14 @@ class AuditLogEntry(Hashable):
@utils.cached_property
def target(self) -> TargetType:
if self._target_id is None or self.action.target_type is None:
if self.action.target_type is None:
return None
try:
converter = getattr(self, '_convert_target_' + self.action.target_type)
except AttributeError:
if self._target_id is None:
return None
return Object(id=self._target_id)
else:
return converter(self._target_id)
@ -522,7 +529,7 @@ class AuditLogEntry(Hashable):
def _convert_target_role(self, target_id: int) -> Union[Role, Object]:
return self.guild.get_role(target_id) or Object(id=target_id)
def _convert_target_invite(self, target_id: int) -> Invite:
def _convert_target_invite(self, target_id: None) -> Invite:
# Invites have target_id set to null
# So figure out which change has the full invite data
changeset = self.before if self.action is enums.AuditLogAction.invite_delete else self.after
@ -557,3 +564,6 @@ class AuditLogEntry(Hashable):
def _convert_target_thread(self, target_id: int) -> Union[Thread, Object]:
return self.guild.get_thread(target_id) or Object(id=target_id)
def _convert_target_guild_scheduled_event(self, target_id: int) -> Union[ScheduledEvent, Object]:
return self.guild.get_scheduled_event(target_id) or Object(id=target_id)

60
discord/channel.py

@ -46,7 +46,6 @@ from .permissions import PermissionOverwrite, Permissions
from .enums import ChannelType, PrivacyLevel, try_enum, VideoQualityMode
from .calls import PrivateCall, GroupCall
from .mixins import Hashable
from .object import Object
from . import utils
from .utils import MISSING
from .asset import Asset
@ -194,7 +193,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
self.last_message_id: Optional[int] = utils._get_as_snowflake(data, 'last_message_id')
self._fill_overwrites(data)
async def _get_channel(self):
async def _get_channel(self) -> Self:
return self
@property
@ -279,7 +278,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
async def edit(self) -> Optional[TextChannel]:
...
async def edit(self, *, reason=None, **options):
async def edit(self, *, reason: Optional[str] = None, **options: Any) -> Optional[TextChannel]:
"""|coro|
Edits the channel.
@ -297,8 +296,8 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
Edits are no longer in-place, the newly edited channel is returned instead.
.. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising
:exc:`ValueError` or :exc:`TypeError` in various cases.
This function will now raise :exc:`TypeError` or
:exc:`ValueError` instead of ``InvalidArgument``.
Parameters
----------
@ -574,8 +573,8 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
.. versionadded:: 1.3
.. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising
:exc:`TypeError`.
This function will now raise :exc:`TypeError` instead of
``InvalidArgument``.
Parameters
-----------
@ -696,7 +695,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
reason: Optional[:class:`str`]
The reason for creating a new thread. Shows up on the audit log.
invitable: :class:`bool`
Whether non-modertators can add users to the thread. Only applicable to private threads.
Whether non-moderators can add users to the thread. Only applicable to private threads.
Defaults to ``True``.
slowmode_delay: Optional[:class:`int`]
Specifies the slowmode rate limit for user in this channel, in seconds.
@ -863,7 +862,7 @@ class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hasha
return self.guild.id, self.id
def _update(self, guild: Guild, data: Union[VoiceChannelPayload, StageChannelPayload]) -> None:
self.guild = guild
self.guild: Guild = guild
self.name: str = data['name']
self.rtc_region: Optional[str] = data.get('rtc_region')
self.video_quality_mode: VideoQualityMode = try_enum(VideoQualityMode, data.get('video_quality_mode', 1))
@ -1031,7 +1030,7 @@ class VoiceChannel(VocalGuildChannel):
async def edit(self) -> Optional[VoiceChannel]:
...
async def edit(self, *, reason=None, **options):
async def edit(self, *, reason: Optional[str] = None, **options: Any) -> Optional[VoiceChannel]:
"""|coro|
Edits the channel.
@ -1049,8 +1048,8 @@ class VoiceChannel(VocalGuildChannel):
The ``region`` parameter now accepts :class:`str` instead of an enum.
.. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising
:exc:`TypeError`.
This function will now raise :exc:`TypeError` instead of
``InvalidArgument``.
Parameters
----------
@ -1175,7 +1174,7 @@ class StageChannel(VocalGuildChannel):
def _update(self, guild: Guild, data: StageChannelPayload) -> None:
super()._update(guild, data)
self.topic = data.get('topic')
self.topic: Optional[str] = data.get('topic')
@property
def requesting_to_speak(self) -> List[Member]:
@ -1316,7 +1315,7 @@ class StageChannel(VocalGuildChannel):
async def edit(self) -> Optional[StageChannel]:
...
async def edit(self, *, reason=None, **options):
async def edit(self, *, reason: Optional[str] = None, **options: Any) -> Optional[StageChannel]:
"""|coro|
Edits the channel.
@ -1334,8 +1333,8 @@ class StageChannel(VocalGuildChannel):
The ``region`` parameter now accepts :class:`str` instead of an enum.
.. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising
:exc:`TypeError`.
This function will now raise :exc:`TypeError` instead of
``InvalidArgument``.
Parameters
----------
@ -1477,7 +1476,7 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
async def edit(self) -> Optional[CategoryChannel]:
...
async def edit(self, *, reason=None, **options):
async def edit(self, *, reason: Optional[str] = None, **options: Any) -> Optional[CategoryChannel]:
"""|coro|
Edits the channel.
@ -1492,8 +1491,8 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
Edits are no longer in-place, the newly edited channel is returned instead.
.. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising
:exc:`ValueError` or :exc:`TypeError` in various cases.
This function will now raise :exc:`TypeError` or
:exc:`ValueError` instead of ``InvalidArgument``.
Parameters
----------
@ -1533,7 +1532,7 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore
@utils.copy_doc(discord.abc.GuildChannel.move)
async def move(self, **kwargs):
async def move(self, **kwargs: Any) -> None:
kwargs.pop('category', None)
await super().move(**kwargs)
@ -1717,9 +1716,9 @@ class StoreChannel(discord.abc.GuildChannel, Hashable):
position: int = ...,
nsfw: bool = ...,
sync_permissions: bool = ...,
category: Optional[CategoryChannel],
reason: Optional[str],
overwrites: Mapping[Union[Role, Member], PermissionOverwrite],
category: Optional[CategoryChannel] = ...,
reason: Optional[str] = ...,
overwrites: Mapping[Union[Role, Member], PermissionOverwrite] = ...,
) -> Optional[StoreChannel]:
...
@ -1727,7 +1726,7 @@ class StoreChannel(discord.abc.GuildChannel, Hashable):
async def edit(self) -> Optional[StoreChannel]:
...
async def edit(self, *, reason=None, **options):
async def edit(self, *, reason: Optional[str] = None, **options: Any) -> Optional[StoreChannel]:
"""|coro|
Edits the channel.
@ -1739,8 +1738,8 @@ class StoreChannel(discord.abc.GuildChannel, Hashable):
Edits are no longer in-place, the newly edited channel is returned instead.
.. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising
:exc:`ValueError` or :exc:`TypeError` in various cases.
This function will now raise :exc:`TypeError` or
:exc:`ValueError` instead of ``InvalidArgument``.
Parameters
----------
@ -1844,7 +1843,7 @@ class DMChannel(discord.abc.Messageable, discord.abc.Connectable, Hashable):
def _add_call(self, **kwargs) -> PrivateCall:
return PrivateCall(**kwargs)
async def _get_channel(self):
async def _get_channel(self) -> Self:
await self._state.access_private_channel(self.id)
return self
@ -2066,7 +2065,7 @@ class GroupChannel(discord.abc.Messageable, discord.abc.Connectable, Hashable):
def _get_voice_state_pair(self) -> Tuple[int, int]:
return self.me.id, self.id
async def _get_channel(self):
async def _get_channel(self) -> Self:
await self._state.access_private_channel(self.id)
return self
@ -2331,7 +2330,7 @@ class GroupChannel(discord.abc.Messageable, discord.abc.Connectable, Hashable):
class PartialMessageable(discord.abc.Messageable, Hashable):
"""Represents a partial messageable to aid with working messageable channels when
only a channel ID are present.
only a channel ID is present.
The only way to construct this class is through :meth:`Client.get_partial_messageable`.
@ -2367,6 +2366,9 @@ class PartialMessageable(discord.abc.Messageable, Hashable):
self.type: Optional[ChannelType] = type
self.last_message_id: Optional[int] = None
def __repr__(self) -> str:
return f'<{self.__class__.__name__} id={self.id} type={self.type!r}>'
async def _get_channel(self) -> PartialMessageable:
return self

265
discord/client.py

@ -27,7 +27,6 @@ from __future__ import annotations
import asyncio
import datetime
import logging
import signal
import sys
import traceback
from typing import (
@ -42,6 +41,7 @@ from typing import (
Sequence,
TYPE_CHECKING,
Tuple,
Type,
TypeVar,
Union,
)
@ -80,6 +80,8 @@ from .team import Team
from .member import _ClientStatus
if TYPE_CHECKING:
from typing_extensions import Self
from types import TracebackType
from .types.guild import Guild as GuildPayload
from .guild import GuildChannel
from .abc import PrivateChannel, GuildChannel, Snowflake, SnowflakeTime
@ -97,43 +99,22 @@ __all__ = (
Coro = TypeVar('Coro', bound=Callable[..., Coroutine[Any, Any, Any]])
_log = logging.getLogger(__name__)
def _cancel_tasks(loop: asyncio.AbstractEventLoop) -> None:
tasks = {t for t in asyncio.all_tasks(loop=loop) if not t.done()}
if not tasks:
return
_log.info('Cleaning up after %d tasks.', len(tasks))
for task in tasks:
task.cancel()
loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
_log.info('All tasks finished cancelling.')
class _LoopSentinel:
__slots__ = ()
for task in tasks:
if task.cancelled():
continue
if task.exception() is not None:
loop.call_exception_handler(
{
'message': 'Unhandled exception during Client.run shutdown.',
'exception': task.exception(),
'task': task,
}
)
def __getattr__(self, attr: str) -> None:
msg = (
'loop attribute cannot be accessed in non-async contexts. '
'Consider using either an asynchronous main function and passing it to asyncio.run or '
'using asynchronous initialisation hooks such as Client.setup_hook'
)
raise AttributeError(msg)
def _cleanup_loop(loop: asyncio.AbstractEventLoop) -> None:
try:
_cancel_tasks(loop)
loop.run_until_complete(loop.shutdown_asyncgens())
finally:
_log.info('Closing the event loop.')
loop.close()
_loop: Any = _LoopSentinel()
class Client:
@ -150,12 +131,6 @@ class Client:
.. versionchanged:: 1.3
Allow disabling the message cache and change the default size to ``1000``.
loop: Optional[:class:`asyncio.AbstractEventLoop`]
The :class:`asyncio.AbstractEventLoop` to use for asynchronous operations.
Defaults to ``None``, in which case the default event loop is used via
:func:`asyncio.get_event_loop()`.
connector: Optional[:class:`aiohttp.BaseConnector`]
The connector to use for connection pooling.
proxy: Optional[:class:`str`]
Proxy URL.
proxy_auth: Optional[:class:`aiohttp.BasicAuth`]
@ -172,10 +147,9 @@ class Client:
.. versionadded:: 1.5
request_guilds: :class:`bool`
Whether to request guilds at startup (behaves similarly to the old
guild_subscriptions option). Defaults to True.
Whether to request guilds at startup. Defaults to True.
.. versionadded:: 1.10
.. versionadded:: 2.0
status: Optional[:class:`.Status`]
A status to start your presence with upon logging on to Discord.
activity: Optional[:class:`.BaseActivity`]
@ -209,36 +183,44 @@ class Client:
Whether to keep presences up-to-date across clients.
The default behavior is ``True`` (what the client does).
.. versionadded:: 2.0
http_trace: :class:`aiohttp.TraceConfig`
The trace configuration to use for tracking HTTP requests the library does using ``aiohttp``.
This allows you to check requests the library is using. For more information, check the
`aiohttp documentation <https://docs.aiohttp.org/en/stable/client_advanced.html#client-tracing>`_.
.. versionadded:: 2.0
Attributes
-----------
ws
The websocket gateway the client is currently connected to. Could be ``None``.
loop: :class:`asyncio.AbstractEventLoop`
The event loop that the client uses for asynchronous operations.
"""
def __init__(
self,
*,
loop: Optional[asyncio.AbstractEventLoop] = None,
**options: Any,
):
# Set in the connect method
def __init__(self, **options: Any) -> None:
self.loop: asyncio.AbstractEventLoop = _loop
# self.ws is set in the connect method
self.ws: DiscordWebSocket = None # type: ignore
self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop
self._listeners: Dict[str, List[Tuple[asyncio.Future, Callable[..., bool]]]] = {}
connector: Optional[aiohttp.BaseConnector] = options.pop('connector', None)
proxy: Optional[str] = options.pop('proxy', None)
proxy_auth: Optional[aiohttp.BasicAuth] = options.pop('proxy_auth', None)
unsync_clock: bool = options.pop('assume_unsync_clock', True)
http_trace: Optional[aiohttp.TraceConfig] = options.pop('http_trace', None)
self.http: HTTPClient = HTTPClient(
connector, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock, loop=self.loop
self.loop,
proxy=proxy,
proxy_auth=proxy_auth,
unsync_clock=unsync_clock,
http_trace=http_trace,
)
self._handlers: Dict[str, Callable] = {'ready': self._handle_ready, 'connect': self._handle_connect}
self._handlers: Dict[str, Callable[..., None]] = {
'ready': self._handle_ready,
'connect': self._handle_connect,
}
self._hooks: Dict[str, Callable] = {
self._hooks: Dict[str, Callable[..., Coroutine[Any, Any, Any]]] = {
'before_identify': self._call_before_identify_hook,
}
@ -246,7 +228,7 @@ class Client:
self._sync_presences: bool = options.pop('sync_presence', True)
self._connection: ConnectionState = self._get_state(**options)
self._closed: bool = False
self._ready: asyncio.Event = asyncio.Event()
self._ready: asyncio.Event = MISSING
self._client_status: _ClientStatus = _ClientStatus()
self._client_activities: Dict[Optional[str], Tuple[ActivityTypes, ...]] = {
@ -259,6 +241,19 @@ class Client:
VoiceClient.warn_nacl = False
_log.warning('PyNaCl is not installed, voice will NOT be supported.')
async def __aenter__(self) -> Self:
await self._async_setup_hook()
return self
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
if not self.is_closed():
await self.close()
# Internals
def _get_state(self, **options: Any) -> ConnectionState:
@ -350,7 +345,7 @@ class Client:
def is_ready(self) -> bool:
""":class:`bool`: Specifies if the client's internal cache is ready for use."""
return self._ready.is_set()
return self._ready is not MISSING and self._ready.is_set()
async def _run_event(
self,
@ -377,9 +372,10 @@ class Client:
**kwargs: Any,
) -> asyncio.Task:
wrapped = self._run_event(coro, event_name, *args, **kwargs)
return asyncio.create_task(wrapped, name=f'discord.py: {event_name}')
# Schedules the task
return self.loop.create_task(wrapped, name=f'discord.py: {event_name}')
def dispatch(self, event: str, *args: Any, **kwargs: Any) -> None:
def dispatch(self, event: str, /, *args: Any, **kwargs: Any) -> None:
_log.debug('Dispatching event %s.', event)
method = 'on_' + event
@ -419,7 +415,7 @@ class Client:
else:
self._schedule_event(coro, method, *args, **kwargs)
async def on_error(self, event_method: str, *args: Any, **kwargs: Any) -> None:
async def on_error(self, event_method: str, /, *args: Any, **kwargs: Any) -> None:
"""|coro|
The default error handler provided by the client.
@ -427,6 +423,10 @@ class Client:
By default this prints to :data:`sys.stderr` however it could be
overridden to have a different implementation.
Check :func:`~discord.on_error` for more details.
.. versionchanged:: 2.0
``event_method`` parameter is now positional-only.
"""
print(f'Ignoring exception in {event_method}', file=sys.stderr)
traceback.print_exc()
@ -471,12 +471,45 @@ class Client:
"""
pass
async def _async_setup_hook(self) -> None:
# Called whenever the client needs to initialise asyncio objects with a running loop
loop = asyncio.get_running_loop()
self.loop = loop
self.http.loop = loop
self._connection.loop = loop
await self._connection.async_setup()
self._ready = asyncio.Event()
async def setup_hook(self) -> None:
"""|coro|
A coroutine to be called to setup the bot, by default this is blank.
To perform asynchronous setup after the bot is logged in but before
it has connected to the Websocket, overwrite this coroutine.
This is only called once, in :meth:`login`, and will be called before
any events are dispatched, making it a better solution than doing such
setup in the :func:`~discord.on_ready` event.
.. warning::
Since this is called *before* the websocket connection is made therefore
anything that waits for the websocket will deadlock, this includes things
like :meth:`wait_for` and :meth:`wait_until_ready`.
.. versionadded:: 2.0
"""
pass
# Login state management
async def login(self, token: str) -> None:
"""|coro|
Logs in the client with the specified credentials.
Logs in the client with the specified credentials and
calls the :meth:`setup_hook`.
.. warning::
@ -502,10 +535,13 @@ class Client:
_log.info('Logging in using static token.')
await self._async_setup_hook()
state = self._connection
data = await state.http.static_login(token.strip())
state.analytics_token = data.get('analytics_token', '')
state.user = ClientUser(state=state, data=data)
await self.setup_hook()
async def connect(self, *, reconnect: bool = True) -> None:
"""|coro|
@ -611,7 +647,11 @@ class Client:
await self.ws.close(code=1000)
await self.http.close()
self._ready.clear()
if self._ready is not MISSING:
self._ready.clear()
self.loop = MISSING
def clear(self) -> None:
"""Clears the internal state of the bot.
@ -644,12 +684,9 @@ class Client:
Roughly Equivalent to: ::
try:
loop.run_until_complete(start(*args, **kwargs))
asyncio.run(self.start(*args, **kwargs))
except KeyboardInterrupt:
loop.run_until_complete(close())
# cancel all tasks lingering
finally:
loop.close()
return
.. warning::
@ -657,41 +694,18 @@ class Client:
is blocking. That means that registration of events or anything being
called after this function call will not execute until it returns.
"""
loop = self.loop
try:
loop.add_signal_handler(signal.SIGINT, lambda: loop.stop())
loop.add_signal_handler(signal.SIGTERM, lambda: loop.stop())
except NotImplementedError:
pass
async def runner():
try:
async with self:
await self.start(*args, **kwargs)
finally:
if not self.is_closed():
await self.close()
def stop_loop_on_completion(f):
loop.stop()
future = asyncio.ensure_future(runner(), loop=loop)
future.add_done_callback(stop_loop_on_completion)
try:
loop.run_forever()
asyncio.run(runner())
except KeyboardInterrupt:
_log.info('Received signal to terminate bot and event loop.')
finally:
future.remove_done_callback(stop_loop_on_completion)
_log.info('Cleaning up tasks.')
_cleanup_loop(loop)
if not future.cancelled():
try:
return future.result()
except KeyboardInterrupt:
# I am unsure why this gets raised here but suppress it anyway
return None
# nothing to do here
# `asyncio.run` handles the loop cleanup
# and `self.start` closes all sockets and the HTTPClient instance.
return
# Properties
@ -712,7 +726,8 @@ class Client:
The client may be setting multiple activities, these can be accessed under :attr:`initial_activities`.
"""
return create_activity(self._connection._activities[0]) if self._connection._activities else None
state = self._connection
return create_activity(state._activities[0], state) if state._activities else None
@initial_activity.setter
def initial_activity(self, value: Optional[ActivityTypes]) -> None:
@ -727,7 +742,8 @@ class Client:
@property
def initial_activities(self) -> List[ActivityTypes]:
"""List[:class:`.BaseActivity`]: The activities set upon logging in."""
return [create_activity(activity) for activity in self._connection._activities]
state = self._connection
return [create_activity(activity, state) for activity in state._activities]
@initial_activities.setter
def initial_activities(self, values: List[ActivityTypes]) -> None:
@ -750,7 +766,7 @@ class Client:
return
@initial_status.setter
def initial_status(self, value):
def initial_status(self, value: Status):
if value is Status.offline:
self._connection._status = 'invisible'
elif isinstance(value, Status):
@ -837,9 +853,10 @@ class Client:
the user is listening to a song on Spotify with a title longer
than 128 characters. See :issue:`1738` for more information.
"""
activities = tuple(map(create_activity, self._client_activities[None]))
state = self._connection
activities = tuple(create_activity(d, state) for d in self._client_activities[None])
if activities is None and not self.is_closed():
activities = getattr(self._connection.settings, 'custom_activity', [])
activities = getattr(state.settings, 'custom_activity', [])
activities = [activities] if activities else activities
return activities
@ -870,7 +887,8 @@ class Client:
.. versionadded:: 2.0
"""
return tuple(map(create_activity, self._client_activities.get('mobile', [])))
state = self._connection
return tuple(create_activity(d, state) for d in self._client_activities.get('mobile', []))
@property
def desktop_activities(self) -> Tuple[ActivityTypes]:
@ -879,7 +897,8 @@ class Client:
.. versionadded:: 2.0
"""
return tuple(map(create_activity, self._client_activities.get('desktop', [])))
state = self._connection
return tuple(create_activity(d, state) for d in self._client_activities.get('desktop', []))
@property
def web_activities(self) -> Tuple[ActivityTypes]:
@ -888,7 +907,8 @@ class Client:
.. versionadded:: 2.0
"""
return tuple(map(create_activity, self._client_activities.get('web', [])))
state = self._connection
return tuple(create_activity(d, state) for d in self._client_activities.get('web', []))
@property
def client_activities(self) -> Tuple[ActivityTypes]:
@ -897,9 +917,10 @@ class Client:
.. versionadded:: 2.0
"""
activities = tuple(map(create_activity, self._client_activities.get('this', [])))
state = self._connection
activities = tuple(create_activity(d, state) for d in self._client_activities.get('this', []))
if activities is None and not self.is_closed():
activities = getattr(self._connection.settings, 'custom_activity', [])
activities = getattr(state.settings, 'custom_activity', [])
activities = [activities] if activities else activities
return activities
@ -979,7 +1000,7 @@ class Client:
Returns
--------
Optional[:class:`.StageInstance`]
The returns stage instance of ``None`` if not found.
The stage instance or ``None`` if not found.
"""
from .channel import StageChannel
@ -1109,12 +1130,18 @@ class Client:
"""|coro|
Waits until the client's internal cache is all ready.
.. warning::
Calling this inside :meth:`setup_hook` can lead to a deadlock.
"""
await self._ready.wait()
if self._ready is not MISSING:
await self._ready.wait()
def wait_for(
self,
event: str,
/,
*,
check: Optional[Callable[..., bool]] = None,
timeout: Optional[float] = None,
@ -1174,6 +1201,10 @@ class Client:
else:
await channel.send('\N{THUMBS UP SIGN}')
.. versionchanged:: 2.0
``event`` parameter is now positional-only.
Parameters
------------
@ -1220,7 +1251,7 @@ class Client:
# Event registration
def event(self, coro: Coro) -> Coro:
def event(self, coro: Coro, /) -> Coro:
"""A decorator that registers an event to listen to.
You can find more info about the events on the :ref:`documentation below <discord-api-events>`.
@ -1236,6 +1267,10 @@ class Client:
async def on_ready():
print('Ready!')
.. versionchanged:: 2.0
``coro`` parameter is now positional-only.
Raises
--------
TypeError
@ -1257,7 +1292,7 @@ class Client:
status: Optional[Status] = None,
afk: bool = False,
edit_settings: bool = True,
):
) -> None:
"""|coro|
Changes the client's presence.
@ -1267,8 +1302,8 @@ class Client:
Added option to update settings.
.. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising
:exc:`TypeError`.
This function will now raise :exc:`TypeError` instead of
``InvalidArgument``.
Example
---------
@ -1439,7 +1474,7 @@ class Client:
"""
code = utils.resolve_template(code)
data = await self.http.get_template(code)
return Template(data=data, state=self._connection) # type: ignore
return Template(data=data, state=self._connection)
async def fetch_guild(self, guild_id: int, /, *, with_counts: bool = True) -> Guild:
"""|coro|
@ -1498,8 +1533,8 @@ class Client:
``name`` and ``icon`` parameters are now keyword-only. The `region`` parameter has been removed.
.. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising
:exc:`ValueError`.
This function will now raise :exc:`ValueError` instead of
``InvalidArgument``.
Parameters
----------

13
discord/colour.py

@ -32,7 +32,6 @@ from typing import (
Callable,
Optional,
Tuple,
Type,
Union,
)
@ -90,10 +89,10 @@ class Colour:
def _get_byte(self, byte: int) -> int:
return (self.value >> (8 * byte)) & 0xFF
def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
return isinstance(other, Colour) and self.value == other.value
def __ne__(self, other: Any) -> bool:
def __ne__(self, other: object) -> bool:
return not self.__eq__(other)
def __str__(self) -> str:
@ -265,28 +264,28 @@ class Colour:
"""A factory method that returns a :class:`Colour` with a value of ``0x95a5a6``."""
return cls(0x95A5A6)
lighter_gray: Callable[[Type[Self]], Self] = lighter_grey
lighter_gray = lighter_grey
@classmethod
def dark_grey(cls) -> Self:
"""A factory method that returns a :class:`Colour` with a value of ``0x607d8b``."""
return cls(0x607D8B)
dark_gray: Callable[[Type[Self]], Self] = dark_grey
dark_gray = dark_grey
@classmethod
def light_grey(cls) -> Self:
"""A factory method that returns a :class:`Colour` with a value of ``0x979c9f``."""
return cls(0x979C9F)
light_gray: Callable[[Type[Self]], Self] = light_grey
light_gray = light_grey
@classmethod
def darker_grey(cls) -> Self:
"""A factory method that returns a :class:`Colour` with a value of ``0x546e7a``."""
return cls(0x546E7A)
darker_gray: Callable[[Type[Self]], Self] = darker_grey
darker_gray = darker_grey
@classmethod
def og_blurple(cls) -> Self:

10
discord/components.py

@ -374,9 +374,9 @@ class SelectOption:
emoji: Optional[Union[str, Emoji, PartialEmoji]] = None,
default: bool = False,
) -> None:
self.label = label
self.value = label if value is MISSING else value
self.description = description
self.label: str = label
self.value: str = label if value is MISSING else value
self.description: Optional[str] = description
if emoji is not None:
if isinstance(emoji, str):
@ -386,8 +386,8 @@ class SelectOption:
else:
raise TypeError(f'expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}')
self.emoji = emoji
self.default = default
self.emoji: Optional[Union[str, Emoji, PartialEmoji]] = emoji
self.default: bool = default
def __repr__(self) -> str:
return (

10
discord/context_managers.py

@ -25,13 +25,15 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
import asyncio
from typing import TYPE_CHECKING, Optional, Type
from typing import TYPE_CHECKING, Optional, Type, TypeVar
if TYPE_CHECKING:
from .abc import Messageable
from types import TracebackType
BE = TypeVar('BE', bound=BaseException)
# fmt: off
__all__ = (
'Typing',
@ -67,13 +69,13 @@ class Typing:
async def __aenter__(self) -> None:
self._channel = channel = await self.messageable._get_channel()
await channel._state.http.send_typing(channel.id)
self.task: asyncio.Task = self.loop.create_task(self.do_typing())
self.task: asyncio.Task[None] = self.loop.create_task(self.do_typing())
self.task.add_done_callback(_typing_done_callback)
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
exc_type: Optional[Type[BE]],
exc: Optional[BE],
traceback: Optional[TracebackType],
) -> None:
self.task.cancel()

190
discord/embeds.py

@ -25,7 +25,7 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
import datetime
from typing import Any, Dict, Final, List, Mapping, Protocol, TYPE_CHECKING, TypeVar, Union
from typing import Any, Dict, List, Mapping, Optional, Protocol, TYPE_CHECKING, TypeVar, Union
from . import utils
from .colour import Colour
@ -37,20 +37,6 @@ __all__ = (
# fmt: on
class _EmptyEmbed:
def __bool__(self) -> bool:
return False
def __repr__(self) -> str:
return 'Embed.Empty'
def __len__(self) -> int:
return 0
EmptyEmbed: Final = _EmptyEmbed()
class EmbedProxy:
def __init__(self, layer: Dict[str, Any]):
self.__dict__.update(layer)
@ -62,8 +48,8 @@ class EmbedProxy:
inner = ', '.join((f'{k}={v!r}' for k, v in self.__dict__.items() if not k.startswith('_')))
return f'EmbedProxy({inner})'
def __getattr__(self, attr: str) -> _EmptyEmbed:
return EmptyEmbed
def __getattr__(self, attr: str) -> None:
return None
if TYPE_CHECKING:
@ -72,37 +58,36 @@ if TYPE_CHECKING:
from .types.embed import Embed as EmbedData, EmbedType
T = TypeVar('T')
MaybeEmpty = Union[T, _EmptyEmbed]
class _EmbedFooterProxy(Protocol):
text: MaybeEmpty[str]
icon_url: MaybeEmpty[str]
text: Optional[str]
icon_url: Optional[str]
class _EmbedFieldProxy(Protocol):
name: MaybeEmpty[str]
value: MaybeEmpty[str]
name: Optional[str]
value: Optional[str]
inline: bool
class _EmbedMediaProxy(Protocol):
url: MaybeEmpty[str]
proxy_url: MaybeEmpty[str]
height: MaybeEmpty[int]
width: MaybeEmpty[int]
url: Optional[str]
proxy_url: Optional[str]
height: Optional[int]
width: Optional[int]
class _EmbedVideoProxy(Protocol):
url: MaybeEmpty[str]
height: MaybeEmpty[int]
width: MaybeEmpty[int]
url: Optional[str]
height: Optional[int]
width: Optional[int]
class _EmbedProviderProxy(Protocol):
name: MaybeEmpty[str]
url: MaybeEmpty[str]
name: Optional[str]
url: Optional[str]
class _EmbedAuthorProxy(Protocol):
name: MaybeEmpty[str]
url: MaybeEmpty[str]
icon_url: MaybeEmpty[str]
proxy_icon_url: MaybeEmpty[str]
name: Optional[str]
url: Optional[str]
icon_url: Optional[str]
proxy_icon_url: Optional[str]
class Embed:
@ -121,18 +106,15 @@ class Embed:
.. versionadded:: 2.0
Certain properties return an ``EmbedProxy``, a type
that acts similar to a regular :class:`dict` except using dotted access,
e.g. ``embed.author.icon_url``. If the attribute
is invalid or empty, then a special sentinel value is returned,
:attr:`Embed.Empty`.
For ease of use, all parameters that expect a :class:`str` are implicitly
casted to :class:`str` for you.
.. versionchanged:: 2.0
``Embed.Empty`` has been removed in favour of ``None``.
Attributes
-----------
title: :class:`str`
title: Optional[:class:`str`]
The title of the embed.
This can be set during initialisation.
type: :class:`str`
@ -140,22 +122,19 @@ class Embed:
This can be set during initialisation.
Possible strings for embed types can be found on discord's
`api docs <https://discord.com/developers/docs/resources/channel#embed-object-embed-types>`_
description: :class:`str`
description: Optional[:class:`str`]
The description of the embed.
This can be set during initialisation.
url: :class:`str`
url: Optional[:class:`str`]
The URL of the embed.
This can be set during initialisation.
timestamp: :class:`datetime.datetime`
timestamp: Optional[:class:`datetime.datetime`]
The timestamp of the embed content. This is an aware datetime.
If a naive datetime is passed, it is converted to an aware
datetime with the local timezone.
colour: Union[:class:`Colour`, :class:`int`]
colour: Optional[Union[:class:`Colour`, :class:`int`]]
The colour code of the embed. Aliased to ``color`` as well.
This can be set during initialisation.
Empty
A special sentinel value used by ``EmbedProxy`` and this class
to denote that the value or attribute is empty.
"""
__slots__ = (
@ -174,36 +153,34 @@ class Embed:
'description',
)
Empty: Final = EmptyEmbed
def __init__(
self,
*,
colour: Union[int, Colour, _EmptyEmbed] = EmptyEmbed,
color: Union[int, Colour, _EmptyEmbed] = EmptyEmbed,
title: MaybeEmpty[Any] = EmptyEmbed,
colour: Optional[Union[int, Colour]] = None,
color: Optional[Union[int, Colour]] = None,
title: Optional[Any] = None,
type: EmbedType = 'rich',
url: MaybeEmpty[Any] = EmptyEmbed,
description: MaybeEmpty[Any] = EmptyEmbed,
timestamp: MaybeEmpty[datetime.datetime] = EmptyEmbed,
url: Optional[Any] = None,
description: Optional[Any] = None,
timestamp: Optional[datetime.datetime] = None,
):
self.colour = colour if colour is not EmptyEmbed else color
self.title = title
self.type = type
self.url = url
self.description = description
self.colour = colour if colour is not None else color
self.title: Optional[str] = title
self.type: EmbedType = type
self.url: Optional[str] = url
self.description: Optional[str] = description
if self.title is not EmptyEmbed:
if self.title is not None:
self.title = str(self.title)
if self.description is not EmptyEmbed:
if self.description is not None:
self.description = str(self.description)
if self.url is not EmptyEmbed:
if self.url is not None:
self.url = str(self.url)
if timestamp is not EmptyEmbed:
if timestamp is not None:
self.timestamp = timestamp
@classmethod
@ -227,18 +204,18 @@ class Embed:
# fill in the basic fields
self.title = data.get('title', EmptyEmbed)
self.type = data.get('type', EmptyEmbed)
self.description = data.get('description', EmptyEmbed)
self.url = data.get('url', EmptyEmbed)
self.title = data.get('title', None)
self.type = data.get('type', None)
self.description = data.get('description', None)
self.url = data.get('url', None)
if self.title is not EmptyEmbed:
if self.title is not None:
self.title = str(self.title)
if self.description is not EmptyEmbed:
if self.description is not None:
self.description = str(self.description)
if self.url is not EmptyEmbed:
if self.url is not None:
self.url = str(self.url)
# try to fill in the more rich fields
@ -268,7 +245,7 @@ class Embed:
return self.__class__.from_dict(self.to_dict())
def __len__(self) -> int:
total = len(self.title) + len(self.description)
total = len(self.title or '') + len(self.description or '')
for field in getattr(self, '_fields', []):
total += len(field['name']) + len(field['value'])
@ -307,34 +284,36 @@ class Embed:
)
@property
def colour(self) -> MaybeEmpty[Colour]:
return getattr(self, '_colour', EmptyEmbed)
def colour(self) -> Optional[Colour]:
return getattr(self, '_colour', None)
@colour.setter
def colour(self, value: Union[int, Colour, _EmptyEmbed]):
if isinstance(value, (Colour, _EmptyEmbed)):
def colour(self, value: Optional[Union[int, Colour]]) -> None:
if value is None:
self._colour = None
elif isinstance(value, Colour):
self._colour = value
elif isinstance(value, int):
self._colour = Colour(value=value)
else:
raise TypeError(f'Expected discord.Colour, int, or Embed.Empty but received {value.__class__.__name__} instead.')
raise TypeError(f'Expected discord.Colour, int, or None but received {value.__class__.__name__} instead.')
color = colour
@property
def timestamp(self) -> MaybeEmpty[datetime.datetime]:
return getattr(self, '_timestamp', EmptyEmbed)
def timestamp(self) -> Optional[datetime.datetime]:
return getattr(self, '_timestamp', None)
@timestamp.setter
def timestamp(self, value: MaybeEmpty[datetime.datetime]):
def timestamp(self, value: Optional[datetime.datetime]) -> None:
if isinstance(value, datetime.datetime):
if value.tzinfo is None:
value = value.astimezone()
self._timestamp = value
elif isinstance(value, _EmptyEmbed):
self._timestamp = value
elif value is None:
self._timestamp = None
else:
raise TypeError(f"Expected datetime.datetime or Embed.Empty received {value.__class__.__name__} instead")
raise TypeError(f"Expected datetime.datetime or None received {value.__class__.__name__} instead")
@property
def footer(self) -> _EmbedFooterProxy:
@ -342,12 +321,12 @@ class Embed:
See :meth:`set_footer` for possible values you can access.
If the attribute has no value then :attr:`Empty` is returned.
If the attribute has no value then ``None`` is returned.
"""
# Lying to the type checker for better developer UX.
return EmbedProxy(getattr(self, '_footer', {})) # type: ignore
def set_footer(self, *, text: MaybeEmpty[Any] = EmptyEmbed, icon_url: MaybeEmpty[Any] = EmptyEmbed) -> Self:
def set_footer(self, *, text: Optional[Any] = None, icon_url: Optional[Any] = None) -> Self:
"""Sets the footer for the embed content.
This function returns the class instance to allow for fluent-style
@ -362,10 +341,10 @@ class Embed:
"""
self._footer = {}
if text is not EmptyEmbed:
if text is not None:
self._footer['text'] = str(text)
if icon_url is not EmptyEmbed:
if icon_url is not None:
self._footer['icon_url'] = str(icon_url)
return self
@ -396,27 +375,24 @@ class Embed:
- ``width``
- ``height``
If the attribute has no value then :attr:`Empty` is returned.
If the attribute has no value then ``None`` is returned.
"""
# Lying to the type checker for better developer UX.
return EmbedProxy(getattr(self, '_image', {})) # type: ignore
def set_image(self, *, url: MaybeEmpty[Any]) -> Self:
def set_image(self, *, url: Optional[Any]) -> Self:
"""Sets the image for the embed content.
This function returns the class instance to allow for fluent-style
chaining.
.. versionchanged:: 1.4
Passing :attr:`Empty` removes the image.
Parameters
-----------
url: :class:`str`
The source URL for the image. Only HTTP(S) is supported.
"""
if url is EmptyEmbed:
if url is None:
try:
del self._image
except AttributeError:
@ -439,19 +415,19 @@ class Embed:
- ``width``
- ``height``
If the attribute has no value then :attr:`Empty` is returned.
If the attribute has no value then ``None`` is returned.
"""
# Lying to the type checker for better developer UX.
return EmbedProxy(getattr(self, '_thumbnail', {})) # type: ignore
def set_thumbnail(self, *, url: MaybeEmpty[Any]) -> Self:
def set_thumbnail(self, *, url: Optional[Any]) -> Self:
"""Sets the thumbnail for the embed content.
This function returns the class instance to allow for fluent-style
chaining.
.. versionchanged:: 1.4
Passing :attr:`Empty` removes the thumbnail.
Passing ``None`` removes the thumbnail.
Parameters
-----------
@ -459,7 +435,7 @@ class Embed:
The source URL for the thumbnail. Only HTTP(S) is supported.
"""
if url is EmptyEmbed:
if url is None:
try:
del self._thumbnail
except AttributeError:
@ -481,7 +457,7 @@ class Embed:
- ``height`` for the video height.
- ``width`` for the video width.
If the attribute has no value then :attr:`Empty` is returned.
If the attribute has no value then ``None`` is returned.
"""
# Lying to the type checker for better developer UX.
return EmbedProxy(getattr(self, '_video', {})) # type: ignore
@ -492,7 +468,7 @@ class Embed:
The only attributes that might be accessed are ``name`` and ``url``.
If the attribute has no value then :attr:`Empty` is returned.
If the attribute has no value then ``None`` is returned.
"""
# Lying to the type checker for better developer UX.
return EmbedProxy(getattr(self, '_provider', {})) # type: ignore
@ -503,12 +479,12 @@ class Embed:
See :meth:`set_author` for possible values you can access.
If the attribute has no value then :attr:`Empty` is returned.
If the attribute has no value then ``None`` is returned.
"""
# Lying to the type checker for better developer UX.
return EmbedProxy(getattr(self, '_author', {})) # type: ignore
def set_author(self, *, name: Any, url: MaybeEmpty[Any] = EmptyEmbed, icon_url: MaybeEmpty[Any] = EmptyEmbed) -> Self:
def set_author(self, *, name: Any, url: Optional[Any] = None, icon_url: Optional[Any] = None) -> Self:
"""Sets the author for the embed content.
This function returns the class instance to allow for fluent-style
@ -528,10 +504,10 @@ class Embed:
'name': str(name),
}
if url is not EmptyEmbed:
if url is not None:
self._author['url'] = str(url)
if icon_url is not EmptyEmbed:
if icon_url is not None:
self._author['icon_url'] = str(icon_url)
return self
@ -553,11 +529,11 @@ class Embed:
@property
def fields(self) -> List[_EmbedFieldProxy]:
"""List[Union[``EmbedProxy``, :attr:`Empty`]]: Returns a :class:`list` of ``EmbedProxy`` denoting the field contents.
"""List[``EmbedProxy``]: Returns a :class:`list` of ``EmbedProxy`` denoting the field contents.
See :meth:`add_field` for possible values you can access.
If the attribute has no value then :attr:`Empty` is returned.
If the attribute has no value then ``None`` is returned.
"""
# Lying to the type checker for better developer UX.
return [EmbedProxy(d) for d in getattr(self, '_fields', [])] # type: ignore

10
discord/emoji.py

@ -23,7 +23,7 @@ DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import Any, Iterator, List, Optional, TYPE_CHECKING, Tuple
from typing import Any, Collection, Iterator, List, Optional, TYPE_CHECKING, Tuple
from .asset import Asset, AssetMixin
from .utils import SnowflakeList, snowflake_time, MISSING
@ -142,10 +142,10 @@ class Emoji(_EmojiTag, AssetMixin):
def __repr__(self) -> str:
return f'<Emoji id={self.id} name={self.name!r} animated={self.animated} managed={self.managed}>'
def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
return isinstance(other, _EmojiTag) and self.id == other.id
def __ne__(self, other: Any) -> bool:
def __ne__(self, other: object) -> bool:
return not self.__eq__(other)
def __hash__(self) -> int:
@ -214,7 +214,9 @@ class Emoji(_EmojiTag, AssetMixin):
await self._state.http.delete_custom_emoji(self.guild_id, self.id, reason=reason)
async def edit(self, *, name: str = MISSING, roles: List[Snowflake] = MISSING, reason: Optional[str] = None) -> Emoji:
async def edit(
self, *, name: str = MISSING, roles: Collection[Snowflake] = MISSING, reason: Optional[str] = None
) -> Emoji:
r"""|coro|
Edits the custom emoji.

45
discord/enums.py

@ -25,7 +25,7 @@ from __future__ import annotations
import types
from collections import namedtuple
from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, TypeVar
from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Iterator, Mapping
__all__ = (
'Enum',
@ -149,38 +149,38 @@ class EnumMeta(type):
value_cls._actual_enum_cls_ = actual_cls # type: ignore - Runtime attribute isn't understood
return actual_cls
def __iter__(cls):
def __iter__(cls) -> Iterator[Any]:
return (cls._enum_member_map_[name] for name in cls._enum_member_names_)
def __reversed__(cls):
def __reversed__(cls) -> Iterator[Any]:
return (cls._enum_member_map_[name] for name in reversed(cls._enum_member_names_))
def __len__(cls):
def __len__(cls) -> int:
return len(cls._enum_member_names_)
def __repr__(cls):
def __repr__(cls) -> str:
return f'<enum {cls.__name__}>'
@property
def __members__(cls):
def __members__(cls) -> Mapping[str, Any]:
return types.MappingProxyType(cls._enum_member_map_)
def __call__(cls, value):
def __call__(cls, value: str) -> Any:
try:
return cls._enum_value_map_[value]
except (KeyError, TypeError):
raise ValueError(f"{value!r} is not a valid {cls.__name__}")
def __getitem__(cls, key):
def __getitem__(cls, key: str) -> Any:
return cls._enum_member_map_[key]
def __setattr__(cls, name, value):
def __setattr__(cls, name: str, value: Any) -> None:
raise TypeError('Enums are immutable')
def __delattr__(cls, attr):
def __delattr__(cls, attr: str) -> None:
raise TypeError('Enums are immutable')
def __instancecheck__(self, instance):
def __instancecheck__(self, instance: Any) -> bool:
# isinstance(x, Y)
# -> __instancecheck__(Y, x)
try:
@ -215,7 +215,7 @@ class ChannelType(Enum):
private_thread = 12
stage_voice = 13
def __str__(self):
def __str__(self) -> str:
return self.name
def __int__(self):
@ -258,10 +258,10 @@ class SpeakingState(Enum):
soundshare = 2
priority = 4
def __str__(self):
def __str__(self) -> str:
return self.name
def __int__(self):
def __int__(self) -> int:
return self.value
@ -272,7 +272,7 @@ class VerificationLevel(Enum, comparable=True):
high = 3
highest = 4
def __str__(self):
def __str__(self) -> str:
return self.name
@ -281,7 +281,7 @@ class ContentFilter(Enum, comparable=True):
no_role = 1
all_members = 2
def __str__(self):
def __str__(self) -> str:
return self.name
@ -347,7 +347,7 @@ class Status(Enum):
do_not_disturb = 'dnd'
invisible = 'invisible'
def __str__(self):
def __str__(self) -> str:
return self.value
@ -360,7 +360,7 @@ class DefaultAvatar(Enum):
red = 4
pink = 5
def __str__(self):
def __str__(self) -> str:
return self.name
@ -554,6 +554,7 @@ class UserFlags(Enum):
discord_certified_moderator = 262144
bot_http_interactions = 524288
spammer = 1048576
disable_premium = 2097152
class ActivityType(Enum):
@ -565,7 +566,7 @@ class ActivityType(Enum):
custom = 4
competing = 5
def __int__(self):
def __int__(self) -> int:
return self.value
@ -690,7 +691,7 @@ class VideoQualityMode(Enum):
auto = 1
full = 2
def __int__(self):
def __int__(self) -> int:
return self.value
@ -700,7 +701,7 @@ class ComponentType(Enum):
select = 3
text_input = 4
def __int__(self):
def __int__(self) -> int:
return self.value
@ -719,7 +720,7 @@ class ButtonStyle(Enum):
red = 4
url = 5
def __int__(self):
def __int__(self) -> int:
return self.value

24
discord/ext/commands/_types.py

@ -23,26 +23,42 @@ DEALINGS IN THE SOFTWARE.
"""
from typing import Any, Callable, Coroutine, TYPE_CHECKING, TypeVar, Union
from typing import Any, Callable, Coroutine, TYPE_CHECKING, TypeVar, Union, Tuple
T = TypeVar('T')
if TYPE_CHECKING:
from typing_extensions import ParamSpec
from .bot import Bot
from .context import Context
from .cog import Cog
from .errors import CommandError
T = TypeVar('T')
P = ParamSpec('P')
MaybeCoroFunc = Union[
Callable[P, 'Coro[T]'],
Callable[P, T],
]
else:
P = TypeVar('P')
MaybeCoroFunc = Tuple[P, T]
_Bot = Bot
Coro = Coroutine[Any, Any, T]
MaybeCoro = Union[T, Coro[T]]
CoroFunc = Callable[..., Coro[Any]]
ContextT = TypeVar('ContextT', bound='Context')
Check = Union[Callable[["Cog", "ContextT"], MaybeCoro[bool]], Callable[["ContextT"], MaybeCoro[bool]]]
Hook = Union[Callable[["Cog", "ContextT"], Coro[Any]], Callable[["ContextT"], Coro[Any]]]
Error = Union[Callable[["Cog", "ContextT", "CommandError"], Coro[Any]], Callable[["ContextT", "CommandError"], Coro[Any]]]
ContextT = TypeVar('ContextT', bound='Context[Any]')
BotT = TypeVar('BotT', bound=_Bot, covariant=True)
ErrorT = TypeVar('ErrorT', bound='Error[Context[Any]]')
HookT = TypeVar('HookT', bound='Hook[Context[Any]]')
# This is merely a tag type to avoid circular import issues.
# Yes, this is a terrible solution but ultimately it is the only solution.

334
discord/ext/commands/bot.py

@ -33,9 +33,24 @@ import importlib.util
import sys
import traceback
import types
from typing import Any, Callable, Mapping, List, Dict, TYPE_CHECKING, Optional, TypeVar, Type, Union, overload
from typing import (
Any,
Callable,
Mapping,
List,
Dict,
TYPE_CHECKING,
Optional,
TypeVar,
Type,
Union,
Iterable,
Collection,
overload,
)
import discord
from discord.utils import MISSING, _is_submodule
from .core import GroupMixin
from .view import StringView
@ -50,36 +65,44 @@ if TYPE_CHECKING:
import importlib.machinery
from discord.message import Message
from discord.abc import User
from discord.abc import User, Snowflake
from ._types import (
_Bot,
BotT,
Check,
CoroFunc,
ContextT,
MaybeCoroFunc,
)
_Prefix = Union[Iterable[str], str]
_PrefixCallable = MaybeCoroFunc[[BotT, Message], _Prefix]
PrefixType = Union[_Prefix, _PrefixCallable[BotT]]
__all__ = (
'when_mentioned',
'when_mentioned_or',
'Bot',
)
MISSING: Any = discord.utils.MISSING
T = TypeVar('T')
CFT = TypeVar('CFT', bound='CoroFunc')
CXT = TypeVar('CXT', bound='Context')
BT = TypeVar('BT', bound='Bot')
def when_mentioned(bot: Bot, msg: Message) -> List[str]:
def when_mentioned(bot: _Bot, msg: Message, /) -> List[str]:
"""A callable that implements a command prefix equivalent to being mentioned.
These are meant to be passed into the :attr:`.Bot.command_prefix` attribute.
.. versionchanged:: 2.0
``bot`` and ``msg`` parameters are now positional-only.
"""
# bot.user will never be None when this is called
return [f'<@{bot.user.id}> ', f'<@!{bot.user.id}> '] # type: ignore
def when_mentioned_or(*prefixes: str) -> Callable[[Bot, Message], List[str]]:
def when_mentioned_or(*prefixes: str) -> Callable[[_Bot, Message], List[str]]:
"""A callable that implements when mentioned or other prefixes provided.
These are meant to be passed into the :attr:`.Bot.command_prefix` attribute.
@ -117,34 +140,38 @@ def when_mentioned_or(*prefixes: str) -> Callable[[Bot, Message], List[str]]:
return inner
def _is_submodule(parent: str, child: str) -> bool:
return parent == child or child.startswith(parent + ".")
class _DefaultRepr:
def __repr__(self):
return '<default-help-command>'
_default = _DefaultRepr()
_default: Any = _DefaultRepr()
class BotBase(GroupMixin):
def __init__(self, command_prefix, help_command=_default, description=None, **options):
class BotBase(GroupMixin[None]):
def __init__(
self,
command_prefix: PrefixType[BotT],
help_command: Optional[HelpCommand[Any]] = _default,
description: Optional[str] = None,
**options: Any,
) -> None:
super().__init__(**options)
self.command_prefix = command_prefix
self.command_prefix: PrefixType[BotT] = command_prefix
self.extra_events: Dict[str, List[CoroFunc]] = {}
# Self doesn't have the ClientT bound, but since this is a mixin it technically does
self.__tree: app_commands.CommandTree[Self] = app_commands.CommandTree(self) # type: ignore
self.__cogs: Dict[str, Cog] = {}
self.__extensions: Dict[str, types.ModuleType] = {}
self._checks: List[Check] = []
self._check_once = []
self._before_invoke = None
self._after_invoke = None
self._help_command = None
self.description = inspect.cleandoc(description) if description else ''
self.owner_id = options.get('owner_id')
self.owner_ids = options.get('owner_ids', set())
self.strip_after_prefix = options.get('strip_after_prefix', False)
self._check_once: List[Check] = []
self._before_invoke: Optional[CoroFunc] = None
self._after_invoke: Optional[CoroFunc] = None
self._help_command: Optional[HelpCommand[Any]] = None
self.description: str = inspect.cleandoc(description) if description else ''
self.owner_id: Optional[int] = options.get('owner_id')
self.owner_ids: Optional[Collection[int]] = options.get('owner_ids', set())
self.strip_after_prefix: bool = options.get('strip_after_prefix', False)
if self.owner_id and self.owner_ids:
raise TypeError('Both owner_id and owner_ids are set')
@ -172,7 +199,7 @@ class BotBase(GroupMixin):
# internal helpers
def dispatch(self, event_name: str, *args: Any, **kwargs: Any) -> None:
def dispatch(self, event_name: str, /, *args: Any, **kwargs: Any) -> None:
# super() will resolve to Client
super().dispatch(event_name, *args, **kwargs) # type: ignore
ev = 'on_' + event_name
@ -183,19 +210,19 @@ class BotBase(GroupMixin):
async def close(self) -> None:
for extension in tuple(self.__extensions):
try:
self.unload_extension(extension)
await self.unload_extension(extension)
except Exception:
pass
for cog in tuple(self.__cogs):
try:
self.remove_cog(cog)
await self.remove_cog(cog)
except Exception:
pass
await super().close() # type: ignore
async def on_command_error(self, context: Context, exception: errors.CommandError) -> None:
async def on_command_error(self, context: Context[BotT], exception: errors.CommandError, /) -> None:
"""|coro|
The default command error handler provided by the bot.
@ -204,6 +231,10 @@ class BotBase(GroupMixin):
overridden to have a different implementation.
This only fires if you do not specify any listeners for command error.
.. versionchanged:: 2.0
``context`` and ``exception`` parameters are now positional-only.
"""
if self.extra_events.get('on_command_error', None):
return
@ -221,7 +252,7 @@ class BotBase(GroupMixin):
# global check registration
def check(self, func: T) -> T:
def check(self, func: T, /) -> T:
r"""A decorator that adds a global check to the bot.
A global check is similar to a :func:`.check` that is applied
@ -245,12 +276,15 @@ class BotBase(GroupMixin):
def check_commands(ctx):
return ctx.command.qualified_name in allowed_commands
.. versionchanged:: 2.0
``func`` parameter is now positional-only.
"""
# T was used instead of Check to ensure the type matches on return
self.add_check(func) # type: ignore
return func
def add_check(self, func: Check, /, *, call_once: bool = False) -> None:
def add_check(self, func: Check[ContextT], /, *, call_once: bool = False) -> None:
"""Adds a global check to the bot.
This is the non-decorator interface to :meth:`.check`
@ -274,7 +308,7 @@ class BotBase(GroupMixin):
else:
self._checks.append(func)
def remove_check(self, func: Check, /, *, call_once: bool = False) -> None:
def remove_check(self, func: Check[ContextT], /, *, call_once: bool = False) -> None:
"""Removes a global check from the bot.
This function is idempotent and will not raise an exception
@ -299,7 +333,7 @@ class BotBase(GroupMixin):
except ValueError:
pass
def check_once(self, func: CFT) -> CFT:
def check_once(self, func: CFT, /) -> CFT:
r"""A decorator that adds a "call once" global check to the bot.
Unlike regular global checks, this one is called only once
@ -333,11 +367,15 @@ class BotBase(GroupMixin):
def whitelist(ctx):
return ctx.message.author.id in my_whitelist
.. versionchanged:: 2.0
``func`` parameter is now positional-only.
"""
self.add_check(func, call_once=True)
return func
async def can_run(self, ctx: Context, *, call_once: bool = False) -> bool:
async def can_run(self, ctx: Context[BotT], /, *, call_once: bool = False) -> bool:
data = self._check_once if call_once else self._checks
if len(data) == 0:
@ -346,12 +384,15 @@ class BotBase(GroupMixin):
# type-checker doesn't distinguish between functions and methods
return await discord.utils.async_all(f(ctx) for f in data) # type: ignore
async def is_owner(self, user: User) -> bool:
async def is_owner(self, user: User, /) -> bool:
"""|coro|
Checks if a :class:`~discord.User` or :class:`~discord.Member` is the owner of
this bot.
.. versionchanged:: 2.0
``user`` parameter is now positional-only.
Parameters
-----------
user: :class:`.abc.User`
@ -374,7 +415,7 @@ class BotBase(GroupMixin):
else:
raise AttributeError('Owners aren\'t set.')
def before_invoke(self, coro: CFT) -> CFT:
def before_invoke(self, coro: CFT, /) -> CFT:
"""A decorator that registers a coroutine as a pre-invoke hook.
A pre-invoke hook is called directly before the command is
@ -390,6 +431,10 @@ class BotBase(GroupMixin):
without error. If any check or argument parsing procedures fail
then the hooks are not called.
.. versionchanged:: 2.0
``coro`` parameter is now positional-only.
Parameters
-----------
coro: :ref:`coroutine <coroutine>`
@ -406,7 +451,7 @@ class BotBase(GroupMixin):
self._before_invoke = coro
return coro
def after_invoke(self, coro: CFT) -> CFT:
def after_invoke(self, coro: CFT, /) -> CFT:
r"""A decorator that registers a coroutine as a post-invoke hook.
A post-invoke hook is called directly after the command is
@ -423,6 +468,10 @@ class BotBase(GroupMixin):
callback raising an error (i.e. :exc:`.CommandInvokeError`\).
This makes it ideal for clean-up scenarios.
.. versionchanged:: 2.0
``coro`` parameter is now positional-only.
Parameters
-----------
coro: :ref:`coroutine <coroutine>`
@ -441,9 +490,13 @@ class BotBase(GroupMixin):
# listener registration
def add_listener(self, func: CoroFunc, name: str = MISSING) -> None:
def add_listener(self, func: CoroFunc, /, name: str = MISSING) -> None:
"""The non decorator alternative to :meth:`.listen`.
.. versionchanged:: 2.0
``func`` parameter is now positional-only.
Parameters
-----------
func: :ref:`coroutine <coroutine>`
@ -473,9 +526,13 @@ class BotBase(GroupMixin):
else:
self.extra_events[name] = [func]
def remove_listener(self, func: CoroFunc, name: str = MISSING) -> None:
def remove_listener(self, func: CoroFunc, /, name: str = MISSING) -> None:
"""Removes a listener from the pool of listeners.
.. versionchanged:: 2.0
``func`` parameter is now positional-only.
Parameters
-----------
func
@ -531,11 +588,29 @@ class BotBase(GroupMixin):
# cogs
def add_cog(self, cog: Cog, /, *, override: bool = False) -> None:
"""Adds a "cog" to the bot.
async def add_cog(
self,
cog: Cog,
/,
*,
override: bool = False,
guild: Optional[Snowflake] = MISSING,
guilds: List[Snowflake] = MISSING,
) -> None:
"""|coro|
Adds a "cog" to the bot.
A cog is a class that has its own event listeners and commands.
If the cog is a :class:`.app_commands.Group` then it is added to
the bot's :class:`~discord.app_commands.CommandTree` as well.
.. note::
Exceptions raised inside a :class:`.Cog`'s :meth:`~.Cog.cog_load` method will be
propagated to the caller.
.. versionchanged:: 2.0
:exc:`.ClientException` is raised when a cog with the same name
@ -545,6 +620,10 @@ class BotBase(GroupMixin):
``cog`` parameter is now positional-only.
.. versionchanged:: 2.0
This method is now a :term:`coroutine`.
Parameters
-----------
cog: :class:`.Cog`
@ -553,6 +632,19 @@ class BotBase(GroupMixin):
If a previously loaded cog with the same name should be ejected
instead of raising an error.
.. versionadded:: 2.0
guild: Optional[:class:`~discord.abc.Snowflake`]
If the cog is an application command group, then this would be the
guild where the cog group would be added to. If not given then
it becomes a global command instead.
.. versionadded:: 2.0
guilds: List[:class:`~discord.abc.Snowflake`]
If the cog is an application command group, then this would be the
guilds where the cog group would be added to. If not given then
it becomes a global command instead. Cannot be mixed with
``guild``.
.. versionadded:: 2.0
Raises
@ -574,9 +666,12 @@ class BotBase(GroupMixin):
if existing is not None:
if not override:
raise discord.ClientException(f'Cog named {cog_name!r} already loaded')
self.remove_cog(cog_name)
await self.remove_cog(cog_name, guild=guild, guilds=guilds)
cog = cog._inject(self)
if isinstance(cog, app_commands.Group):
self.__tree.add_command(cog, override=override, guild=guild, guilds=guilds)
cog = await cog._inject(self, override=override, guild=guild, guilds=guilds)
self.__cogs[cog_name] = cog
def get_cog(self, name: str, /) -> Optional[Cog]:
@ -602,8 +697,17 @@ class BotBase(GroupMixin):
"""
return self.__cogs.get(name)
def remove_cog(self, name: str, /) -> Optional[Cog]:
"""Removes a cog from the bot and returns it.
async def remove_cog(
self,
name: str,
/,
*,
guild: Optional[Snowflake] = MISSING,
guilds: List[Snowflake] = MISSING,
) -> Optional[Cog]:
"""|coro|
Removes a cog from the bot and returns it.
All registered commands and event listeners that the
cog has registered will be removed as well.
@ -614,10 +718,27 @@ class BotBase(GroupMixin):
``name`` parameter is now positional-only.
.. versionchanged:: 2.0
This method is now a :term:`coroutine`.
Parameters
-----------
name: :class:`str`
The name of the cog to remove.
guild: Optional[:class:`~discord.abc.Snowflake`]
If the cog is an application command group, then this would be the
guild where the cog group would be removed from. If not given then
a global command is removed instead instead.
.. versionadded:: 2.0
guilds: List[:class:`~discord.abc.Snowflake`]
If the cog is an application command group, then this would be the
guilds where the cog group would be removed from. If not given then
a global command is removed instead instead. Cannot be mixed with
``guild``.
.. versionadded:: 2.0
Returns
-------
@ -632,7 +753,16 @@ class BotBase(GroupMixin):
help_command = self._help_command
if help_command and help_command.cog is cog:
help_command.cog = None
cog._eject(self)
guild_ids = _retrieve_guild_ids(cog, guild, guilds)
if isinstance(cog, app_commands.Group):
if guild_ids is None:
self.__tree.remove_command(name)
else:
for guild_id in guild_ids:
self.__tree.remove_command(name, guild=discord.Object(guild_id))
await cog._eject(self, guild_ids=guild_ids)
return cog
@ -643,12 +773,12 @@ class BotBase(GroupMixin):
# extensions
def _remove_module_references(self, name: str) -> None:
async def _remove_module_references(self, name: str) -> None:
# find all references to the module
# remove the cogs registered from the module
for cogname, cog in self.__cogs.copy().items():
if _is_submodule(name, cog.__module__):
self.remove_cog(cogname)
await self.remove_cog(cogname)
# remove all the commands from the module
for cmd in self.all_commands.copy().values():
@ -667,14 +797,17 @@ class BotBase(GroupMixin):
for index in reversed(remove):
del event_list[index]
def _call_module_finalizers(self, lib: types.ModuleType, key: str) -> None:
# remove all relevant application commands from the tree
self.__tree._remove_with_module(name)
async def _call_module_finalizers(self, lib: types.ModuleType, key: str) -> None:
try:
func = getattr(lib, 'teardown')
except AttributeError:
pass
else:
try:
func(self)
await func(self)
except Exception:
pass
finally:
@ -685,7 +818,7 @@ class BotBase(GroupMixin):
if _is_submodule(name, module):
del sys.modules[module]
def _load_from_module_spec(self, spec: importlib.machinery.ModuleSpec, key: str) -> None:
async def _load_from_module_spec(self, spec: importlib.machinery.ModuleSpec, key: str) -> None:
# precondition: key not in self.__extensions
lib = importlib.util.module_from_spec(spec)
sys.modules[key] = lib
@ -702,11 +835,11 @@ class BotBase(GroupMixin):
raise errors.NoEntryPointError(key)
try:
setup(self)
await setup(self)
except Exception as e:
del sys.modules[key]
self._remove_module_references(lib.__name__)
self._call_module_finalizers(lib, key)
await self._remove_module_references(lib.__name__)
await self._call_module_finalizers(lib, key)
raise errors.ExtensionFailed(key, e) from e
else:
self.__extensions[key] = lib
@ -717,8 +850,10 @@ class BotBase(GroupMixin):
except ImportError:
raise errors.ExtensionNotFound(name)
def load_extension(self, name: str, *, package: Optional[str] = None) -> None:
"""Loads an extension.
async def load_extension(self, name: str, *, package: Optional[str] = None) -> None:
"""|coro|
Loads an extension.
An extension is a python module that contains commands, cogs, or
listeners.
@ -727,6 +862,10 @@ class BotBase(GroupMixin):
the entry point on what to do when the extension is loaded. This entry
point must have a single argument, the ``bot``.
.. versionchanged:: 2.0
This method is now a :term:`coroutine`.
Parameters
------------
name: :class:`str`
@ -762,10 +901,12 @@ class BotBase(GroupMixin):
if spec is None:
raise errors.ExtensionNotFound(name)
self._load_from_module_spec(spec, name)
await self._load_from_module_spec(spec, name)
async def unload_extension(self, name: str, *, package: Optional[str] = None) -> None:
"""|coro|
def unload_extension(self, name: str, *, package: Optional[str] = None) -> None:
"""Unloads an extension.
Unloads an extension.
When the extension is unloaded, all commands, listeners, and cogs are
removed from the bot and the module is un-imported.
@ -775,6 +916,10 @@ class BotBase(GroupMixin):
parameter, the ``bot``, similar to ``setup`` from
:meth:`~.Bot.load_extension`.
.. versionchanged:: 2.0
This method is now a :term:`coroutine`.
Parameters
------------
name: :class:`str`
@ -802,10 +947,10 @@ class BotBase(GroupMixin):
if lib is None:
raise errors.ExtensionNotLoaded(name)
self._remove_module_references(lib.__name__)
self._call_module_finalizers(lib, name)
await self._remove_module_references(lib.__name__)
await self._call_module_finalizers(lib, name)
def reload_extension(self, name: str, *, package: Optional[str] = None) -> None:
async def reload_extension(self, name: str, *, package: Optional[str] = None) -> None:
"""Atomically reloads an extension.
This replaces the extension with the same extension, only refreshed. This is
@ -856,14 +1001,14 @@ class BotBase(GroupMixin):
try:
# Unload and then load the module...
self._remove_module_references(lib.__name__)
self._call_module_finalizers(lib, name)
self.load_extension(name)
await self._remove_module_references(lib.__name__)
await self._call_module_finalizers(lib, name)
await self.load_extension(name)
except Exception:
# if the load failed, the remnants should have been
# cleaned from the load_extension function call
# so let's load it from our old compiled library.
lib.setup(self) # type: ignore
await lib.setup(self)
self.__extensions[name] = lib
# revert sys.modules back to normal and raise back to caller
@ -878,11 +1023,11 @@ class BotBase(GroupMixin):
# help command stuff
@property
def help_command(self) -> Optional[HelpCommand]:
def help_command(self) -> Optional[HelpCommand[Any]]:
return self._help_command
@help_command.setter
def help_command(self, value: Optional[HelpCommand]) -> None:
def help_command(self, value: Optional[HelpCommand[Any]]) -> None:
if value is not None:
if not isinstance(value, HelpCommand):
raise TypeError('help_command must be a subclass of HelpCommand')
@ -896,14 +1041,32 @@ class BotBase(GroupMixin):
else:
self._help_command = None
# application command interop
# As mentioned above, this is a mixin so the Self type hint fails here.
# However, since the only classes that can use this are subclasses of Client
# anyway, then this is sound.
@property
def tree(self) -> app_commands.CommandTree[Self]: # type: ignore
""":class:`~discord.app_commands.CommandTree`: The command tree responsible for handling the application commands
in this bot.
.. versionadded:: 2.0
"""
return self.__tree
# command processing
async def get_prefix(self, message: Message) -> Union[List[str], str]:
async def get_prefix(self, message: Message, /) -> Union[List[str], str]:
"""|coro|
Retrieves the prefix the bot is listening to
with the message as a context.
.. versionchanged:: 2.0
``message`` parameter is now positional-only.
Parameters
-----------
message: :class:`discord.Message`
@ -917,11 +1080,12 @@ class BotBase(GroupMixin):
"""
prefix = ret = self.command_prefix
if callable(prefix):
ret = await discord.utils.maybe_coroutine(prefix, self, message)
# self will be a Bot or AutoShardedBot
ret = await discord.utils.maybe_coroutine(prefix, self, message) # type: ignore
if not isinstance(ret, str):
try:
ret = list(ret)
ret = list(ret) # type: ignore
except TypeError:
# It's possible that a generator raised this exception. Don't
# replace it with our own error if that's the case.
@ -942,6 +1106,7 @@ class BotBase(GroupMixin):
async def get_context(
self,
message: Message,
/,
) -> Context[Self]: # type: ignore
...
@ -949,16 +1114,18 @@ class BotBase(GroupMixin):
async def get_context(
self,
message: Message,
/,
*,
cls: Type[CXT] = ...,
) -> CXT: # type: ignore
cls: Type[ContextT] = ...,
) -> ContextT:
...
async def get_context(
self,
message: Message,
/,
*,
cls: Type[CXT] = MISSING,
cls: Type[ContextT] = MISSING,
) -> Any:
r"""|coro|
@ -972,6 +1139,10 @@ class BotBase(GroupMixin):
If the context is not valid then it is not a valid candidate to be
invoked under :meth:`~.Bot.invoke`.
.. versionchanged:: 2.0
``message`` parameter is now positional-only.
Parameters
-----------
message: :class:`discord.Message`
@ -1039,12 +1210,16 @@ class BotBase(GroupMixin):
ctx.command = self.all_commands.get(invoker)
return ctx
async def invoke(self, ctx: Context) -> None:
async def invoke(self, ctx: Context[BotT], /) -> None:
"""|coro|
Invokes the command given under the invocation context and
handles all the internal event dispatch mechanisms.
.. versionchanged:: 2.0
``ctx`` parameter is now positional-only.
Parameters
-----------
ctx: :class:`.Context`
@ -1065,7 +1240,7 @@ class BotBase(GroupMixin):
exc = errors.CommandNotFound(f'Command "{ctx.invoked_with}" is not found')
self.dispatch('command_error', ctx, exc)
async def process_commands(self, message: Message) -> None:
async def process_commands(self, message: Message, /) -> None:
"""|coro|
This function processes the commands that have been registered
@ -1082,6 +1257,10 @@ class BotBase(GroupMixin):
This also checks if the message's author is a bot and doesn't
call :meth:`~.Bot.get_context` or :meth:`~.Bot.invoke` if so.
.. versionchanged:: 2.0
``message`` parameter is now positional-only.
Parameters
-----------
message: :class:`discord.Message`
@ -1091,9 +1270,10 @@ class BotBase(GroupMixin):
return
ctx = await self.get_context(message)
await self.invoke(ctx)
# the type of the invocation context's bot attribute will be correct
await self.invoke(ctx) # type: ignore
async def on_message(self, message):
async def on_message(self, message: Message, /) -> None:
await self.process_commands(message)

134
discord/ext/commands/cog.py

@ -24,14 +24,17 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
import inspect
import discord.utils
import discord
from discord import app_commands
from discord.utils import maybe_coroutine
from typing import Any, Callable, ClassVar, Dict, Generator, List, Optional, TYPE_CHECKING, Tuple, TypeVar, Type
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, TYPE_CHECKING, Tuple, TypeVar, Union
from ._types import _BaseCommand
from ._types import _BaseCommand, BotT
if TYPE_CHECKING:
from typing_extensions import Self
from discord.abc import Snowflake
from .bot import BotBase
from .context import Context
@ -109,20 +112,35 @@ class CogMeta(type):
__cog_name__: str
__cog_settings__: Dict[str, Any]
__cog_commands__: List[Command]
__cog_commands__: List[Command[Any, ..., Any]]
__cog_is_app_commands_group__: bool
__cog_app_commands__: List[Union[app_commands.Group, app_commands.Command[Any, ..., Any]]]
__cog_listeners__: List[Tuple[str, str]]
def __new__(cls, *args: Any, **kwargs: Any) -> Self:
name, bases, attrs = args
attrs['__cog_name__'] = kwargs.pop('name', name)
attrs['__cog_name__'] = kwargs.get('name', name)
attrs['__cog_settings__'] = kwargs.pop('command_attrs', {})
is_parent = any(issubclass(base, app_commands.Group) for base in bases)
attrs['__cog_is_app_commands_group__'] = is_parent
description = kwargs.pop('description', None)
description = kwargs.get('description', None)
if description is None:
description = inspect.cleandoc(attrs.get('__doc__', ''))
attrs['__cog_description__'] = description
if is_parent:
attrs['__discord_app_commands_skip_init_binding__'] = True
# This is hacky, but it signals the Group not to process this info.
# It's overridden later.
attrs['__discord_app_commands_group_children__'] = True
else:
# Remove the extraneous keyword arguments we're using
kwargs.pop('name', None)
kwargs.pop('description', None)
commands = {}
cog_app_commands = {}
listeners = {}
no_bot_cog = 'Commands or listeners must not start with cog_ or bot_ (in method {0.__name__}.{1})'
@ -143,6 +161,8 @@ class CogMeta(type):
if elem.startswith(('cog_', 'bot_')):
raise TypeError(no_bot_cog.format(base, elem))
commands[elem] = value
elif isinstance(value, (app_commands.Group, app_commands.Command)) and value.parent is None:
cog_app_commands[elem] = value
elif inspect.iscoroutinefunction(value):
try:
getattr(value, '__cog_listener__')
@ -154,6 +174,13 @@ class CogMeta(type):
listeners[elem] = value
new_cls.__cog_commands__ = list(commands.values()) # this will be copied in Cog.__new__
new_cls.__cog_app_commands__ = list(cog_app_commands.values())
if is_parent:
# Prefill the app commands for the Group as well..
# The type checker doesn't like runtime attribute modification and this one's
# optional so it can't be cheesed.
new_cls.__discord_app_commands_group_children__ = new_cls.__cog_app_commands__ # type: ignore
listeners_as_list = []
for listener in listeners.values():
@ -189,10 +216,11 @@ class Cog(metaclass=CogMeta):
are equally valid here.
"""
__cog_name__: ClassVar[str]
__cog_settings__: ClassVar[Dict[str, Any]]
__cog_commands__: ClassVar[List[Command[Self, ..., Any]]]
__cog_listeners__: ClassVar[List[Tuple[str, str]]]
__cog_name__: str
__cog_settings__: Dict[str, Any]
__cog_commands__: List[Command[Self, ..., Any]]
__cog_app_commands__: List[Union[app_commands.Group, app_commands.Command[Self, ..., Any]]]
__cog_listeners__: List[Tuple[str, str]]
def __new__(cls, *args: Any, **kwargs: Any) -> Self:
# For issue 426, we need to store a copy of the command objects
@ -219,6 +247,27 @@ class Cog(metaclass=CogMeta):
parent.remove_command(command.name) # type: ignore
parent.add_command(command) # type: ignore
# Register the application commands
children: List[Union[app_commands.Group, app_commands.Command[Self, ..., Any]]] = []
for command in cls.__cog_app_commands__:
if cls.__cog_is_app_commands_group__:
# Type checker doesn't understand this type of narrowing.
# Not even with TypeGuard somehow.
command.parent = self # type: ignore
copy = command._copy_with_binding(self)
children.append(copy)
if command._attr:
setattr(self, command._attr, copy)
self.__cog_app_commands__ = children
if cls.__cog_is_app_commands_group__:
# Dynamic attribute setting
self.__discord_app_commands_group_children__ = children # type: ignore
# Enforce this to work even if someone forgets __init__
self.module = cls.__module__ # type: ignore
return self
def get_commands(self) -> List[Command[Self, ..., Any]]:
@ -330,18 +379,35 @@ class Cog(metaclass=CogMeta):
return not hasattr(self.cog_command_error.__func__, '__cog_special_method__')
@_cog_special_method
def cog_unload(self) -> None:
"""A special method that is called when the cog gets removed.
async def cog_load(self) -> None:
"""|maybecoro|
A special method that is called when the cog gets loaded.
Subclasses must replace this if they want special asynchronous loading behaviour.
Note that the ``__init__`` special method does not allow asynchronous code to run
inside it, thus this is helpful for setting up code that needs to be asynchronous.
.. versionadded:: 2.0
"""
pass
@_cog_special_method
async def cog_unload(self) -> None:
"""|maybecoro|
This function **cannot** be a coroutine. It must be a regular
function.
A special method that is called when the cog gets removed.
Subclasses must replace this if they want special unloading behaviour.
.. versionchanged:: 2.0
This method can now be a :term:`coroutine`.
"""
pass
@_cog_special_method
def bot_check_once(self, ctx: Context) -> bool:
def bot_check_once(self, ctx: Context[BotT]) -> bool:
"""A special method that registers as a :meth:`.Bot.check_once`
check.
@ -351,7 +417,7 @@ class Cog(metaclass=CogMeta):
return True
@_cog_special_method
def bot_check(self, ctx: Context) -> bool:
def bot_check(self, ctx: Context[BotT]) -> bool:
"""A special method that registers as a :meth:`.Bot.check`
check.
@ -361,7 +427,7 @@ class Cog(metaclass=CogMeta):
return True
@_cog_special_method
def cog_check(self, ctx: Context) -> bool:
def cog_check(self, ctx: Context[BotT]) -> bool:
"""A special method that registers as a :func:`~discord.ext.commands.check`
for every command and subcommand in this cog.
@ -371,7 +437,7 @@ class Cog(metaclass=CogMeta):
return True
@_cog_special_method
async def cog_command_error(self, ctx: Context, error: Exception) -> None:
async def cog_command_error(self, ctx: Context[BotT], error: Exception) -> None:
"""A special method that is called whenever an error
is dispatched inside this cog.
@ -390,7 +456,7 @@ class Cog(metaclass=CogMeta):
pass
@_cog_special_method
async def cog_before_invoke(self, ctx: Context) -> None:
async def cog_before_invoke(self, ctx: Context[BotT]) -> None:
"""A special method that acts as a cog local pre-invoke hook.
This is similar to :meth:`.Command.before_invoke`.
@ -405,7 +471,7 @@ class Cog(metaclass=CogMeta):
pass
@_cog_special_method
async def cog_after_invoke(self, ctx: Context) -> None:
async def cog_after_invoke(self, ctx: Context[BotT]) -> None:
"""A special method that acts as a cog local post-invoke hook.
This is similar to :meth:`.Command.after_invoke`.
@ -419,9 +485,13 @@ class Cog(metaclass=CogMeta):
"""
pass
def _inject(self, bot: BotBase) -> Self:
async def _inject(self, bot: BotBase, override: bool, guild: Optional[Snowflake], guilds: List[Snowflake]) -> Self:
cls = self.__class__
# we'll call this first so that errors can propagate without
# having to worry about undoing anything
await maybe_coroutine(self.cog_load)
# realistically, the only thing that can cause loading errors
# is essentially just the command loading, which raises if there are
# duplicates. When this condition is met, we want to undo all what
@ -430,7 +500,8 @@ class Cog(metaclass=CogMeta):
command.cog = self
if command.parent is None:
try:
bot.add_command(command)
# Type checker does not understand the generic bounds here
bot.add_command(command) # type: ignore
except Exception as e:
# undo our additions
for to_undo in self.__cog_commands__[:index]:
@ -452,9 +523,15 @@ class Cog(metaclass=CogMeta):
for name, method_name in self.__cog_listeners__:
bot.add_listener(getattr(self, method_name), name)
# Only do this if these are "top level" commands
if not cls.__cog_is_app_commands_group__:
for command in self.__cog_app_commands__:
# This is already atomic
bot.tree.add_command(command, override=override, guild=guild, guilds=guilds)
return self
def _eject(self, bot: BotBase) -> None:
async def _eject(self, bot: BotBase, guild_ids: Optional[Iterable[int]]) -> None:
cls = self.__class__
try:
@ -462,6 +539,15 @@ class Cog(metaclass=CogMeta):
if command.parent is None:
bot.remove_command(command.name)
if not cls.__cog_is_app_commands_group__:
for command in self.__cog_app_commands__:
guild_ids = guild_ids or command._guild_ids
if guild_ids is None:
bot.tree.remove_command(command.name)
else:
for guild_id in guild_ids:
bot.tree.remove_command(command.name, guild=discord.Object(id=guild_id))
for name, method_name in self.__cog_listeners__:
bot.remove_listener(getattr(self, method_name), name)
@ -472,6 +558,6 @@ class Cog(metaclass=CogMeta):
bot.remove_check(self.bot_check_once, call_once=True)
finally:
try:
self.cog_unload()
await maybe_coroutine(self.cog_unload)
except Exception:
pass

16
discord/ext/commands/context.py

@ -28,6 +28,8 @@ import re
from typing import Any, Dict, Generic, List, Optional, TYPE_CHECKING, TypeVar, Union
from ._types import BotT
import discord.abc
import discord.utils
@ -58,7 +60,6 @@ MISSING: Any = discord.utils.MISSING
T = TypeVar('T')
BotT = TypeVar('BotT', bound="Bot")
CogT = TypeVar('CogT', bound="Cog")
if TYPE_CHECKING:
@ -132,10 +133,10 @@ class Context(discord.abc.Messageable, Generic[BotT]):
args: List[Any] = MISSING,
kwargs: Dict[str, Any] = MISSING,
prefix: Optional[str] = None,
command: Optional[Command] = None,
command: Optional[Command[Any, ..., Any]] = None,
invoked_with: Optional[str] = None,
invoked_parents: List[str] = MISSING,
invoked_subcommand: Optional[Command] = None,
invoked_subcommand: Optional[Command[Any, ..., Any]] = None,
subcommand_passed: Optional[str] = None,
command_failed: bool = False,
current_parameter: Optional[inspect.Parameter] = None,
@ -145,11 +146,11 @@ class Context(discord.abc.Messageable, Generic[BotT]):
self.args: List[Any] = args or []
self.kwargs: Dict[str, Any] = kwargs or {}
self.prefix: Optional[str] = prefix
self.command: Optional[Command] = command
self.command: Optional[Command[Any, ..., Any]] = command
self.view: StringView = view
self.invoked_with: Optional[str] = invoked_with
self.invoked_parents: List[str] = invoked_parents or []
self.invoked_subcommand: Optional[Command] = invoked_subcommand
self.invoked_subcommand: Optional[Command[Any, ..., Any]] = invoked_subcommand
self.subcommand_passed: Optional[str] = subcommand_passed
self.command_failed: bool = command_failed
self.current_parameter: Optional[inspect.Parameter] = current_parameter
@ -352,6 +353,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
"""
from .core import Group, Command, wrap_callback
from .errors import CommandError
from .help import _context
bot = self.bot
cmd = bot.help_command
@ -359,8 +361,8 @@ class Context(discord.abc.Messageable, Generic[BotT]):
if cmd is None:
return None
cmd = cmd.copy()
cmd.context = self
_context.set(self)
if len(args) == 0:
await cmd.prepare_help_command(self, None)
mapping = cmd.get_bot_mapping()

153
discord/ext/commands/converter.py

@ -41,7 +41,6 @@ from typing import (
Tuple,
Union,
runtime_checkable,
overload,
)
import discord
@ -51,9 +50,8 @@ if TYPE_CHECKING:
from .context import Context
from discord.state import Channel
from discord.threads import Thread
from .bot import Bot
_Bot = Bot
from ._types import BotT, _Bot
__all__ = (
@ -80,13 +78,14 @@ __all__ = (
'ThreadConverter',
'GuildChannelConverter',
'GuildStickerConverter',
'ScheduledEventConverter',
'clean_content',
'Greedy',
'run_converters',
)
def _get_from_guilds(bot, getter, argument):
def _get_from_guilds(bot: _Bot, getter: str, argument: Any) -> Any:
result = None
for guild in bot.guilds:
result = getattr(guild, getter)(argument)
@ -114,7 +113,7 @@ class Converter(Protocol[T_co]):
method to do its conversion logic. This method must be a :ref:`coroutine <coroutine>`.
"""
async def convert(self, ctx: Context, argument: str) -> T_co:
async def convert(self, ctx: Context[BotT], argument: str) -> T_co:
"""|coro|
The method to override to do conversion logic.
@ -162,7 +161,7 @@ class ObjectConverter(IDConverter[discord.Object]):
2. Lookup by member, role, or channel mention.
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Object:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Object:
match = self._get_id_match(argument) or re.match(r'<(?:@(?:!|&)?|#)([0-9]{15,20})>$', argument)
if match is None:
@ -195,7 +194,7 @@ class MemberConverter(IDConverter[discord.Member]):
optionally caching the result if :attr:`.MemberCacheFlags.joined` is enabled.
"""
async def query_member_named(self, guild, argument):
async def query_member_named(self, guild: discord.Guild, argument: str) -> Optional[discord.Member]:
cache = guild._state.member_cache_flags.joined
if len(argument) > 5 and argument[-5] == '#':
username, _, discriminator = argument.rpartition('#')
@ -226,7 +225,7 @@ class MemberConverter(IDConverter[discord.Member]):
return None
return members[0]
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Member:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Member:
bot = ctx.bot
match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument)
guild = ctx.guild
@ -280,7 +279,7 @@ class UserConverter(IDConverter[discord.User]):
and it's not available in cache.
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.User:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.User:
match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument)
result = None
state = ctx._state
@ -345,7 +344,7 @@ class PartialMessageConverter(Converter[discord.PartialMessage]):
if not match:
raise MessageNotFound(argument)
data = match.groupdict()
channel_id = discord.utils._get_as_snowflake(data, 'channel_id')
channel_id = discord.utils._get_as_snowflake(data, 'channel_id') or ctx.channel.id
message_id = int(data['message_id'])
guild_id = data.get('guild_id')
if guild_id is None:
@ -358,7 +357,7 @@ class PartialMessageConverter(Converter[discord.PartialMessage]):
@staticmethod
def _resolve_channel(
ctx: Context[_Bot], guild_id: Optional[int], channel_id: Optional[int]
ctx: Context[BotT], guild_id: Optional[int], channel_id: Optional[int]
) -> Optional[Union[Channel, Thread]]:
if channel_id is None:
# we were passed just a message id so we can assume the channel is the current context channel
@ -372,7 +371,7 @@ class PartialMessageConverter(Converter[discord.PartialMessage]):
return ctx.bot.get_channel(channel_id)
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.PartialMessage:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.PartialMessage:
guild_id, message_id, channel_id = self._get_id_matches(ctx, argument)
channel = self._resolve_channel(ctx, guild_id, channel_id)
if not channel or not isinstance(channel, discord.abc.Messageable):
@ -395,14 +394,14 @@ class MessageConverter(IDConverter[discord.Message]):
Raise :exc:`.ChannelNotFound`, :exc:`.MessageNotFound` or :exc:`.ChannelNotReadable` instead of generic :exc:`.BadArgument`
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Message:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Message:
guild_id, message_id, channel_id = PartialMessageConverter._get_id_matches(ctx, argument)
message = ctx.bot._connection._get_message(message_id)
if message:
return message
channel = PartialMessageConverter._resolve_channel(ctx, guild_id, channel_id)
if not channel or not isinstance(channel, discord.abc.Messageable):
raise ChannelNotFound(channel_id) # type: ignore - channel_id won't be None here
raise ChannelNotFound(channel_id)
try:
return await channel.fetch_message(message_id)
except discord.NotFound:
@ -426,11 +425,11 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
.. versionadded:: 2.0
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.abc.GuildChannel:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.abc.GuildChannel:
return self._resolve_channel(ctx, argument, 'channels', discord.abc.GuildChannel)
@staticmethod
def _resolve_channel(ctx: Context, argument: str, attribute: str, type: Type[CT]) -> CT:
def _resolve_channel(ctx: Context[BotT], argument: str, attribute: str, type: Type[CT]) -> CT:
bot = ctx.bot
match = IDConverter._get_id_match(argument) or re.match(r'<#([0-9]{15,20})>$', argument)
@ -447,7 +446,7 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
def check(c):
return isinstance(c, type) and c.name == argument
result = discord.utils.find(check, bot.get_all_channels())
result = discord.utils.find(check, bot.get_all_channels()) # type: ignore
else:
channel_id = int(match.group(1))
if guild:
@ -462,7 +461,7 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
return result
@staticmethod
def _resolve_thread(ctx: Context, argument: str, attribute: str, type: Type[TT]) -> TT:
def _resolve_thread(ctx: Context[BotT], argument: str, attribute: str, type: Type[TT]) -> TT:
bot = ctx.bot
match = IDConverter._get_id_match(argument) or re.match(r'<#([0-9]{15,20})>$', argument)
@ -501,7 +500,7 @@ class TextChannelConverter(IDConverter[discord.TextChannel]):
Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument`
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.TextChannel:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.TextChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'text_channels', discord.TextChannel)
@ -521,7 +520,7 @@ class VoiceChannelConverter(IDConverter[discord.VoiceChannel]):
Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument`
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.VoiceChannel:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.VoiceChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'voice_channels', discord.VoiceChannel)
@ -540,7 +539,7 @@ class StageChannelConverter(IDConverter[discord.StageChannel]):
3. Lookup by name
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.StageChannel:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.StageChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'stage_channels', discord.StageChannel)
@ -560,7 +559,7 @@ class CategoryChannelConverter(IDConverter[discord.CategoryChannel]):
Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument`
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.CategoryChannel:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.CategoryChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'categories', discord.CategoryChannel)
@ -579,7 +578,7 @@ class StoreChannelConverter(IDConverter[discord.StoreChannel]):
.. versionadded:: 1.7
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.StoreChannel:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.StoreChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'channels', discord.StoreChannel)
@ -597,7 +596,7 @@ class ThreadConverter(IDConverter[discord.Thread]):
.. versionadded: 2.0
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Thread:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Thread:
return GuildChannelConverter._resolve_thread(ctx, argument, 'threads', discord.Thread)
@ -629,7 +628,7 @@ class ColourConverter(Converter[discord.Colour]):
RGB_REGEX = re.compile(r'rgb\s*\((?P<r>[0-9]{1,3}%?)\s*,\s*(?P<g>[0-9]{1,3}%?)\s*,\s*(?P<b>[0-9]{1,3}%?)\s*\)')
def parse_hex_number(self, argument):
def parse_hex_number(self, argument: str) -> discord.Colour:
arg = ''.join(i * 2 for i in argument) if len(argument) == 3 else argument
try:
value = int(arg, base=16)
@ -640,7 +639,7 @@ class ColourConverter(Converter[discord.Colour]):
else:
return discord.Color(value=value)
def parse_rgb_number(self, argument, number):
def parse_rgb_number(self, argument: str, number: str) -> int:
if number[-1] == '%':
value = int(number[:-1])
if not (0 <= value <= 100):
@ -652,7 +651,7 @@ class ColourConverter(Converter[discord.Colour]):
raise BadColourArgument(argument)
return value
def parse_rgb(self, argument, *, regex=RGB_REGEX):
def parse_rgb(self, argument: str, *, regex: re.Pattern[str] = RGB_REGEX) -> discord.Colour:
match = regex.match(argument)
if match is None:
raise BadColourArgument(argument)
@ -662,7 +661,7 @@ class ColourConverter(Converter[discord.Colour]):
blue = self.parse_rgb_number(argument, match.group('b'))
return discord.Color.from_rgb(red, green, blue)
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Colour:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Colour:
if argument[0] == '#':
return self.parse_hex_number(argument[1:])
@ -703,7 +702,7 @@ class RoleConverter(IDConverter[discord.Role]):
Raise :exc:`.RoleNotFound` instead of generic :exc:`.BadArgument`
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Role:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Role:
guild = ctx.guild
if not guild:
raise NoPrivateMessage()
@ -722,7 +721,7 @@ class RoleConverter(IDConverter[discord.Role]):
class GameConverter(Converter[discord.Game]):
"""Converts to :class:`~discord.Game`."""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Game:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Game:
return discord.Game(name=argument)
@ -735,7 +734,7 @@ class InviteConverter(Converter[discord.Invite]):
Raise :exc:`.BadInviteArgument` instead of generic :exc:`.BadArgument`
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Invite:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Invite:
try:
invite = await ctx.bot.fetch_invite(argument)
return invite
@ -754,7 +753,7 @@ class GuildConverter(IDConverter[discord.Guild]):
.. versionadded:: 1.7
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Guild:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Guild:
match = self._get_id_match(argument)
result = None
@ -786,7 +785,7 @@ class EmojiConverter(IDConverter[discord.Emoji]):
Raise :exc:`.EmojiNotFound` instead of generic :exc:`.BadArgument`
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.Emoji:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.Emoji:
match = self._get_id_match(argument) or re.match(r'<a?:[a-zA-Z0-9\_]{1,32}:([0-9]{15,20})>$', argument)
result = None
bot = ctx.bot
@ -820,7 +819,7 @@ class PartialEmojiConverter(Converter[discord.PartialEmoji]):
Raise :exc:`.PartialEmojiConversionFailure` instead of generic :exc:`.BadArgument`
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.PartialEmoji:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.PartialEmoji:
match = re.match(r'<(a?):([a-zA-Z0-9\_]{1,32}):([0-9]{15,20})>$', argument)
if match:
@ -844,12 +843,12 @@ class GuildStickerConverter(IDConverter[discord.GuildSticker]):
The lookup strategy is as follows (in order):
1. Lookup by ID.
3. Lookup by name
2. Lookup by name.
.. versionadded:: 2.0
"""
async def convert(self, ctx: Context[_Bot], argument: str) -> discord.GuildSticker:
async def convert(self, ctx: Context[BotT], argument: str) -> discord.GuildSticker:
match = self._get_id_match(argument)
result = None
bot = ctx.bot
@ -874,6 +873,65 @@ class GuildStickerConverter(IDConverter[discord.GuildSticker]):
return result
class ScheduledEventConverter(IDConverter[discord.ScheduledEvent]):
"""Converts to a :class:`~discord.ScheduledEvent`.
Lookups are done for the local guild if available. Otherwise, for a DM context,
lookup is done by the global cache.
The lookup strategy is as follows (in order):
1. Lookup by ID.
2. Lookup by url.
3. Lookup by name.
.. versionadded:: 2.0
"""
async def convert(self, ctx: Context[BotT], argument: str) -> discord.ScheduledEvent:
guild = ctx.guild
match = self._get_id_match(argument)
result = None
if match:
# ID match
event_id = int(match.group(1))
if guild:
result = guild.get_scheduled_event(event_id)
else:
for guild in ctx.bot.guilds:
result = guild.get_scheduled_event(event_id)
if result:
break
else:
pattern = (
r'https?://(?:(ptb|canary|www)\.)?discord\.com/events/'
r'(?P<guild_id>[0-9]{15,20})/'
r'(?P<event_id>[0-9]{15,20})$'
)
match = re.match(pattern, argument, flags=re.I)
if match:
# URL match
guild = ctx.bot.get_guild(int(match.group('guild_id')))
if guild:
event_id = int(match.group('event_id'))
result = guild.get_scheduled_event(event_id)
else:
# lookup by name
if guild:
result = discord.utils.get(guild.scheduled_events, name=argument)
else:
for guild in ctx.bot.guilds:
result = discord.utils.get(guild.scheduled_events, name=argument)
if result:
break
if result is None:
raise ScheduledEventNotFound(argument)
return result
class clean_content(Converter[str]):
"""Converts the argument to mention scrubbed version of
said content.
@ -907,7 +965,7 @@ class clean_content(Converter[str]):
self.escape_markdown = escape_markdown
self.remove_markdown = remove_markdown
async def convert(self, ctx: Context[_Bot], argument: str) -> str:
async def convert(self, ctx: Context[BotT], argument: str) -> str:
msg = ctx.message
if ctx.guild:
@ -924,7 +982,7 @@ class clean_content(Converter[str]):
def resolve_member(id: int) -> str:
m = _utils_get(msg.mentions, id=id) or ctx.bot.get_user(id)
return f'@{m.name}' if m else '@deleted-user'
return f'@{m.display_name}' if m else '@deleted-user'
def resolve_role(id: int) -> str:
return '@deleted-role'
@ -932,7 +990,7 @@ class clean_content(Converter[str]):
if self.fix_channel_mentions and ctx.guild:
def resolve_channel(id: int) -> str:
c = ctx.guild.get_channel(id) # type: ignore
c = ctx.guild._resolve_channel(id) # type: ignore
return f'#{c.name}' if c else '#deleted-channel'
else:
@ -987,10 +1045,10 @@ class Greedy(List[T]):
__slots__ = ('converter',)
def __init__(self, *, converter: T):
self.converter = converter
def __init__(self, *, converter: T) -> None:
self.converter: T = converter
def __repr__(self):
def __repr__(self) -> str:
converter = getattr(self.converter, '__name__', repr(self.converter))
return f'Greedy[{converter}]'
@ -1039,11 +1097,11 @@ def get_converter(param: inspect.Parameter) -> Any:
_GenericAlias = type(List[T])
def is_generic_type(tp: Any, *, _GenericAlias: Type = _GenericAlias) -> bool:
return isinstance(tp, type) and issubclass(tp, Generic) or isinstance(tp, _GenericAlias) # type: ignore
def is_generic_type(tp: Any, *, _GenericAlias: type = _GenericAlias) -> bool:
return isinstance(tp, type) and issubclass(tp, Generic) or isinstance(tp, _GenericAlias)
CONVERTER_MAPPING: Dict[Type[Any], Any] = {
CONVERTER_MAPPING: Dict[type, Any] = {
discord.Object: ObjectConverter,
discord.Member: MemberConverter,
discord.User: UserConverter,
@ -1064,10 +1122,11 @@ CONVERTER_MAPPING: Dict[Type[Any], Any] = {
discord.Thread: ThreadConverter,
discord.abc.GuildChannel: GuildChannelConverter,
discord.GuildSticker: GuildStickerConverter,
discord.ScheduledEvent: ScheduledEventConverter,
}
async def _actual_conversion(ctx: Context, converter, argument: str, param: inspect.Parameter):
async def _actual_conversion(ctx: Context[BotT], converter, argument: str, param: inspect.Parameter):
if converter is bool:
return _convert_to_bool(argument)
@ -1105,7 +1164,7 @@ async def _actual_conversion(ctx: Context, converter, argument: str, param: insp
raise BadArgument(f'Converting to "{name}" failed for parameter "{param.name}".') from exc
async def run_converters(ctx: Context, converter, argument: str, param: inspect.Parameter):
async def run_converters(ctx: Context[BotT], converter: Any, argument: str, param: inspect.Parameter) -> Any:
"""|coro|
Runs converters for a given converter, argument, and parameter.

4
discord/ext/commands/cooldowns.py

@ -220,7 +220,7 @@ class CooldownMapping:
return self._type
@classmethod
def from_cooldown(cls, rate, per, type) -> Self:
def from_cooldown(cls, rate: float, per: float, type: Callable[[Message], Any]) -> Self:
return cls(Cooldown(rate, per), type)
def _bucket_key(self, msg: Message) -> Any:
@ -297,7 +297,7 @@ class _Semaphore:
def __init__(self, number: int) -> None:
self.value: int = number
self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
self.loop: asyncio.AbstractEventLoop = asyncio.get_running_loop()
self._waiters: Deque[asyncio.Future] = deque()
def __repr__(self) -> str:

195
discord/ext/commands/core.py

@ -61,11 +61,15 @@ if TYPE_CHECKING:
from discord.message import Message
from ._types import (
BotT,
ContextT,
Coro,
CoroFunc,
Check,
Hook,
Error,
ErrorT,
HookT,
)
@ -101,10 +105,8 @@ MISSING: Any = discord.utils.MISSING
T = TypeVar('T')
CogT = TypeVar('CogT', bound='Optional[Cog]')
CommandT = TypeVar('CommandT', bound='Command')
ContextT = TypeVar('ContextT', bound='Context')
# CHT = TypeVar('CHT', bound='Check')
GroupT = TypeVar('GroupT', bound='Group')
FuncT = TypeVar('FuncT', bound=Callable[..., Any])
if TYPE_CHECKING:
P = ParamSpec('P')
@ -112,7 +114,7 @@ else:
P = TypeVar('P')
def unwrap_function(function: Callable[..., Any]) -> Callable[..., Any]:
def unwrap_function(function: Callable[..., Any], /) -> Callable[..., Any]:
partial = functools.partial
while True:
if hasattr(function, '__wrapped__'):
@ -126,6 +128,7 @@ def unwrap_function(function: Callable[..., Any]) -> Callable[..., Any]:
def get_signature_parameters(
function: Callable[..., Any],
globalns: Dict[str, Any],
/,
*,
skip_parameters: Optional[int] = None,
) -> Dict[str, inspect.Parameter]:
@ -159,9 +162,9 @@ def get_signature_parameters(
return params
def wrap_callback(coro):
def wrap_callback(coro: Callable[P, Coro[T]], /) -> Callable[P, Coro[Optional[T]]]:
@functools.wraps(coro)
async def wrapped(*args, **kwargs):
async def wrapped(*args: P.args, **kwargs: P.kwargs) -> Optional[T]:
try:
ret = await coro(*args, **kwargs)
except CommandError:
@ -175,9 +178,11 @@ def wrap_callback(coro):
return wrapped
def hooked_wrapped_callback(command, ctx, coro):
def hooked_wrapped_callback(
command: Command[Any, ..., Any], ctx: Context[BotT], coro: Callable[P, Coro[T]], /
) -> Callable[P, Coro[Optional[T]]]:
@functools.wraps(coro)
async def wrapped(*args, **kwargs):
async def wrapped(*args: P.args, **kwargs: P.kwargs) -> Optional[T]:
try:
ret = await coro(*args, **kwargs)
except CommandError:
@ -191,7 +196,7 @@ def hooked_wrapped_callback(command, ctx, coro):
raise CommandInvokeError(exc) from exc
finally:
if command._max_concurrency is not None:
await command._max_concurrency.release(ctx)
await command._max_concurrency.release(ctx.message)
await command.call_after_hooks(ctx)
return ret
@ -318,6 +323,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
Callable[Concatenate[CogT, ContextT, P], Coro[T]],
Callable[Concatenate[ContextT, P], Coro[T]],
],
/,
**kwargs: Any,
) -> None:
if not asyncio.iscoroutinefunction(func):
@ -359,7 +365,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
except AttributeError:
checks = kwargs.get('checks', [])
self.checks: List[Check] = checks
self.checks: List[Check[ContextT]] = checks
try:
cooldown = func.__commands_cooldown__
@ -387,8 +393,8 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
self.cog: CogT = None
# bandaid for the fact that sometimes parent can be the bot instance
parent = kwargs.get('parent')
self.parent: Optional[GroupMixin] = parent if isinstance(parent, _BaseCommand) else None # type: ignore
parent: Optional[GroupMixin[Any]] = kwargs.get('parent')
self.parent: Optional[GroupMixin[Any]] = parent if isinstance(parent, _BaseCommand) else None
self._before_invoke: Optional[Hook] = None
try:
@ -422,16 +428,16 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
) -> None:
self._callback = function
unwrap = unwrap_function(function)
self.module = unwrap.__module__
self.module: str = unwrap.__module__
try:
globalns = unwrap.__globals__
except AttributeError:
globalns = {}
self.params = get_signature_parameters(function, globalns)
self.params: Dict[str, inspect.Parameter] = get_signature_parameters(function, globalns)
def add_check(self, func: Check, /) -> None:
def add_check(self, func: Check[ContextT], /) -> None:
"""Adds a check to the command.
This is the non-decorator interface to :func:`.check`.
@ -450,7 +456,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
self.checks.append(func)
def remove_check(self, func: Check, /) -> None:
def remove_check(self, func: Check[ContextT], /) -> None:
"""Removes a check from the command.
This function is idempotent and will not raise an exception
@ -476,7 +482,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
def update(self, **kwargs: Any) -> None:
"""Updates :class:`Command` instance with updated attribute.
This works similarly to the :func:`.command` decorator in terms
This works similarly to the :func:`~discord.ext.commands.command` decorator in terms
of parameters in that they are passed to the :class:`Command` or
subclass constructors, sans the name and callback.
"""
@ -484,7 +490,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
self.__init__(self.callback, **dict(self.__original_kwargs__, **kwargs))
self.cog = cog
async def __call__(self, context: Context, *args: P.args, **kwargs: P.kwargs) -> T:
async def __call__(self, context: Context[BotT], /, *args: P.args, **kwargs: P.kwargs) -> T:
"""|coro|
Calls the internal callback that the command holds.
@ -496,6 +502,10 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
the proper arguments and types to this function.
.. versionadded:: 1.3
.. versionchanged:: 2.0
``context`` parameter is now positional-only.
"""
if self.cog is not None:
return await self.callback(self.cog, context, *args, **kwargs) # type: ignore
@ -539,7 +549,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
else:
return self.copy()
async def dispatch_error(self, ctx: Context, error: Exception) -> None:
async def dispatch_error(self, ctx: Context[BotT], error: CommandError, /) -> None:
ctx.command_failed = True
cog = self.cog
try:
@ -549,7 +559,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
else:
injected = wrap_callback(coro)
if cog is not None:
await injected(cog, ctx, error)
await injected(cog, ctx, error) # type: ignore
else:
await injected(ctx, error)
@ -562,7 +572,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
finally:
ctx.bot.dispatch('command_error', ctx, error)
async def transform(self, ctx: Context, param: inspect.Parameter) -> Any:
async def transform(self, ctx: Context[BotT], param: inspect.Parameter, /) -> Any:
required = param.default is param.empty
converter = get_converter(param)
consume_rest_is_special = param.kind == param.KEYWORD_ONLY and not self.rest_is_raw
@ -610,7 +620,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
# type-checker fails to narrow argument
return await run_converters(ctx, converter, argument, param) # type: ignore
async def _transform_greedy_pos(self, ctx: Context, param: inspect.Parameter, required: bool, converter: Any) -> Any:
async def _transform_greedy_pos(
self, ctx: Context[BotT], param: inspect.Parameter, required: bool, converter: Any
) -> Any:
view = ctx.view
result = []
while not view.eof:
@ -631,7 +643,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
return param.default
return result
async def _transform_greedy_var_pos(self, ctx: Context, param: inspect.Parameter, converter: Any) -> Any:
async def _transform_greedy_var_pos(self, ctx: Context[BotT], param: inspect.Parameter, converter: Any) -> Any:
view = ctx.view
previous = view.index
try:
@ -669,7 +681,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
return ' '.join(reversed(entries))
@property
def parents(self) -> List[Group]:
def parents(self) -> List[Group[Any, ..., Any]]:
"""List[:class:`Group`]: Retrieves the parents of this command.
If the command has no parents then it returns an empty :class:`list`.
@ -687,7 +699,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
return entries
@property
def root_parent(self) -> Optional[Group]:
def root_parent(self) -> Optional[Group[Any, ..., Any]]:
"""Optional[:class:`Group`]: Retrieves the root parent of this command.
If the command has no parents then it returns ``None``.
@ -716,7 +728,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
def __str__(self) -> str:
return self.qualified_name
async def _parse_arguments(self, ctx: Context) -> None:
async def _parse_arguments(self, ctx: Context[BotT]) -> None:
ctx.args = [ctx] if self.cog is None else [self.cog, ctx]
ctx.kwargs = {}
args = ctx.args
@ -752,7 +764,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
if not self.ignore_extra and not view.eof:
raise TooManyArguments('Too many arguments passed to ' + self.qualified_name)
async def call_before_hooks(self, ctx: Context) -> None:
async def call_before_hooks(self, ctx: Context[BotT], /) -> None:
# now that we're done preparing we can call the pre-command hooks
# first, call the command local hook:
cog = self.cog
@ -777,7 +789,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
if hook is not None:
await hook(ctx)
async def call_after_hooks(self, ctx: Context) -> None:
async def call_after_hooks(self, ctx: Context[BotT], /) -> None:
cog = self.cog
if self._after_invoke is not None:
instance = getattr(self._after_invoke, '__self__', cog)
@ -796,7 +808,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
if hook is not None:
await hook(ctx)
def _prepare_cooldowns(self, ctx: Context) -> None:
def _prepare_cooldowns(self, ctx: Context[BotT]) -> None:
if self._buckets.valid:
dt = ctx.message.edited_at or ctx.message.created_at
current = dt.replace(tzinfo=datetime.timezone.utc).timestamp()
@ -806,7 +818,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
if retry_after:
raise CommandOnCooldown(bucket, retry_after, self._buckets.type) # type: ignore
async def prepare(self, ctx: Context) -> None:
async def prepare(self, ctx: Context[BotT], /) -> None:
ctx.command = self
if not await self.can_run(ctx):
@ -830,9 +842,13 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
await self._max_concurrency.release(ctx) # type: ignore
raise
def is_on_cooldown(self, ctx: Context) -> bool:
def is_on_cooldown(self, ctx: Context[BotT], /) -> bool:
"""Checks whether the command is currently on cooldown.
.. versionchanged:: 2.0
``ctx`` parameter is now positional-only.
Parameters
-----------
ctx: :class:`.Context`
@ -851,9 +867,13 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
current = dt.replace(tzinfo=datetime.timezone.utc).timestamp()
return bucket.get_tokens(current) == 0
def reset_cooldown(self, ctx: Context) -> None:
def reset_cooldown(self, ctx: Context[BotT], /) -> None:
"""Resets the cooldown on this command.
.. versionchanged:: 2.0
``ctx`` parameter is now positional-only.
Parameters
-----------
ctx: :class:`.Context`
@ -863,11 +883,15 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
bucket = self._buckets.get_bucket(ctx.message)
bucket.reset()
def get_cooldown_retry_after(self, ctx: Context) -> float:
def get_cooldown_retry_after(self, ctx: Context[BotT], /) -> float:
"""Retrieves the amount of seconds before this command can be tried again.
.. versionadded:: 1.4
.. versionchanged:: 2.0
``ctx`` parameter is now positional-only.
Parameters
-----------
ctx: :class:`.Context`
@ -887,7 +911,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
return 0.0
async def invoke(self, ctx: Context) -> None:
async def invoke(self, ctx: Context[BotT], /) -> None:
await self.prepare(ctx)
# terminate the invoked_subcommand chain.
@ -896,9 +920,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
ctx.invoked_subcommand = None
ctx.subcommand_passed = None
injected = hooked_wrapped_callback(self, ctx, self.callback)
await injected(*ctx.args, **ctx.kwargs)
await injected(*ctx.args, **ctx.kwargs) # type: ignore
async def reinvoke(self, ctx: Context, *, call_hooks: bool = False) -> None:
async def reinvoke(self, ctx: Context[BotT], /, *, call_hooks: bool = False) -> None:
ctx.command = self
await self._parse_arguments(ctx)
@ -915,13 +939,17 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
if call_hooks:
await self.call_after_hooks(ctx)
def error(self, coro: FuncT) -> FuncT:
def error(self, coro: ErrorT, /) -> ErrorT:
"""A decorator that registers a coroutine as a local error handler.
A local error handler is an :func:`.on_command_error` event limited to
a single command. However, the :func:`.on_command_error` is still
invoked afterwards as the catch-all.
.. versionchanged:: 2.0
``coro`` parameter is now positional-only.
Parameters
-----------
coro: :ref:`coroutine <coroutine>`
@ -936,7 +964,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
if not asyncio.iscoroutinefunction(coro):
raise TypeError('The error handler must be a coroutine.')
self.on_error: Error = coro
self.on_error: Error[Any] = coro
return coro
def has_error_handler(self) -> bool:
@ -946,7 +974,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
"""
return hasattr(self, 'on_error')
def before_invoke(self, coro: FuncT) -> FuncT:
def before_invoke(self, coro: HookT, /) -> HookT:
"""A decorator that registers a coroutine as a pre-invoke hook.
A pre-invoke hook is called directly before the command is
@ -957,6 +985,10 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
See :meth:`.Bot.before_invoke` for more info.
.. versionchanged:: 2.0
``coro`` parameter is now positional-only.
Parameters
-----------
coro: :ref:`coroutine <coroutine>`
@ -973,7 +1005,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
self._before_invoke = coro
return coro
def after_invoke(self, coro: FuncT) -> FuncT:
def after_invoke(self, coro: HookT, /) -> HookT:
"""A decorator that registers a coroutine as a post-invoke hook.
A post-invoke hook is called directly after the command is
@ -984,6 +1016,10 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
See :meth:`.Bot.after_invoke` for more info.
.. versionchanged:: 2.0
``coro`` parameter is now positional-only.
Parameters
-----------
coro: :ref:`coroutine <coroutine>`
@ -1075,7 +1111,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
return ' '.join(result)
async def can_run(self, ctx: Context) -> bool:
async def can_run(self, ctx: Context[BotT], /) -> bool:
"""|coro|
Checks if the command can be executed by checking all the predicates
@ -1085,6 +1121,10 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
.. versionchanged:: 1.3
Checks whether the command is disabled or not
.. versionchanged:: 2.0
``ctx`` parameter is now positional-only.
Parameters
-----------
ctx: :class:`.Context`
@ -1341,11 +1381,11 @@ class GroupMixin(Generic[CogT]):
def command(
self,
name: str = MISSING,
cls: Type[Command] = MISSING,
cls: Type[Command[Any, ..., Any]] = MISSING,
*args: Any,
**kwargs: Any,
) -> Any:
"""A shortcut decorator that invokes :func:`.command` and adds it to
"""A shortcut decorator that invokes :func:`~discord.ext.commands.command` and adds it to
the internal command list via :meth:`~.GroupMixin.add_command`.
Returns
@ -1401,7 +1441,7 @@ class GroupMixin(Generic[CogT]):
def group(
self,
name: str = MISSING,
cls: Type[Group] = MISSING,
cls: Type[Group[Any, ..., Any]] = MISSING,
*args: Any,
**kwargs: Any,
) -> Any:
@ -1461,9 +1501,9 @@ class Group(GroupMixin[CogT], Command[CogT, P, T]):
ret = super().copy()
for cmd in self.commands:
ret.add_command(cmd.copy())
return ret # type: ignore
return ret
async def invoke(self, ctx: Context) -> None:
async def invoke(self, ctx: Context[BotT], /) -> None:
ctx.invoked_subcommand = None
ctx.subcommand_passed = None
early_invoke = not self.invoke_without_command
@ -1481,7 +1521,7 @@ class Group(GroupMixin[CogT], Command[CogT, P, T]):
if early_invoke:
injected = hooked_wrapped_callback(self, ctx, self.callback)
await injected(*ctx.args, **ctx.kwargs)
await injected(*ctx.args, **ctx.kwargs) # type: ignore
ctx.invoked_parents.append(ctx.invoked_with) # type: ignore
@ -1494,7 +1534,7 @@ class Group(GroupMixin[CogT], Command[CogT, P, T]):
view.previous = previous
await super().invoke(ctx)
async def reinvoke(self, ctx: Context, *, call_hooks: bool = False) -> None:
async def reinvoke(self, ctx: Context[BotT], /, *, call_hooks: bool = False) -> None:
ctx.invoked_subcommand = None
early_invoke = not self.invoke_without_command
if early_invoke:
@ -1592,7 +1632,7 @@ def command(
def command(
name: str = MISSING,
cls: Type[Command] = MISSING,
cls: Type[Command[Any, ..., Any]] = MISSING,
**attrs: Any,
) -> Any:
"""A decorator that transforms a function into a :class:`.Command`
@ -1662,12 +1702,12 @@ def group(
def group(
name: str = MISSING,
cls: Type[Group] = MISSING,
cls: Type[Group[Any, ..., Any]] = MISSING,
**attrs: Any,
) -> Any:
"""A decorator that transforms a function into a :class:`.Group`.
This is similar to the :func:`.command` decorator but the ``cls``
This is similar to the :func:`~discord.ext.commands.command` decorator but the ``cls``
parameter is set to :class:`Group` by default.
.. versionchanged:: 1.1
@ -1679,7 +1719,7 @@ def group(
return command(name=name, cls=cls, **attrs)
def check(predicate: Check) -> Callable[[T], T]:
def check(predicate: Check[ContextT], /) -> Callable[[T], T]:
r"""A decorator that adds a check to the :class:`.Command` or its
subclasses. These checks could be accessed via :attr:`.Command.checks`.
@ -1744,6 +1784,10 @@ def check(predicate: Check) -> Callable[[T], T]:
async def only_me(ctx):
await ctx.send('Only you!')
.. versionchanged:: 2.0
``predicate`` parameter is now positional-only.
Parameters
-----------
predicate: Callable[[:class:`Context`], :class:`bool`]
@ -1774,7 +1818,7 @@ def check(predicate: Check) -> Callable[[T], T]:
return decorator # type: ignore
def check_any(*checks: Check) -> Callable[[T], T]:
def check_any(*checks: Check[ContextT]) -> Callable[[T], T]:
r"""A :func:`check` that is added that checks if any of the checks passed
will pass, i.e. using logical OR.
@ -1827,7 +1871,7 @@ def check_any(*checks: Check) -> Callable[[T], T]:
else:
unwrapped.append(pred)
async def predicate(ctx: Context) -> bool:
async def predicate(ctx: Context[BotT]) -> bool:
errors = []
for func in unwrapped:
try:
@ -1843,7 +1887,7 @@ def check_any(*checks: Check) -> Callable[[T], T]:
return check(predicate)
def has_role(item: Union[int, str]) -> Callable[[T], T]:
def has_role(item: Union[int, str], /) -> Callable[[T], T]:
"""A :func:`.check` that is added that checks if the member invoking the
command has the role specified via the name or ID specified.
@ -1864,13 +1908,17 @@ def has_role(item: Union[int, str]) -> Callable[[T], T]:
Raise :exc:`.MissingRole` or :exc:`.NoPrivateMessage`
instead of generic :exc:`.CheckFailure`
.. versionchanged:: 2.0
``item`` parameter is now positional-only.
Parameters
-----------
item: Union[:class:`int`, :class:`str`]
The name or ID of the role to check.
"""
def predicate(ctx: Context) -> bool:
def predicate(ctx: Context[BotT]) -> bool:
if ctx.guild is None:
raise NoPrivateMessage()
@ -1923,7 +1971,7 @@ def has_any_role(*items: Union[int, str]) -> Callable[[T], T]:
raise NoPrivateMessage()
# ctx.guild is None doesn't narrow ctx.author to Member
getter = functools.partial(discord.utils.get, ctx.author.roles) # type: ignore
getter = functools.partial(discord.utils.get, ctx.author.roles)
if any(getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None for item in items):
return True
raise MissingAnyRole(list(items))
@ -1931,7 +1979,7 @@ def has_any_role(*items: Union[int, str]) -> Callable[[T], T]:
return check(predicate)
def bot_has_role(item: int) -> Callable[[T], T]:
def bot_has_role(item: int, /) -> Callable[[T], T]:
"""Similar to :func:`.has_role` except checks if the bot itself has the
role.
@ -1943,6 +1991,10 @@ def bot_has_role(item: int) -> Callable[[T], T]:
Raise :exc:`.BotMissingRole` or :exc:`.NoPrivateMessage`
instead of generic :exc:`.CheckFailure`
.. versionchanged:: 2.0
``item`` parameter is now positional-only.
"""
def predicate(ctx):
@ -2022,7 +2074,7 @@ def has_permissions(**perms: bool) -> Callable[[T], T]:
if invalid:
raise TypeError(f"Invalid permission(s): {', '.join(invalid)}")
def predicate(ctx: Context) -> bool:
def predicate(ctx: Context[BotT]) -> bool:
ch = ctx.channel
permissions = ch.permissions_for(ctx.author) # type: ignore
@ -2048,7 +2100,7 @@ def bot_has_permissions(**perms: bool) -> Callable[[T], T]:
if invalid:
raise TypeError(f"Invalid permission(s): {', '.join(invalid)}")
def predicate(ctx: Context) -> bool:
def predicate(ctx: Context[BotT]) -> bool:
guild = ctx.guild
me = guild.me if guild is not None else ctx.bot.user
permissions = ctx.channel.permissions_for(me) # type: ignore
@ -2077,7 +2129,7 @@ def has_guild_permissions(**perms: bool) -> Callable[[T], T]:
if invalid:
raise TypeError(f"Invalid permission(s): {', '.join(invalid)}")
def predicate(ctx: Context) -> bool:
def predicate(ctx: Context[BotT]) -> bool:
if not ctx.guild:
raise NoPrivateMessage
@ -2103,7 +2155,7 @@ def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]:
if invalid:
raise TypeError(f"Invalid permission(s): {', '.join(invalid)}")
def predicate(ctx: Context) -> bool:
def predicate(ctx: Context[BotT]) -> bool:
if not ctx.guild:
raise NoPrivateMessage
@ -2129,7 +2181,7 @@ def dm_only() -> Callable[[T], T]:
.. versionadded:: 1.1
"""
def predicate(ctx: Context) -> bool:
def predicate(ctx: Context[BotT]) -> bool:
if ctx.guild is not None:
raise PrivateMessageOnly()
return True
@ -2146,7 +2198,7 @@ def guild_only() -> Callable[[T], T]:
that is inherited from :exc:`.CheckFailure`.
"""
def predicate(ctx: Context) -> bool:
def predicate(ctx: Context[BotT]) -> bool:
if ctx.guild is None:
raise NoPrivateMessage()
return True
@ -2164,7 +2216,7 @@ def is_owner() -> Callable[[T], T]:
from :exc:`.CheckFailure`.
"""
async def predicate(ctx: Context) -> bool:
async def predicate(ctx: Context[BotT]) -> bool:
if not await ctx.bot.is_owner(ctx.author):
raise NotOwner('You do not own this bot.')
return True
@ -2184,7 +2236,7 @@ def is_nsfw() -> Callable[[T], T]:
DM channels will also now pass this check.
"""
def pred(ctx: Context) -> bool:
def pred(ctx: Context[BotT]) -> bool:
ch = ctx.channel
if ctx.guild is None or (isinstance(ch, (discord.TextChannel, discord.Thread)) and ch.is_nsfw()):
return True
@ -2314,7 +2366,7 @@ def max_concurrency(number: int, per: BucketType = BucketType.default, *, wait:
return decorator # type: ignore
def before_invoke(coro) -> Callable[[T], T]:
def before_invoke(coro: Hook[ContextT], /) -> Callable[[T], T]:
"""A decorator that registers a coroutine as a pre-invoke hook.
This allows you to refer to one before invoke hook for several commands that
@ -2322,6 +2374,10 @@ def before_invoke(coro) -> Callable[[T], T]:
.. versionadded:: 1.4
.. versionchanged:: 2.0
``coro`` parameter is now positional-only.
Example
---------
@ -2350,7 +2406,6 @@ def before_invoke(coro) -> Callable[[T], T]:
async def why(self, ctx): # Output: <Nothing>
await ctx.send('because someone made me')
bot.add_cog(What())
"""
def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]:
@ -2363,13 +2418,17 @@ def before_invoke(coro) -> Callable[[T], T]:
return decorator # type: ignore
def after_invoke(coro) -> Callable[[T], T]:
def after_invoke(coro: Hook[ContextT], /) -> Callable[[T], T]:
"""A decorator that registers a coroutine as a post-invoke hook.
This allows you to refer to one after invoke hook for several commands that
do not have to be within the same cog.
.. versionadded:: 1.4
.. versionchanged:: 2.0
``coro`` parameter is now positional-only.
"""
def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]:

33
discord/ext/commands/errors.py

@ -39,6 +39,8 @@ if TYPE_CHECKING:
from discord.threads import Thread
from discord.types.snowflake import Snowflake, SnowflakeList
from ._types import BotT
__all__ = (
'CommandError',
@ -70,6 +72,7 @@ __all__ = (
'BadInviteArgument',
'EmojiNotFound',
'GuildStickerNotFound',
'ScheduledEventNotFound',
'PartialEmojiConversionFailure',
'BadBoolArgument',
'MissingRole',
@ -134,8 +137,8 @@ class ConversionError(CommandError):
the ``__cause__`` attribute.
"""
def __init__(self, converter: Converter, original: Exception) -> None:
self.converter: Converter = converter
def __init__(self, converter: Converter[Any], original: Exception) -> None:
self.converter: Converter[Any] = converter
self.original: Exception = original
@ -223,9 +226,9 @@ class CheckAnyFailure(CheckFailure):
A list of check predicates that failed.
"""
def __init__(self, checks: List[CheckFailure], errors: List[Callable[[Context], bool]]) -> None:
def __init__(self, checks: List[CheckFailure], errors: List[Callable[[Context[BotT]], bool]]) -> None:
self.checks: List[CheckFailure] = checks
self.errors: List[Callable[[Context], bool]] = errors
self.errors: List[Callable[[Context[BotT]], bool]] = errors
super().__init__('You do not have permission to run this command.')
@ -515,6 +518,24 @@ class GuildStickerNotFound(BadArgument):
super().__init__(f'Sticker "{argument}" not found.')
class ScheduledEventNotFound(BadArgument):
"""Exception raised when the bot can not find the scheduled event.
This inherits from :exc:`BadArgument`
.. versionadded:: 2.0
Attributes
-----------
argument: :class:`str`
The event supplied by the caller that was not found
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'ScheduledEvent "{argument}" not found.')
class BadBoolArgument(BadArgument):
"""Exception raised when a boolean argument was not convertable.
@ -788,9 +809,9 @@ class BadUnionArgument(UserInputError):
A list of errors that were caught from failing the conversion.
"""
def __init__(self, param: Parameter, converters: Tuple[Type, ...], errors: List[CommandError]) -> None:
def __init__(self, param: Parameter, converters: Tuple[type, ...], errors: List[CommandError]) -> None:
self.param: Parameter = param
self.converters: Tuple[Type, ...] = converters
self.converters: Tuple[type, ...] = converters
self.errors: List[CommandError] = errors
def _get_name(x):

24
discord/ext/commands/flags.py

@ -49,8 +49,6 @@ from typing import (
Tuple,
List,
Any,
Type,
TypeVar,
Union,
)
@ -70,6 +68,8 @@ if TYPE_CHECKING:
from .context import Context
from ._types import BotT
@dataclass
class Flag:
@ -148,7 +148,7 @@ def flag(
return Flag(name=name, aliases=aliases, default=default, max_args=max_args, override=override)
def validate_flag_name(name: str, forbidden: Set[str]):
def validate_flag_name(name: str, forbidden: Set[str]) -> None:
if not name:
raise ValueError('flag names should not be empty')
@ -348,7 +348,7 @@ class FlagsMeta(type):
return type.__new__(cls, name, bases, attrs)
async def tuple_convert_all(ctx: Context, argument: str, flag: Flag, converter: Any) -> Tuple[Any, ...]:
async def tuple_convert_all(ctx: Context[BotT], argument: str, flag: Flag, converter: Any) -> Tuple[Any, ...]:
view = StringView(argument)
results = []
param: inspect.Parameter = ctx.current_parameter # type: ignore
@ -373,7 +373,7 @@ async def tuple_convert_all(ctx: Context, argument: str, flag: Flag, converter:
return tuple(results)
async def tuple_convert_flag(ctx: Context, argument: str, flag: Flag, converters: Any) -> Tuple[Any, ...]:
async def tuple_convert_flag(ctx: Context[BotT], argument: str, flag: Flag, converters: Any) -> Tuple[Any, ...]:
view = StringView(argument)
results = []
param: inspect.Parameter = ctx.current_parameter # type: ignore
@ -401,7 +401,7 @@ async def tuple_convert_flag(ctx: Context, argument: str, flag: Flag, converters
return tuple(results)
async def convert_flag(ctx, argument: str, flag: Flag, annotation: Any = None) -> Any:
async def convert_flag(ctx: Context[BotT], argument: str, flag: Flag, annotation: Any = None) -> Any:
param: inspect.Parameter = ctx.current_parameter # type: ignore
annotation = annotation or flag.annotation
try:
@ -480,12 +480,13 @@ class FlagConverter(metaclass=FlagsMeta):
yield (flag.name, getattr(self, flag.attribute))
@classmethod
async def _construct_default(cls, ctx: Context) -> Self:
async def _construct_default(cls, ctx: Context[BotT]) -> Self:
self = cls.__new__(cls)
flags = cls.__commands_flags__
for flag in flags.values():
if callable(flag.default):
default = await maybe_coroutine(flag.default, ctx)
# Type checker does not understand that flag.default is a Callable
default = await maybe_coroutine(flag.default, ctx) # type: ignore
setattr(self, flag.attribute, default)
else:
setattr(self, flag.attribute, flag.default)
@ -546,7 +547,7 @@ class FlagConverter(metaclass=FlagsMeta):
return result
@classmethod
async def convert(cls, ctx: Context, argument: str) -> Self:
async def convert(cls, ctx: Context[BotT], argument: str) -> Self:
"""|coro|
The method that actually converters an argument to the flag mapping.
@ -584,7 +585,8 @@ class FlagConverter(metaclass=FlagsMeta):
raise MissingRequiredFlag(flag)
else:
if callable(flag.default):
default = await maybe_coroutine(flag.default, ctx)
# Type checker does not understand flag.default is a Callable
default = await maybe_coroutine(flag.default, ctx) # type: ignore
setattr(self, flag.attribute, default)
else:
setattr(self, flag.attribute, flag.default)
@ -610,7 +612,7 @@ class FlagConverter(metaclass=FlagsMeta):
values = [await convert_flag(ctx, value, flag) for value in values]
if flag.cast_to_dict:
values = dict(values) # type: ignore
values = dict(values)
setattr(self, flag.attribute, values)

585
discord/ext/commands/help.py

File diff suppressed because it is too large

37
discord/ext/commands/view.py

@ -21,6 +21,11 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import Optional
from .errors import UnexpectedQuoteError, InvalidEndOfQuotedStringError, ExpectedClosingQuoteError
# map from opening quotes to closing quotes
@ -47,24 +52,24 @@ _all_quotes = set(_quotes.keys()) | set(_quotes.values())
class StringView:
def __init__(self, buffer):
self.index = 0
self.buffer = buffer
self.end = len(buffer)
def __init__(self, buffer: str) -> None:
self.index: int = 0
self.buffer: str = buffer
self.end: int = len(buffer)
self.previous = 0
@property
def current(self):
def current(self) -> Optional[str]:
return None if self.eof else self.buffer[self.index]
@property
def eof(self):
def eof(self) -> bool:
return self.index >= self.end
def undo(self):
def undo(self) -> None:
self.index = self.previous
def skip_ws(self):
def skip_ws(self) -> bool:
pos = 0
while not self.eof:
try:
@ -79,7 +84,7 @@ class StringView:
self.index += pos
return self.previous != self.index
def skip_string(self, string):
def skip_string(self, string: str) -> bool:
strlen = len(string)
if self.buffer[self.index : self.index + strlen] == string:
self.previous = self.index
@ -87,19 +92,19 @@ class StringView:
return True
return False
def read_rest(self):
def read_rest(self) -> str:
result = self.buffer[self.index :]
self.previous = self.index
self.index = self.end
return result
def read(self, n):
def read(self, n: int) -> str:
result = self.buffer[self.index : self.index + n]
self.previous = self.index
self.index += n
return result
def get(self):
def get(self) -> Optional[str]:
try:
result = self.buffer[self.index + 1]
except IndexError:
@ -109,7 +114,7 @@ class StringView:
self.index += 1
return result
def get_word(self):
def get_word(self) -> str:
pos = 0
while not self.eof:
try:
@ -119,12 +124,12 @@ class StringView:
pos += 1
except IndexError:
break
self.previous = self.index
self.previous: int = self.index
result = self.buffer[self.index : self.index + pos]
self.index += pos
return result
def get_quoted_word(self):
def get_quoted_word(self) -> Optional[str]:
current = self.current
if current is None:
return None
@ -187,5 +192,5 @@ class StringView:
result.append(current)
def __repr__(self):
def __repr__(self) -> str:
return f'<StringView pos: {self.index} prev: {self.previous} end: {self.end} eof: {self.eof}>'

186
discord/ext/tasks/__init__.py

@ -26,6 +26,7 @@ from __future__ import annotations
import asyncio
import datetime
import logging
from typing import (
Any,
Awaitable,
@ -48,6 +49,8 @@ from collections.abc import Sequence
from discord.backoff import ExponentialBackoff
from discord.utils import MISSING
_log = logging.getLogger(__name__)
# fmt: off
__all__ = (
'loop',
@ -61,19 +64,61 @@ FT = TypeVar('FT', bound=_func)
ET = TypeVar('ET', bound=Callable[[Any, BaseException], Awaitable[Any]])
def is_ambiguous(dt: datetime.datetime) -> bool:
if dt.tzinfo is None or isinstance(dt.tzinfo, datetime.timezone):
# Naive or fixed timezones are never ambiguous
return False
before = dt.replace(fold=0)
after = dt.replace(fold=1)
same_offset = before.utcoffset() == after.utcoffset()
same_dst = before.dst() == after.dst()
return not (same_offset and same_dst)
def is_imaginary(dt: datetime.datetime) -> bool:
if dt.tzinfo is None or isinstance(dt.tzinfo, datetime.timezone):
# Naive or fixed timezones are never imaginary
return False
tz = dt.tzinfo
dt = dt.replace(tzinfo=None)
roundtrip = dt.replace(tzinfo=tz).astimezone(datetime.timezone.utc).astimezone(tz).replace(tzinfo=None)
return dt != roundtrip
def resolve_datetime(dt: datetime.datetime) -> datetime.datetime:
if dt.tzinfo is None or isinstance(dt.tzinfo, datetime.timezone):
# Naive or fixed requires no resolution
return dt
if is_imaginary(dt):
# Largest gap is probably 24 hours
tomorrow = dt + datetime.timedelta(days=1)
yesterday = dt - datetime.timedelta(days=1)
# utcoffset shouldn't return None since these are aware instances
# If it returns None then the timezone implementation was broken from the get go
return dt + (tomorrow.utcoffset() - yesterday.utcoffset()) # type: ignore
elif is_ambiguous(dt):
return dt.replace(fold=1)
else:
return dt
class SleepHandle:
__slots__ = ('future', 'loop', 'handle')
def __init__(self, dt: datetime.datetime, *, loop: asyncio.AbstractEventLoop) -> None:
self.loop = loop
self.future = future = loop.create_future()
self.loop: asyncio.AbstractEventLoop = loop
self.future: asyncio.Future[None] = loop.create_future()
relative_delta = discord.utils.compute_timedelta(dt)
self.handle = loop.call_later(relative_delta, future.set_result, True)
self.handle = loop.call_later(relative_delta, self.future.set_result, True)
def recalculate(self, dt: datetime.datetime) -> None:
self.handle.cancel()
relative_delta = discord.utils.compute_timedelta(dt)
self.handle = self.loop.call_later(relative_delta, self.future.set_result, True)
self.handle: asyncio.TimerHandle = self.loop.call_later(relative_delta, self.future.set_result, True)
def wait(self) -> asyncio.Future[Any]:
return self.future
@ -101,15 +146,13 @@ class Loop(Generic[LF]):
time: Union[datetime.time, Sequence[datetime.time]],
count: Optional[int],
reconnect: bool,
loop: asyncio.AbstractEventLoop,
) -> None:
self.coro: LF = coro
self.reconnect: bool = reconnect
self.loop: asyncio.AbstractEventLoop = loop
self.count: Optional[int] = count
self._current_loop = 0
self._handle: SleepHandle = MISSING
self._task: asyncio.Task[None] = MISSING
self._handle: Optional[SleepHandle] = None
self._task: Optional[asyncio.Task[None]] = None
self._injected = None
self._valid_exception = (
OSError,
@ -147,16 +190,20 @@ class Loop(Generic[LF]):
await coro(*args, **kwargs)
def _try_sleep_until(self, dt: datetime.datetime):
self._handle = SleepHandle(dt=dt, loop=self.loop)
self._handle = SleepHandle(dt=dt, loop=asyncio.get_running_loop())
return self._handle.wait()
def _is_relative_time(self) -> bool:
return self._time is MISSING
def _is_explicit_time(self) -> bool:
return self._time is not MISSING
async def _loop(self, *args: Any, **kwargs: Any) -> None:
backoff = ExponentialBackoff()
await self._call_loop_function('before_loop')
self._last_iteration_failed = False
if self._time is not MISSING:
# the time index should be prepared every time the internal loop is started
self._prepare_time_index()
if self._is_explicit_time():
self._next_iteration = self._get_next_sleep_time()
else:
self._next_iteration = datetime.datetime.now(datetime.timezone.utc)
@ -166,11 +213,30 @@ class Loop(Generic[LF]):
return
while True:
# sleep before the body of the task for explicit time intervals
if self._time is not MISSING:
if self._is_explicit_time():
await self._try_sleep_until(self._next_iteration)
if not self._last_iteration_failed:
self._last_iteration = self._next_iteration
self._next_iteration = self._get_next_sleep_time()
# In order to account for clock drift, we need to ensure that
# the next iteration always follows the last iteration.
# Sometimes asyncio is cheeky and wakes up a few microseconds before our target
# time, causing it to repeat a run.
while self._is_explicit_time() and self._next_iteration <= self._last_iteration:
_log.warn(
(
'Clock drift detected for task %s. Woke up at %s but needed to sleep until %s. '
'Sleeping until %s again to correct clock'
),
self.coro.__qualname__,
discord.utils.utcnow(),
self._next_iteration,
self._next_iteration,
)
await self._try_sleep_until(self._next_iteration)
self._next_iteration = self._get_next_sleep_time()
try:
await self.coro(*args, **kwargs)
self._last_iteration_failed = False
@ -184,7 +250,7 @@ class Loop(Generic[LF]):
return
# sleep after the body of the task for relative time intervals
if self._time is MISSING:
if self._is_relative_time():
await self._try_sleep_until(self._next_iteration)
self._current_loop += 1
@ -200,7 +266,8 @@ class Loop(Generic[LF]):
raise exc
finally:
await self._call_loop_function('after_loop')
self._handle.cancel()
if self._handle:
self._handle.cancel()
self._is_being_cancelled = False
self._current_loop = 0
self._stop_next_iteration = False
@ -218,7 +285,6 @@ class Loop(Generic[LF]):
time=self._time,
count=self.count,
reconnect=self.reconnect,
loop=self.loop,
)
copy._injected = obj
copy._before_loop = self._before_loop
@ -325,16 +391,13 @@ class Loop(Generic[LF]):
The task that has been created.
"""
if self._task is not MISSING and not self._task.done():
if self._task and not self._task.done():
raise RuntimeError('Task is already launched and is not completed.')
if self._injected is not None:
args = (self._injected, *args)
if self.loop is MISSING:
self.loop = asyncio.get_event_loop()
self._task = self.loop.create_task(self._loop(*args, **kwargs))
self._task = asyncio.create_task(self._loop(*args, **kwargs))
return self._task
def stop(self) -> None:
@ -358,7 +421,7 @@ class Loop(Generic[LF]):
.. versionadded:: 1.2
"""
if self._task is not MISSING and not self._task.done():
if self._task and not self._task.done():
self._stop_next_iteration = True
def _can_be_cancelled(self) -> bool:
@ -366,7 +429,7 @@ class Loop(Generic[LF]):
def cancel(self) -> None:
"""Cancels the internal task, if it is running."""
if self._can_be_cancelled():
if self._can_be_cancelled() and self._task:
self._task.cancel()
def restart(self, *args: Any, **kwargs: Any) -> None:
@ -386,10 +449,11 @@ class Loop(Generic[LF]):
"""
def restart_when_over(fut: Any, *, args: Any = args, kwargs: Any = kwargs) -> None:
self._task.remove_done_callback(restart_when_over)
if self._task:
self._task.remove_done_callback(restart_when_over)
self.start(*args, **kwargs)
if self._can_be_cancelled():
if self._can_be_cancelled() and self._task:
self._task.add_done_callback(restart_when_over)
self._task.cancel()
@ -468,7 +532,7 @@ class Loop(Generic[LF]):
.. versionadded:: 1.4
"""
return not bool(self._task.done()) if self._task is not MISSING else False
return not bool(self._task.done()) if self._task else False
async def _error(self, *args: Any) -> None:
exception: Exception = args[-1]
@ -557,47 +621,50 @@ class Loop(Generic[LF]):
self._error = coro # type: ignore
return coro
def _get_next_sleep_time(self) -> datetime.datetime:
def _get_next_sleep_time(self, now: datetime.datetime = MISSING) -> datetime.datetime:
if self._sleep is not MISSING:
return self._last_iteration + datetime.timedelta(seconds=self._sleep)
if self._time_index >= len(self._time):
self._time_index = 0
if self._current_loop == 0:
# if we're at the last index on the first iteration, we need to sleep until tomorrow
return datetime.datetime.combine(
datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1), self._time[0]
)
next_time = self._time[self._time_index]
if now is MISSING:
now = datetime.datetime.now(datetime.timezone.utc)
if self._current_loop == 0:
self._time_index += 1
return datetime.datetime.combine(datetime.datetime.now(datetime.timezone.utc), next_time)
index = self._start_time_relative_to(now)
next_date = self._last_iteration
if self._time_index == 0:
# we can assume that the earliest time should be scheduled for "tomorrow"
next_date += datetime.timedelta(days=1)
if index is None:
time = self._time[0]
tomorrow = now.astimezone(time.tzinfo) + datetime.timedelta(days=1)
date = tomorrow.date()
else:
time = self._time[index]
date = now.astimezone(time.tzinfo).date()
self._time_index += 1
return datetime.datetime.combine(next_date, next_time)
dt = datetime.datetime.combine(date, time, tzinfo=time.tzinfo)
return resolve_datetime(dt)
def _prepare_time_index(self, now: datetime.datetime = MISSING) -> None:
def _start_time_relative_to(self, now: datetime.datetime) -> Optional[int]:
# now kwarg should be a datetime.datetime representing the time "now"
# to calculate the next time index from
# pre-condition: self._time is set
time_now = (
now if now is not MISSING else datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0)
).timetz()
idx = -1
# Sole time comparisons are apparently broken, therefore, attach today's date
# to it in order to make the comparisons make sense.
# For example, if given a list of times [0, 3, 18]
# If it's 04:00 today then we know we have to wait until 18:00 today
# If it's 19:00 today then we know we we have to wait until 00:00 tomorrow
# Note that timezones need to be taken into consideration for this to work.
# If the timezone is set to UTC+9 and the now timezone is UTC
# A conversion needs to be done.
# i.e. 03:00 UTC+9 -> 18:00 UTC the previous day
for idx, time in enumerate(self._time):
if time >= time_now:
self._time_index = idx
break
# Convert the current time to the target timezone
# e.g. 18:00 UTC -> 03:00 UTC+9
# Then compare the time instances to see if they're the same
start = now.astimezone(time.tzinfo)
if time >= start.timetz():
return idx
else:
self._time_index = idx + 1
return None
def _get_time_parameter(
self,
@ -687,12 +754,8 @@ class Loop(Generic[LF]):
self._sleep = self._seconds = self._minutes = self._hours = MISSING
if self.is_running():
if self._time is not MISSING:
# prepare the next time index starting from after the last iteration
self._prepare_time_index(now=self._last_iteration)
self._next_iteration = self._get_next_sleep_time()
if self._handle is not MISSING and not self._handle.done():
if self._handle and not self._handle.done():
# the loop is sleeping, recalculate based on new interval
self._handle.recalculate(self._next_iteration)
@ -705,7 +768,6 @@ def loop(
time: Union[datetime.time, Sequence[datetime.time]] = MISSING,
count: Optional[int] = None,
reconnect: bool = True,
loop: asyncio.AbstractEventLoop = MISSING,
) -> Callable[[LF], Loop[LF]]:
"""A decorator that schedules a task in the background for you with
optional reconnect logic. The decorator returns a :class:`Loop`.
@ -738,9 +800,6 @@ def loop(
Whether to handle errors and restart the task
using an exponential back-off algorithm similar to the
one used in :meth:`discord.Client.connect`.
loop: :class:`asyncio.AbstractEventLoop`
The loop to use to register the task, if not given
defaults to :func:`asyncio.get_event_loop`.
Raises
--------
@ -760,7 +819,6 @@ def loop(
count=count,
time=time,
reconnect=reconnect,
loop=loop,
)
return decorator

2
discord/file.py

@ -78,7 +78,7 @@ class File:
def __init__(
self,
fp: Union[str, bytes, os.PathLike, io.BufferedIOBase],
fp: Union[str, bytes, os.PathLike[Any], io.BufferedIOBase],
filename: Optional[str] = None,
*,
spoiler: bool = False,

21
discord/flags.py

@ -45,8 +45,8 @@ BF = TypeVar('BF', bound='BaseFlags')
class flag_value:
def __init__(self, func: Callable[[Any], int]):
self.flag = func(None)
self.__doc__ = func.__doc__
self.flag: int = func(None)
self.__doc__: Optional[str] = func.__doc__
@overload
def __get__(self, instance: None, owner: Type[BF]) -> Self:
@ -64,7 +64,7 @@ class flag_value:
def __set__(self, instance: BaseFlags, value: bool) -> None:
instance._set_flag(self.flag, value)
def __repr__(self):
def __repr__(self) -> str:
return f'<flag_value flag={self.flag!r}>'
@ -72,8 +72,8 @@ class alias_flag_value(flag_value):
pass
def fill_with_flags(*, inverted: bool = False):
def decorator(cls: Type[BF]):
def fill_with_flags(*, inverted: bool = False) -> Callable[[Type[BF]], Type[BF]]:
def decorator(cls: Type[BF]) -> Type[BF]:
# fmt: off
cls.VALID_FLAGS = {
name: value.flag
@ -115,10 +115,10 @@ class BaseFlags:
self.value = value
return self
def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
return isinstance(other, self.__class__) and self.value == other.value
def __ne__(self, other: Any) -> bool:
def __ne__(self, other: object) -> bool:
return not self.__eq__(other)
def __hash__(self) -> int:
@ -533,6 +533,11 @@ class PrivateUserFlags(PublicUserFlags):
""":class:`bool`: Returns ``True`` if the user has a partner or a verification application."""
return UserFlags.partner_or_verification_application.value
@flag_value
def disable_premium(self):
""":class:`bool`: Returns ``True`` if the user bought premium but has it manually disabled."""
return UserFlags.disable_premium.value
@fill_with_flags()
class MemberCacheFlags(BaseFlags):
@ -576,7 +581,7 @@ class MemberCacheFlags(BaseFlags):
def __init__(self, **kwargs: bool):
bits = max(self.VALID_FLAGS.values()).bit_length()
self.value = (1 << bits) - 1
self.value: int = (1 << bits) - 1
for key, value in kwargs.items():
if key not in self.VALID_FLAGS:
raise TypeError(f'{key!r} is not a valid flag name.')

84
discord/gateway.py

@ -25,7 +25,6 @@ from __future__ import annotations
import asyncio
from collections import deque
import concurrent.futures
import logging
import struct
import time
@ -54,8 +53,11 @@ __all__ = (
)
if TYPE_CHECKING:
from typing_extensions import Self
from .client import Client
from .state import ConnectionState
from .types.snowflake import Snowflake
from .voice_client import VoiceClient
@ -108,9 +110,6 @@ class GatewayRatelimiter:
return self.per - (current - self.window)
self.remaining -= 1
if self.remaining == 0:
self.window = current
return 0.0
async def block(self) -> None:
@ -222,7 +221,7 @@ class VoiceKeepAliveHandler(KeepAliveHandler):
ack_time = time.perf_counter()
self._last_ack = ack_time
self._last_recv = ack_time
self.latency = ack_time - self._last_send
self.latency: float = ack_time - self._last_send
self.recent_ack_latencies.append(self.latency)
if self.latency > 10:
_log.warning(self.behind_msg, self.latency)
@ -345,7 +344,7 @@ class DiscordWebSocket:
@classmethod
async def from_client(
cls: Type[DWS],
cls,
client: Client,
*,
initial: bool = False,
@ -353,7 +352,7 @@ class DiscordWebSocket:
session: Optional[str] = None,
sequence: Optional[int] = None,
resume: bool = False,
) -> DWS:
) -> Self:
"""Creates a main websocket for Discord from a :class:`Client`.
This is for internal use only.
@ -662,21 +661,21 @@ class DiscordWebSocket:
if activities is not None:
if not all(isinstance(activity, BaseActivity) for activity in activities):
raise TypeError('activity must derive from BaseActivity')
activities = [activity.to_dict() for activity in activities]
activities_data = [activity.to_dict() for activity in activities]
else:
activities = []
activities_data = []
if status == 'idle':
since = int(time.time() * 1000)
payload = {'op': self.PRESENCE, 'd': {'activities': activities, 'afk': afk, 'since': since, 'status': str(status)}}
payload = {'op': self.PRESENCE, 'd': {'activities': activities_data, 'afk': afk, 'since': since, 'status': str(status)}}
sent = utils._to_json(payload)
_log.debug('Sending "%s" to change presence.', sent)
await self.send(sent)
async def request_lazy_guild(
self, guild_id, *, typing=None, threads=None, activities=None, members=None, channels=None, thread_member_lists=None
self, guild_id: Snowflake, *, typing: Optional[bool] = None, threads: Optional[bool] = None, activities: Optional[bool] = None, members: Optional[List[Snowflake]]=None, channels: Optional[Dict[Snowflake, List[List[int]]]]=None, thread_member_lists: Optional[List[Snowflake]]=None
):
payload = {
'op': self.GUILD_SUBSCRIBE,
@ -704,11 +703,11 @@ class DiscordWebSocket:
async def request_chunks(
self,
guild_ids: List[int],
guild_ids: List[Snowflake],
query: Optional[str] = None,
*,
limit: Optional[int] = None,
user_ids: Optional[List[int]] = None,
user_ids: Optional[List[Snowflake]] = None,
presences: bool = True,
nonce: Optional[str] = None,
) -> None:
@ -723,7 +722,7 @@ class DiscordWebSocket:
},
}
if nonce:
if nonce is not None:
payload['d']['nonce'] = nonce
await self.send_as_json(payload)
@ -755,7 +754,7 @@ class DiscordWebSocket:
_log.debug('Updating %s voice state to %s.', guild_id or 'client', payload)
await self.send_as_json(payload)
async def access_dm(self, channel_id: int):
async def access_dm(self, channel_id: Snowflake):
payload = {'op': self.CALL_CONNECT, 'd': {'channel_id': str(channel_id)}}
_log.debug('Sending ACCESS_DM for channel %s.', channel_id)
@ -763,7 +762,7 @@ class DiscordWebSocket:
async def request_commands(
self,
guild_id: int,
guild_id: Snowflake,
type: int,
*,
nonce: Optional[str] = None,
@ -771,13 +770,13 @@ class DiscordWebSocket:
applications: Optional[bool] = None,
offset: int = 0,
query: Optional[str] = None,
command_ids: Optional[List[int]] = None,
application_id: Optional[int] = None,
command_ids: Optional[List[Snowflake]] = None,
application_id: Optional[Snowflake] = None,
) -> None:
payload = {
'op': self.REQUEST_COMMANDS,
'd': {
'guild_id': guild_id,
'guild_id': str(guild_id),
'type': type,
},
}
@ -795,7 +794,7 @@ class DiscordWebSocket:
if command_ids is not None:
payload['d']['command_ids'] = command_ids
if application_id is not None:
payload['d']['application_id'] = application_id
payload['d']['application_id'] = str(application_id)
await self.send_as_json(payload)
@ -871,11 +870,11 @@ class DiscordVoiceWebSocket:
*,
hook: Optional[Callable[..., Coroutine[Any, Any, Any]]] = None,
) -> None:
self.ws = socket
self.loop = loop
self._keep_alive = None
self._close_code = None
self.secret_key = None
self.ws: aiohttp.ClientWebSocketResponse = socket
self.loop: asyncio.AbstractEventLoop = loop
self._keep_alive: Optional[VoiceKeepAliveHandler] = None
self._close_code: Optional[int] = None
self.secret_key: Optional[str] = None
if hook:
self._hook = hook # type: ignore - type-checker doesn't like overriding methods
@ -914,7 +913,9 @@ class DiscordVoiceWebSocket:
await self.send_as_json(payload)
@classmethod
async def from_client(cls: Type[DVWS], client: VoiceClient, *, resume=False, hook=None) -> DVWS:
async def from_client(
cls, client: VoiceClient, *, resume: bool = False, hook: Optional[Callable[..., Coroutine[Any, Any, Any]]] = None
) -> Self:
"""Creates a voice websocket for the :class:`VoiceClient`."""
gateway = 'wss://' + client.endpoint + '/?v=4'
http = client._state.http
@ -971,29 +972,30 @@ class DiscordVoiceWebSocket:
async def received_message(self, msg: Dict[str, Any]) -> None:
_log.debug('Voice gateway event: %s.', msg)
op = msg['op']
data = msg.get('d')
data = msg['d'] # According to Discord this key is always given
if op == self.READY:
await self.initial_connection(data) # type: ignore - type-checker thinks data could be None
await self.initial_connection(data)
elif op == self.HEARTBEAT_ACK:
self._keep_alive.ack() # type: ignore - _keep_alive can't be None at this point
if self._keep_alive:
self._keep_alive.ack()
elif op == self.RESUMED:
_log.info('Voice RESUME succeeded.')
self.secret_key = self._connection.secret_key
elif op == self.SELECT_PROTOCOL_ACK:
self._connection.mode = data['mode'] # type: ignore - data can't be None at this point
await self.load_secret_key(data) # type: ignore - data can't be None at this point
self._connection.mode = data['mode']
await self.load_secret_key(data)
elif op == self.HELLO:
interval = data['heartbeat_interval'] / 1000.0 # type: ignore - type-checker thinks data could be None
interval = data['heartbeat_interval'] / 1000.0
self._keep_alive = VoiceKeepAliveHandler(ws=self, interval=min(interval, 5.0))
self._keep_alive.start()
elif op == self.SPEAKING:
state = self._connection
user_id = int(data['user_id']) # type: ignore - data can't be None at this point
speaking = data['speaking'] # type: ignore - data can't be None at this point
user_id = int(data['user_id'])
speaking = data['speaking']
ssrc = state._flip_ssrc(user_id)
if ssrc is None:
state._set_ssrc(user_id, SSRC(data['ssrc'], speaking)) # type: ignore - data can't be None at this point
state._set_ssrc(user_id, SSRC(data['ssrc'], speaking))
else:
ssrc.speaking = speaking
@ -1019,17 +1021,17 @@ class DiscordVoiceWebSocket:
# The IP is ascii starting at the 4th byte and ending at the first null
ip_start = 4
ip_end = recv.index(0, ip_start)
state.endpoint_ip = recv[ip_start:ip_end].decode('ascii')
state.ip = recv[ip_start:ip_end].decode('ascii')
state.voice_port = struct.unpack_from('>H', recv, len(recv) - 2)[0]
_log.debug('detected ip: %s port: %s', state.endpoint_ip, state.voice_port)
state.port = struct.unpack_from('>H', recv, len(recv) - 2)[0]
_log.debug('detected ip: %s port: %s', state.ip, state.port)
# There *should* always be at least one supported mode (xsalsa20_poly1305)
modes = [mode for mode in data['modes'] if mode in self._connection.supported_modes]
_log.debug('Received supported encryption modes: %s.', ", ".join(modes))
mode = modes[0]
await self.select_protocol(state.endpoint_ip, state.voice_port, mode)
await self.select_protocol(state.ip, state.port, mode)
_log.info('Selected the voice protocol for use: %s.', mode)
@property
@ -1047,9 +1049,9 @@ class DiscordVoiceWebSocket:
return sum(heartbeat.recent_ack_latencies) / len(heartbeat.recent_ack_latencies)
async def load_secret_key(self, data):
async def load_secret_key(self, data: Dict[str, Any]) -> None:
_log.info('Received secret key for voice connection.')
self.secret_key = self._connection.secret_key = data.get('secret_key')
self.secret_key = self._connection.secret_key = data['secret_key']
await self.speak()
await self.speak(SpeakingState.none)

141
discord/guild.py

@ -32,9 +32,11 @@ from typing import (
Any,
AsyncIterator,
ClassVar,
Collection,
Coroutine,
Dict,
List,
Mapping,
NamedTuple,
Sequence,
Set,
@ -129,6 +131,7 @@ if TYPE_CHECKING:
)
from .types.integration import IntegrationType
from .types.snowflake import SnowflakeList
from .types.widget import EditWidgetSettings
VocalGuildChannel = Union[VoiceChannel, StageChannel]
GuildChannel = Union[VocalGuildChannel, TextChannel, CategoryChannel, StoreChannel]
@ -340,6 +343,7 @@ class Guild(Hashable):
self._state: ConnectionState = state
self.notification_settings: Optional[GuildSettings] = None
self.command_counts: Optional[CommandCounts] = None
self._member_count: int = 0
self._from_data(data)
def _add_channel(self, channel: GuildChannel, /) -> None:
@ -384,7 +388,7 @@ class Guild(Hashable):
('id', self.id),
('name', self.name),
('chunked', self.chunked),
('member_count', getattr(self, '_member_count', None)),
('member_count', self._member_count),
)
inner = ' '.join('%s=%r' % t for t in attrs)
return f'<Guild {inner}>'
@ -435,9 +439,10 @@ class Guild(Hashable):
return role
def _from_data(self, guild: GuildPayload) -> None:
member_count = guild.get('member_count', guild.get('approximate_member_count'))
if member_count is not None:
self._member_count: int = member_count
try:
self._member_count: int = guild['member_count']
except KeyError:
pass
self.id: int = int(guild['id'])
self.name: str = guild.get('name', '')
@ -506,7 +511,7 @@ class Guild(Hashable):
self.owner_application_id: Optional[int] = utils._get_as_snowflake(guild, 'application_id')
self.premium_progress_bar_enabled: bool = guild.get('premium_progress_bar_enabled', False)
large = None if member_count is None else member_count >= 250
large = None if self._member_count is 0 else self._member_count >= 250
self._large: Optional[bool] = guild.get('large', large)
if (settings := guild.get('settings')) is not None:
@ -552,10 +557,9 @@ class Guild(Hashable):
members, which for this library is set to the maximum of 250.
"""
if self._large is None:
try:
if self._member_count is not None:
return self._member_count >= 250
except AttributeError:
return len(self._members) >= 250
return len(self._members) >= 250
return self._large
@property
@ -958,13 +962,16 @@ class Guild(Hashable):
return Asset._from_guild_image(self._state, self.id, self._discovery_splash, path='discovery-splashes')
@property
def member_count(self) -> int:
""":class:`int`: Returns the true member count regardless of it being loaded fully or not.
def member_count(self) -> Optional[int]:
"""Optional[:class:`int`]: Returns the member count if available.
.. warning::
Due to a Discord limitation, this may not always be up-to-date and accurate.
.. versionchanged:: 2.0
Now returns an ``Optional[int]``.
"""
return self._member_count
@ -1054,7 +1061,7 @@ class Guild(Hashable):
self,
name: str,
channel_type: Literal[ChannelType.text],
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ...,
overwrites: Mapping[Union[Role, Member], PermissionOverwrite] = ...,
category: Optional[Snowflake] = ...,
**options: Any,
) -> Coroutine[Any, Any, TextChannelPayload]:
@ -1065,7 +1072,7 @@ class Guild(Hashable):
self,
name: str,
channel_type: Literal[ChannelType.voice],
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ...,
overwrites: Mapping[Union[Role, Member], PermissionOverwrite] = ...,
category: Optional[Snowflake] = ...,
**options: Any,
) -> Coroutine[Any, Any, VoiceChannelPayload]:
@ -1076,7 +1083,7 @@ class Guild(Hashable):
self,
name: str,
channel_type: Literal[ChannelType.stage_voice],
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ...,
overwrites: Mapping[Union[Role, Member], PermissionOverwrite] = ...,
category: Optional[Snowflake] = ...,
**options: Any,
) -> Coroutine[Any, Any, StageChannelPayload]:
@ -1087,7 +1094,7 @@ class Guild(Hashable):
self,
name: str,
channel_type: Literal[ChannelType.category],
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ...,
overwrites: Mapping[Union[Role, Member], PermissionOverwrite] = ...,
category: Optional[Snowflake] = ...,
**options: Any,
) -> Coroutine[Any, Any, CategoryChannelPayload]:
@ -1098,7 +1105,7 @@ class Guild(Hashable):
self,
name: str,
channel_type: Literal[ChannelType.news],
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ...,
overwrites: Mapping[Union[Role, Member], PermissionOverwrite] = ...,
category: Optional[Snowflake] = ...,
**options: Any,
) -> Coroutine[Any, Any, NewsChannelPayload]:
@ -1109,7 +1116,7 @@ class Guild(Hashable):
self,
name: str,
channel_type: Literal[ChannelType.store],
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ...,
overwrites: Mapping[Union[Role, Member], PermissionOverwrite] = ...,
category: Optional[Snowflake] = ...,
**options: Any,
) -> Coroutine[Any, Any, StoreChannelPayload]:
@ -1120,7 +1127,7 @@ class Guild(Hashable):
self,
name: str,
channel_type: Literal[ChannelType.text],
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ...,
overwrites: Mapping[Union[Role, Member], PermissionOverwrite] = ...,
category: Optional[Snowflake] = ...,
**options: Any,
) -> Coroutine[Any, Any, GuildChannelPayload]:
@ -1131,7 +1138,7 @@ class Guild(Hashable):
self,
name: str,
channel_type: ChannelType,
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ...,
overwrites: Mapping[Union[Role, Member], PermissionOverwrite] = ...,
category: Optional[Snowflake] = ...,
**options: Any,
) -> Coroutine[Any, Any, GuildChannelPayload]:
@ -1141,13 +1148,13 @@ class Guild(Hashable):
self,
name: str,
channel_type: ChannelType,
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = MISSING,
overwrites: Mapping[Union[Role, Member], PermissionOverwrite] = MISSING,
category: Optional[Snowflake] = None,
**options: Any,
) -> Coroutine[Any, Any, GuildChannelPayload]:
if overwrites is MISSING:
overwrites = {}
elif not isinstance(overwrites, dict):
elif not isinstance(overwrites, Mapping):
raise TypeError('overwrites parameter expects a dict')
perms = []
@ -1180,7 +1187,7 @@ class Guild(Hashable):
topic: str = MISSING,
slowmode_delay: int = MISSING,
nsfw: bool = MISSING,
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = MISSING,
overwrites: Mapping[Union[Role, Member], PermissionOverwrite] = MISSING,
) -> TextChannel:
"""|coro|
@ -1201,8 +1208,8 @@ class Guild(Hashable):
will be required to update the position of the channel in the channel list.
.. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising
:exc:`TypeError`.
This function will now raise :exc:`TypeError` instead of
``InvalidArgument``.
Examples
----------
@ -1297,15 +1304,15 @@ class Guild(Hashable):
user_limit: int = MISSING,
rtc_region: Optional[str] = MISSING,
video_quality_mode: VideoQualityMode = MISSING,
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = MISSING,
overwrites: Mapping[Union[Role, Member], PermissionOverwrite] = MISSING,
) -> VoiceChannel:
"""|coro|
This is similar to :meth:`create_text_channel` except makes a :class:`VoiceChannel` instead.
.. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising
:exc:`TypeError`.
This function will now raise :exc:`TypeError` instead of
``InvalidArgument``.
Parameters
-----------
@ -1383,7 +1390,7 @@ class Guild(Hashable):
*,
topic: str,
position: int = MISSING,
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = MISSING,
overwrites: Mapping[Union[Role, Member], PermissionOverwrite] = MISSING,
category: Optional[CategoryChannel] = None,
reason: Optional[str] = None,
) -> StageChannel:
@ -1394,8 +1401,8 @@ class Guild(Hashable):
.. versionadded:: 1.7
.. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising
:exc:`TypeError`.
This function will now raise :exc:`TypeError` instead of
``InvalidArgument``.
Parameters
-----------
@ -1451,7 +1458,7 @@ class Guild(Hashable):
self,
name: str,
*,
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = MISSING,
overwrites: Mapping[Union[Role, Member], PermissionOverwrite] = MISSING,
reason: Optional[str] = None,
position: int = MISSING,
) -> CategoryChannel:
@ -1465,8 +1472,8 @@ class Guild(Hashable):
cannot have categories.
.. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising
:exc:`TypeError`.
This function will now raise :exc:`TypeError` instead of
``InvalidArgument``.
Raises
------
@ -1575,8 +1582,8 @@ class Guild(Hashable):
The ``region`` keyword parameter has been removed.
.. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising
:exc:`ValueError` or :exc:`TypeError` in various cases.
This function will now raise :exc:`TypeError` or
:exc:`ValueError` instead of ``InvalidArgument``.
.. versionchanged:: 2.0
The ``preferred_locale`` keyword parameter now accepts an enum instead of :class:`str`.
@ -1635,11 +1642,11 @@ class Guild(Hashable):
The new preferred locale for the guild. Used as the primary language in the guild.
rules_channel: Optional[:class:`TextChannel`]
The new channel that is used for rules. This is only available to
guilds that contain ``PUBLIC`` in :attr:`Guild.features`. Could be ``None`` for no rules
guilds that contain ``COMMUNITY`` in :attr:`Guild.features`. Could be ``None`` for no rules
channel.
public_updates_channel: Optional[:class:`TextChannel`]
The new channel that is used for public updates from Discord. This is only available to
guilds that contain ``PUBLIC`` in :attr:`Guild.features`. Could be ``None`` for no
guilds that contain ``COMMUNITY`` in :attr:`Guild.features`. Could be ``None`` for no
public updates channel.
premium_progress_bar_enabled: :class:`bool`
Whether the premium AKA server boost level progress bar should be enabled for the guild.
@ -1875,7 +1882,7 @@ class Guild(Hashable):
HTTPException
Fetching the profile failed.
InvalidData
The profile is not from this guild.
The member is not in this guild.
Returns
--------
@ -1928,7 +1935,7 @@ class Guild(Hashable):
:class:`BanEntry`
The :class:`BanEntry` object for the specified user.
"""
data: BanPayload = await self._state.http.get_ban(user.id, self.id)
data = await self._state.http.get_ban(user.id, self.id)
return BanEntry(user=User(state=self._state, data=data['user']), reason=data['reason'])
async def fetch_channel(self, channel_id: int, /) -> Union[GuildChannel, Thread]:
@ -1996,7 +2003,7 @@ class Guild(Hashable):
List[:class:`BanEntry`]
A list of :class:`BanEntry` objects.
"""
data: List[BanPayload] = await self._state.http.get_bans(self.id)
data = await self._state.http.get_bans(self.id)
return [BanEntry(user=User(state=self._state, data=e['user']), reason=e['reason']) for e in data]
async def prune_members(
@ -2004,7 +2011,7 @@ class Guild(Hashable):
*,
days: int,
compute_prune_count: bool = True,
roles: List[Snowflake] = MISSING,
roles: Collection[Snowflake] = MISSING,
reason: Optional[str] = None,
) -> Optional[int]:
r"""|coro|
@ -2026,8 +2033,8 @@ class Guild(Hashable):
The ``roles`` keyword-only parameter was added.
.. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising
:exc:`TypeError`.
This function will now raise :exc:`TypeError` instead of
``InvalidArgument``.
Parameters
-----------
@ -2118,7 +2125,7 @@ class Guild(Hashable):
data = await self._state.http.guild_webhooks(self.id)
return [Webhook.from_state(d, state=self._state) for d in data]
async def estimate_pruned_members(self, *, days: int, roles: List[Snowflake] = MISSING) -> Optional[int]:
async def estimate_pruned_members(self, *, days: int, roles: Collection[Snowflake] = MISSING) -> Optional[int]:
"""|coro|
Similar to :meth:`prune_members` except instead of actually
@ -2129,8 +2136,8 @@ class Guild(Hashable):
The returned value can be ``None``.
.. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising
:exc:`TypeError`.
This function will now raise :exc:`TypeError` instead of
``InvalidArgument``.
Parameters
-----------
@ -2673,7 +2680,7 @@ class Guild(Hashable):
*,
name: str,
image: bytes,
roles: List[Role] = MISSING,
roles: Collection[Role] = MISSING,
reason: Optional[str] = None,
) -> Emoji:
r"""|coro|
@ -2831,8 +2838,8 @@ class Guild(Hashable):
The ``display_icon`` keyword-only parameter was added.
.. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising
:exc:`TypeError`.
This function will now raise :exc:`TypeError` instead of
``InvalidArgument``.
Parameters
-----------
@ -2925,7 +2932,7 @@ class Guild(Hashable):
# TODO: add to cache
return role
async def edit_role_positions(self, positions: Dict[Snowflake, int], *, reason: Optional[str] = None) -> List[Role]:
async def edit_role_positions(self, positions: Mapping[Snowflake, int], *, reason: Optional[str] = None) -> List[Role]:
"""|coro|
Bulk edits a list of :class:`Role` in the guild.
@ -2936,8 +2943,8 @@ class Guild(Hashable):
.. versionadded:: 1.4
.. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising
:exc:`TypeError`.
This function will now raise :exc:`TypeError` instead of
``InvalidArgument``.
Example
----------
@ -2974,7 +2981,7 @@ class Guild(Hashable):
List[:class:`Role`]
A list of all the roles in the guild.
"""
if not isinstance(positions, dict):
if not isinstance(positions, Mapping):
raise TypeError('positions parameter expects a dict')
role_positions = []
@ -3024,7 +3031,7 @@ class Guild(Hashable):
user: Snowflake,
*,
reason: Optional[str] = None,
delete_message_days: Literal[0, 1, 2, 3, 4, 5, 6, 7] = 1,
delete_message_days: int = 1,
) -> None:
"""|coro|
@ -3080,6 +3087,16 @@ class Guild(Hashable):
"""
await self._state.http.unban(user.id, self.id, reason=reason)
@property
def vanity_url(self) -> Optional[str]:
"""Optional[:class:`str`]: The Discord vanity invite URL for this guild, if available.
.. versionadded:: 2.0
"""
if self.vanity_url_code is None:
return None
return f'{Invite.BASE}/{self.vanity_url_code}'
async def vanity_invite(self) -> Optional[Invite]:
"""|coro|
@ -3230,7 +3247,7 @@ class Guild(Hashable):
after = Object(id=utils.time_snowflake(after, high=True))
if oldest_first is MISSING:
reverse = after is not None
reverse = after is not MISSING
else:
reverse = oldest_first
@ -3340,7 +3357,7 @@ class Guild(Hashable):
HTTPException
Editing the widget failed.
"""
payload = {}
payload: EditWidgetSettings = {}
if channel is not MISSING:
payload['channel_id'] = None if channel is None else channel.id
if enabled is not MISSING:
@ -3452,14 +3469,14 @@ class Guild(Hashable):
The members that belong to this guild.
"""
if self._offline_members_hidden:
raise ClientException('This guild cannot be chunked.')
raise ClientException('This guild cannot be chunked')
if self._state.is_guild_evicted(self):
raise ClientException('This guild is no longer available.')
raise ClientException('This guild is no longer available')
members = await self._state.chunk_guild(self, channels=[channel] if channel else [])
if members is None:
raise ClientException('Chunking failed.')
return members # type: ignore
raise ClientException('Chunking failed')
return members
async def fetch_members(
self,
@ -3509,14 +3526,14 @@ class Guild(Hashable):
The members that belong to this guild (offline members may not be included).
"""
if self._state.is_guild_evicted(self):
raise ClientException('This guild is no longer available.')
raise ClientException('This guild is no longer available')
members = await self._state.scrape_guild(
self, cache=cache, force_scraping=force_scraping, delay=delay, channels=channels
)
if members is None:
raise ClientException('Fetching members failed')
return members # type: ignore
return members
async def query_members(
self,
@ -3533,7 +3550,7 @@ class Guild(Hashable):
Request members that belong to this guild whose username starts with
the query given.
This is a websocket operation and can be slow.
This is a websocket operation.
.. note::
This is preferrable to using :meth:`fetch_member` as the client uses

75
discord/http.py

@ -57,6 +57,7 @@ from .errors import HTTPException, Forbidden, NotFound, LoginFailure, DiscordSer
from .file import File
from .tracking import ContextProperties
from . import utils
from .mentions import AllowedMentions
from .utils import MISSING
_log = logging.getLogger(__name__)
@ -78,6 +79,7 @@ if TYPE_CHECKING:
appinfo,
audit_log,
channel,
command,
emoji,
guild,
integration,
@ -91,7 +93,6 @@ if TYPE_CHECKING:
widget,
team,
threads,
voice,
scheduled_event,
sticker,
welcome_screen,
@ -121,9 +122,9 @@ async def json_or_text(response: aiohttp.ClientResponse) -> Union[Dict[str, Any]
class MultipartParameters(NamedTuple):
payload: Optional[Dict[str, Any]]
multipart: Optional[List[Dict[str, Any]]]
files: Optional[List[File]]
files: Optional[Sequence[File]]
def __enter__(self):
def __enter__(self) -> Self:
return self
def __exit__(
@ -146,10 +147,10 @@ def handle_message_parameters(
nonce: Optional[Union[int, str]] = None,
flags: MessageFlags = MISSING,
file: File = MISSING,
files: List[File] = MISSING,
files: Sequence[File] = MISSING,
embed: Optional[Embed] = MISSING,
embeds: List[Embed] = MISSING,
attachments: List[Union[Attachment, File]] = MISSING,
embeds: Sequence[Embed] = MISSING,
attachments: Sequence[Union[Attachment, File]] = MISSING,
allowed_mentions: Optional[AllowedMentions] = MISSING,
message_reference: Optional[message.MessageReference] = MISSING,
stickers: Optional[SnowflakeList] = MISSING,
@ -215,15 +216,12 @@ def handle_message_parameters(
payload['allowed_mentions'] = previous_allowed_mentions.to_dict()
if mention_author is not None:
try:
payload['allowed_mentions']['replied_user'] = mention_author
except KeyError:
payload['allowed_mentions'] = {
'replied_user': mention_author,
}
if 'allowed_mentions' not in payload:
payload['allowed_mentions'] = AllowedMentions().to_dict()
payload['allowed_mentions']['replied_user'] = mention_author
if attachments is MISSING:
attachments = files # type: ignore
attachments = files
else:
files = [a for a in attachments if isinstance(a, File)]
@ -322,23 +320,24 @@ class HTTPClient:
def __init__(
self,
loop: asyncio.AbstractEventLoop,
connector: Optional[aiohttp.BaseConnector] = None,
*,
proxy: Optional[str] = None,
proxy_auth: Optional[aiohttp.BasicAuth] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
unsync_clock: bool = True,
http_trace: Optional[aiohttp.TraceConfig] = None,
) -> None:
self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop
self.connector: aiohttp.BaseConnector = connector or aiohttp.TCPConnector(limit=0)
self.loop: asyncio.AbstractEventLoop = loop
self.connector: aiohttp.BaseConnector = connector or MISSING
self.__session: aiohttp.ClientSession = MISSING
self._locks: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
self._global_over: asyncio.Event = asyncio.Event()
self._global_over.set()
self._global_over: asyncio.Event = MISSING
self.token: Optional[str] = None
self.ack_token: Optional[str] = None
self.proxy: Optional[str] = proxy
self.proxy_auth: Optional[aiohttp.BasicAuth] = proxy_auth
self.http_trace: Optional[aiohttp.TraceConfig] = http_trace
self.use_clock: bool = not unsync_clock
self.user_agent: str = MISSING
@ -357,7 +356,12 @@ class HTTPClient:
async def startup(self) -> None:
if self._started:
return
self.__session = session = aiohttp.ClientSession(connector=self.connector)
self.__session = session = aiohttp.ClientSession(
connector=self.connector,
loop=self.loop,
trace_configs=None if self.http_trace is None else [self.http_trace],
)
self.user_agent, self.browser_version, self.client_build_number = ua, bv, bn = await utils._get_info(session)
_log.info('Found user agent %s (%s), build number %s.', ua, bv, bn)
self.super_properties = sp = {
@ -587,7 +591,11 @@ class HTTPClient:
def recreate(self) -> None:
if self.__session and self.__session.closed:
self.__session = aiohttp.ClientSession(connector=self.connector)
self.__session = aiohttp.ClientSession(
connector=self.connector,
loop=self.loop,
trace_configs=None if self.http_trace is None else [self.http_trace],
)
async def close(self) -> None:
if self.__session:
@ -599,13 +607,19 @@ class HTTPClient:
self.token = token
self.ack_token = None
def get_me(self, with_analytics_token=True) -> Response[user.User]:
def get_me(self, with_analytics_token: bool = True) -> Response[user.User]:
params = {'with_analytics_token': str(with_analytics_token).lower()}
return self.request(Route('GET', '/users/@me'), params=params)
async def static_login(self, token: str) -> user.User:
old_token, self.token = self.token, token
if self.connector is MISSING:
self.connector = aiohttp.TCPConnector(loop=self.loop, limit=0)
self._global_over = asyncio.Event()
self._global_over.set()
await self.startup()
try:
@ -1254,15 +1268,14 @@ class HTTPClient:
def sync_template(self, guild_id: Snowflake, code: str) -> Response[template.Template]:
return self.request(Route('PUT', '/guilds/{guild_id}/templates/{code}', guild_id=guild_id, code=code))
def edit_template(self, guild_id: Snowflake, code: str, payload) -> Response[template.Template]:
r = Route('PATCH', '/guilds/{guild_id}/templates/{code}', guild_id=guild_id, code=code)
def edit_template(self, guild_id: Snowflake, code: str, payload: Dict[str, Any]) -> Response[template.Template]:
valid_keys = (
'name',
'description',
)
payload = {k: v for k, v in payload.items() if k in valid_keys}
return self.request(r, json=payload)
return self.request(Route('PATCH', '/guilds/{guild_id}/templates/{code}', guild_id=guild_id, code=code), json=payload)
def delete_template(self, guild_id: Snowflake, code: str) -> Response[None]:
return self.request(Route('DELETE', '/guilds/{guild_id}/templates/{code}', guild_id=guild_id, code=code))
@ -1300,7 +1313,7 @@ class HTTPClient:
guild_id: Snowflake,
days: int,
compute_prune_count: bool,
roles: List[str],
roles: Iterable[str],
*,
reason: Optional[str] = None,
) -> Response[guild.GuildPrune]:
@ -1317,7 +1330,7 @@ class HTTPClient:
self,
guild_id: Snowflake,
days: int,
roles: List[str],
roles: Iterable[str],
) -> Response[guild.GuildPrune]:
params: Dict[str, Any] = {
'days': days,
@ -1330,6 +1343,9 @@ class HTTPClient:
def get_sticker(self, sticker_id: Snowflake) -> Response[sticker.Sticker]:
return self.request(Route('GET', '/stickers/{sticker_id}', sticker_id=sticker_id))
def get_sticker_guild(self, sticker_id: Snowflake) -> Response[guild.Guild]:
return self.request(Route('GET', '/stickers/{sticker_id}/guild', sticker_id=sticker_id))
def list_premium_sticker_packs(
self, country: str = 'US', locale: str = 'en-US', payment_source_id: Snowflake = MISSING
) -> Response[sticker.ListPremiumStickerPacks]:
@ -1413,6 +1429,9 @@ class HTTPClient:
def get_custom_emoji(self, guild_id: Snowflake, emoji_id: Snowflake) -> Response[emoji.Emoji]:
return self.request(Route('GET', '/guilds/{guild_id}/emojis/{emoji_id}', guild_id=guild_id, emoji_id=emoji_id))
def get_emoji_guild(self, emoji_id: Snowflake) -> Response[guild.Guild]:
return self.request(Route('GET', '/emojis/{emoji_id}', emoji_id=emoji_id))
def create_custom_emoji(
self,
guild_id: Snowflake,
@ -1533,7 +1552,9 @@ class HTTPClient:
def get_widget(self, guild_id: Snowflake) -> Response[widget.Widget]:
return self.request(Route('GET', '/guilds/{guild_id}/widget.json', guild_id=guild_id))
def edit_widget(self, guild_id: Snowflake, payload, reason: Optional[str] = None) -> Response[widget.WidgetSettings]:
def edit_widget(
self, guild_id: Snowflake, payload: widget.EditWidgetSettings, reason: Optional[str] = None
) -> Response[widget.WidgetSettings]:
return self.request(Route('PATCH', '/guilds/{guild_id}/widget', guild_id=guild_id), json=payload, reason=reason)
def get_welcome_screen(self, guild_id: Snowflake) -> Response[welcome_screen.WelcomeScreen]:

21
discord/integrations.py

@ -39,6 +39,9 @@ __all__ = (
)
if TYPE_CHECKING:
from .guild import Guild
from .role import Role
from .state import ConnectionState
from .types.integration import (
IntegrationAccount as IntegrationAccountPayload,
Integration as IntegrationPayload,
@ -47,8 +50,6 @@ if TYPE_CHECKING:
IntegrationType,
IntegrationApplication as IntegrationApplicationPayload,
)
from .guild import Guild
from .role import Role
class IntegrationAccount:
@ -109,11 +110,11 @@ class Integration:
)
def __init__(self, *, data: IntegrationPayload, guild: Guild) -> None:
self.guild = guild
self._state = guild._state
self.guild: Guild = guild
self._state: ConnectionState = guild._state
self._from_data(data)
def __repr__(self):
def __repr__(self) -> str:
return f"<{self.__class__.__name__} id={self.id} name={self.name!r}>"
def _from_data(self, data: IntegrationPayload) -> None:
@ -123,7 +124,7 @@ class Integration:
self.account: IntegrationAccount = IntegrationAccount(data['account'])
user = data.get('user')
self.user = User(state=self._state, data=user) if user else None
self.user: Optional[User] = User(state=self._state, data=user) if user else None
self.enabled: bool = data['enabled']
async def delete(self, *, reason: Optional[str] = None) -> None:
@ -232,8 +233,8 @@ class StreamIntegration(Integration):
do this.
.. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising
:exc:`TypeError`.
This function will now raise :exc:`TypeError` instead of
``InvalidArgument``.
Parameters
-----------
@ -319,7 +320,7 @@ class IntegrationApplication:
'user',
)
def __init__(self, *, data: IntegrationApplicationPayload, state):
def __init__(self, *, data: IntegrationApplicationPayload, state: ConnectionState) -> None:
self.id: int = int(data['id'])
self.name: str = data['name']
self.icon: Optional[str] = data['icon']
@ -358,7 +359,7 @@ class BotIntegration(Integration):
def _from_data(self, data: BotIntegrationPayload) -> None:
super()._from_data(data)
self.application = IntegrationApplication(data=data['application'], state=self._state)
self.application: IntegrationApplication = IntegrationApplication(data=data['application'], state=self._state)
def _integration_factory(value: str) -> Tuple[Type[Integration], str]:

59
discord/invite.py

@ -30,7 +30,7 @@ from .utils import parse_time, snowflake_time, _get_as_snowflake, MISSING
from .object import Object
from .mixins import Hashable
from .scheduled_event import ScheduledEvent
from .enums import ChannelType, VerificationLevel, InviteTarget, InviteType, try_enum
from .enums import ChannelType, VerificationLevel, InviteTarget, InviteType, NSFWLevel, try_enum
from .welcome_screen import WelcomeScreen
__all__ = (
@ -165,9 +165,34 @@ class PartialInviteGuild:
A list of features the guild has. See :attr:`Guild.features` for more information.
description: Optional[:class:`str`]
The partial guild's description.
nsfw_level: :class:`NSFWLevel`
The partial guild's NSFW level.
.. versionadded:: 2.0
vanity_url_code: Optional[:class:`str`]
The partial guild's vanity URL code, if available.
.. versionadded:: 2.0
premium_subscription_count: :class:`int`
The number of "boosts" the partial guild currently has.
.. versionadded:: 2.0
"""
__slots__ = ('_state', 'features', '_icon', '_banner', 'id', 'name', '_splash', 'verification_level', 'description')
__slots__ = (
'_state',
'_icon',
'_banner',
'_splash',
'features',
'id',
'name',
'verification_level',
'description',
'vanity_url_code',
'nsfw_level',
'premium_subscription_count',
)
def __init__(self, state: ConnectionState, data: InviteGuildPayload, id: int):
self._state: ConnectionState = state
@ -179,6 +204,9 @@ class PartialInviteGuild:
self._splash: Optional[str] = data.get('splash')
self.verification_level: VerificationLevel = try_enum(VerificationLevel, data.get('verification_level'))
self.description: Optional[str] = data.get('description')
self.vanity_url_code: Optional[str] = data.get('vanity_url_code')
self.nsfw_level: NSFWLevel = try_enum(NSFWLevel, data.get('nsfw_level', 0))
self.premium_subscription_count: int = data.get('premium_subscription_count') or 0
def __str__(self) -> str:
return self.name
@ -194,6 +222,16 @@ class PartialInviteGuild:
""":class:`datetime.datetime`: Returns the guild's creation time in UTC."""
return snowflake_time(self.id)
@property
def vanity_url(self) -> Optional[str]:
"""Optional[:class:`str`]: The Discord vanity invite URL for this partial guild, if available.
.. versionadded:: 2.0
"""
if self.vanity_url_code is None:
return None
return f'{Invite.BASE}/{self.vanity_url_code}'
@property
def icon(self) -> Optional[Asset]:
"""Optional[:class:`Asset`]: Returns the guild's icon asset, if available."""
@ -446,12 +484,13 @@ class Invite(Hashable):
@classmethod
def from_gateway(cls, *, state: ConnectionState, data: GatewayInvitePayload) -> Self:
guild_id: Optional[int] = _get_as_snowflake(data, 'guild_id')
guild: Optional[Union[Guild, Object]] = state._get_guild(guild_id)
channel_id = _get_as_snowflake(data, 'channel_id')
if guild_id is not None:
guild: Optional[Union[Guild, Object]] = state._get_guild(guild_id) or Object(id=guild_id)
if channel_id is not None:
channel: Optional[InviteChannelType] = state.get_channel(channel_id) or Object(id=channel_id) # type: ignore
if guild is not None:
channel = (guild.get_channel(channel_id) or Object(id=channel_id)) if channel_id is not None else None
else:
guild = Object(id=guild_id) if guild_id is not None else None
channel = Object(id=channel_id) if channel_id is not None else None
return cls(state=state, data=data, guild=guild, channel=channel) # type: ignore
@ -543,7 +582,7 @@ class Invite(Hashable):
Raises
------
:exc:`.HTTPException`
HTTPException
Using the invite failed.
Returns
@ -587,7 +626,7 @@ class Invite(Hashable):
Raises
------
:exc:`.HTTPException`
HTTPException
Using the invite failed.
Returns
@ -597,7 +636,7 @@ class Invite(Hashable):
"""
return await self.use()
async def delete(self, *, reason: Optional[str] = None):
async def delete(self, *, reason: Optional[str] = None) -> None:
"""|coro|
Revokes the instant invite.

58
discord/member.py

@ -28,7 +28,7 @@ import datetime
import inspect
import itertools
from operator import attrgetter
from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Union
from typing import Any, Callable, Collection, Coroutine, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Union, Type
import discord.abc
@ -214,7 +214,7 @@ class _ClientStatus:
return self
def flatten_user(cls):
def flatten_user(cls: Any) -> Type[Member]:
for attr, value in itertools.chain(BaseUser.__dict__.items(), User.__dict__.items()):
# Ignore private/special methods (or not)
# if attr.startswith('_'):
@ -331,7 +331,7 @@ class Member(discord.abc.Messageable, discord.abc.Connectable, _UserTag):
default_avatar: Asset
avatar: Optional[Asset]
dm_channel: Optional[DMChannel]
create_dm = User.create_dm
create_dm: Callable[[], Coroutine[Any, Any, DMChannel]]
mutual_guilds: List[Guild]
public_flags: PublicUserFlags
banner: Optional[Asset]
@ -361,10 +361,10 @@ class Member(discord.abc.Messageable, discord.abc.Connectable, _UserTag):
f' bot={self._user.bot} nick={self.nick!r} guild={self.guild!r}>'
)
def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
return isinstance(other, _UserTag) and other.id == self.id
def __ne__(self, other: Any) -> bool:
def __ne__(self, other: object) -> bool:
return not self.__eq__(other)
def __hash__(self) -> int:
@ -445,7 +445,7 @@ class Member(discord.abc.Messageable, discord.abc.Connectable, _UserTag):
if self._self:
return
self._activities = tuple(map(create_activity, data['activities']))
self._activities = tuple(create_activity(d, self._state) for d in data['activities'])
self._client_status._update(data['status'], data['client_status'])
if len(user) > 1:
@ -696,7 +696,7 @@ class Member(discord.abc.Messageable, discord.abc.Connectable, _UserTag):
async def ban(
self,
*,
delete_message_days: Literal[0, 1, 2, 3, 4, 5, 6, 7] = 1,
delete_message_days: int = 1,
reason: Optional[str] = None,
) -> None:
"""|coro|
@ -726,7 +726,7 @@ class Member(discord.abc.Messageable, discord.abc.Connectable, _UserTag):
mute: bool = MISSING,
deafen: bool = MISSING,
suppress: bool = MISSING,
roles: List[discord.abc.Snowflake] = MISSING,
roles: Collection[discord.abc.Snowflake] = MISSING,
voice_channel: Optional[VocalGuildChannel] = MISSING,
timed_out_until: Optional[datetime.datetime] = MISSING,
avatar: Optional[bytes] = MISSING,
@ -783,7 +783,7 @@ class Member(discord.abc.Messageable, discord.abc.Connectable, _UserTag):
.. versionadded:: 1.7
roles: List[:class:`Role`]
The member's new list of roles. This *replaces* the roles.
voice_channel: Optional[:class:`VoiceChannel`]
voice_channel: Optional[Union[:class:`VoiceChannel`, :class:`StageChannel`]]
The voice channel to move the member to.
Pass ``None`` to kick them from voice.
timed_out_until: Optional[:class:`datetime.datetime`]
@ -913,7 +913,7 @@ class Member(discord.abc.Messageable, discord.abc.Connectable, _UserTag):
else:
await self._state.http.edit_my_voice_state(self.guild.id, payload)
async def move_to(self, channel: VocalGuildChannel, *, reason: Optional[str] = None) -> None:
async def move_to(self, channel: Optional[VocalGuildChannel], *, reason: Optional[str] = None) -> None:
"""|coro|
Moves a member to a new voice channel (they must be connected first).
@ -928,7 +928,7 @@ class Member(discord.abc.Messageable, discord.abc.Connectable, _UserTag):
Parameters
-----------
channel: Optional[:class:`VoiceChannel`]
channel: Optional[Union[:class:`VoiceChannel`, :class:`StageChannel`]]
The new voice channel to move the member to.
Pass ``None`` to kick them from voice.
reason: Optional[:class:`str`]
@ -936,6 +936,42 @@ class Member(discord.abc.Messageable, discord.abc.Connectable, _UserTag):
"""
await self.edit(voice_channel=channel, reason=reason)
async def timeout(self, when: Union[datetime.timedelta, datetime.datetime], /, *, reason: Optional[str] = None) -> None:
"""|coro|
Applies a time out to a member until the specified date time or for the
given :class:`datetime.timedelta`.
You must have the :attr:`~Permissions.moderate_members` permission to
use this.
This raises the same exceptions as :meth:`edit`.
Parameters
-----------
when: Union[:class:`datetime.timedelta`, :class:`datetime.datetime`]
If this is a :class:`datetime.timedelta` then it represents the amount of
time the member should be timed out for. If this is a :class:`datetime.datetime`
then it's when the member's timeout should expire. Note that the API only allows
for timeouts up to 28 days.
reason: Optional[:class:`str`]
The reason for doing this action. Shows up on the audit log.
Raises
-------
TypeError
The ``when`` parameter was the wrong type of the datetime was not timezone-aware.
"""
if isinstance(when, datetime.timedelta):
timed_out_until = utils.utcnow() + when
elif isinstance(when, datetime.datetime):
timed_out_until = when
else:
raise TypeError(f'expected datetime.datetime or datetime.timedelta not {when.__class__!r}')
await self.edit(timed_out_until=timed_out_until, reason=reason)
async def add_roles(self, *roles: Snowflake, reason: Optional[str] = None, atomic: bool = True) -> None:
r"""|coro|

10
discord/mentions.py

@ -23,7 +23,7 @@ DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import Union, List, TYPE_CHECKING, Any, Union
from typing import Union, List, TYPE_CHECKING, Any
# fmt: off
__all__ = (
@ -92,10 +92,10 @@ class AllowedMentions:
roles: Union[bool, List[Snowflake]] = default,
replied_user: bool = default,
):
self.everyone = everyone
self.users = users
self.roles = roles
self.replied_user = replied_user
self.everyone: bool = everyone
self.users: Union[bool, List[Snowflake]] = users
self.roles: Union[bool, List[Snowflake]] = roles
self.replied_user: bool = replied_user
@classmethod
def all(cls) -> Self:

2279
discord/message.py

File diff suppressed because it is too large

5
discord/opus.py

@ -363,7 +363,7 @@ class Encoder(_OpusStruct):
_lib.opus_encoder_ctl(self._state, CTL_SET_FEC, 1 if enabled else 0)
def set_expected_packet_loss_percent(self, percentage: float) -> None:
_lib.opus_encoder_ctl(self._state, CTL_SET_PLP, min(100, max(0, int(percentage * 100)))) # type: ignore
_lib.opus_encoder_ctl(self._state, CTL_SET_PLP, min(100, max(0, int(percentage * 100))))
def encode(self, pcm: bytes, frame_size: int) -> bytes:
max_data_bytes = len(pcm)
@ -373,8 +373,7 @@ class Encoder(_OpusStruct):
ret = _lib.opus_encode(self._state, pcm_ptr, frame_size, data, max_data_bytes)
# array can be initialized with bytes but mypy doesn't know
return array.array('b', data[:ret]).tobytes() # type: ignore
return array.array('b', data[:ret]).tobytes()
class Decoder(_OpusStruct):

15
discord/partial_emoji.py

@ -42,6 +42,7 @@ if TYPE_CHECKING:
from .state import ConnectionState
from datetime import datetime
from .types.message import PartialEmoji as PartialEmojiPayload
from .types.activity import ActivityEmoji
class _EmojiTag:
@ -99,13 +100,13 @@ class PartialEmoji(_EmojiTag, AssetMixin):
id: Optional[int]
def __init__(self, *, name: str, animated: bool = False, id: Optional[int] = None):
self.animated = animated
self.name = name
self.id = id
self.animated: bool = animated
self.name: str = name
self.id: Optional[int] = id
self._state: Optional[ConnectionState] = None
@classmethod
def from_dict(cls, data: Union[PartialEmojiPayload, Dict[str, Any]]) -> Self:
def from_dict(cls, data: Union[PartialEmojiPayload, ActivityEmoji, Dict[str, Any]]) -> Self:
return cls(
animated=data.get('animated', False),
id=utils._get_as_snowflake(data, 'id'),
@ -178,10 +179,10 @@ class PartialEmoji(_EmojiTag, AssetMixin):
return f'<a:{self.name}:{self.id}>'
return f'<:{self.name}:{self.id}>'
def __repr__(self):
def __repr__(self) -> str:
return f'<{self.__class__.__name__} animated={self.animated} name={self.name!r} id={self.id}>'
def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
if self.is_unicode_emoji():
return isinstance(other, PartialEmoji) and self.name == other.name
@ -189,7 +190,7 @@ class PartialEmoji(_EmojiTag, AssetMixin):
return self.id == other.id
return False
def __ne__(self, other: Any) -> bool:
def __ne__(self, other: object) -> bool:
return not self.__eq__(other)
def __hash__(self) -> int:

19
discord/permissions.py

@ -234,12 +234,19 @@ class Permissions(BaseFlags):
@classmethod
def stage_moderator(cls) -> Self:
"""A factory method that creates a :class:`Permissions` with all
"Stage Moderator" permissions from the official Discord UI set to ``True``.
"""A factory method that creates a :class:`Permissions` with all permissions
for stage moderators set to ``True``. These permissions are currently:
- :attr:`manage_channels`
- :attr:`mute_members`
- :attr:`move_members`
.. versionadded:: 1.7
.. versionchanged:: 2.0
Added :attr:`manage_channels` permission and removed :attr:`request_to_speak` permission.
"""
return cls(0b100000001010000000000000000000000)
return cls(0b1010000000000000000010000)
@classmethod
def advanced(cls) -> Self:
@ -279,7 +286,7 @@ class Permissions(BaseFlags):
# So 0000 OP2 0101 -> 0101
# The OP is base & ~denied.
# The OP2 is base | allowed.
self.value = (self.value & ~deny) | allow
self.value: int = (self.value & ~deny) | allow
@flag_value
def create_instant_invite(self) -> int:
@ -697,7 +704,7 @@ class PermissionOverwrite:
setattr(self, key, value)
def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
return isinstance(other, PermissionOverwrite) and self._values == other._values
def _set(self, key: str, value: Optional[bool]) -> None:
@ -750,7 +757,7 @@ class PermissionOverwrite:
"""
return len(self._values) == 0
def update(self, **kwargs: bool) -> None:
def update(self, **kwargs: Optional[bool]) -> None:
r"""Bulk updates this permission overwrite object.
Allows you to set multiple attributes by using keyword

31
discord/player.py

@ -365,12 +365,11 @@ class FFmpegOpusAudio(FFmpegAudio):
bitrate: Optional[int] = None,
codec: Optional[str] = None,
executable: str = 'ffmpeg',
pipe=False,
stderr=None,
before_options=None,
options=None,
pipe: bool = False,
stderr: Optional[IO[bytes]] = None,
before_options: Optional[str] = None,
options: Optional[str] = None,
) -> None:
args = []
subprocess_kwargs = {'stdin': subprocess.PIPE if pipe else subprocess.DEVNULL, 'stderr': stderr}
@ -521,9 +520,9 @@ class FFmpegOpusAudio(FFmpegAudio):
raise TypeError(f"Expected str or callable for parameter 'probe', not '{method.__class__.__name__}'")
codec = bitrate = None
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
try:
codec, bitrate = await loop.run_in_executor(None, lambda: probefunc(source, executable)) # type: ignore
codec, bitrate = await loop.run_in_executor(None, lambda: probefunc(source, executable))
except Exception:
if not fallback:
_log.exception("Probe '%s' using '%s' failed", method, executable)
@ -531,7 +530,7 @@ class FFmpegOpusAudio(FFmpegAudio):
_log.exception("Probe '%s' using '%s' failed, trying fallback", method, executable)
try:
codec, bitrate = await loop.run_in_executor(None, lambda: fallback(source, executable)) # type: ignore
codec, bitrate = await loop.run_in_executor(None, lambda: fallback(source, executable))
except Exception:
_log.exception("Fallback probe using '%s' failed", executable)
else:
@ -635,7 +634,13 @@ class PCMVolumeTransformer(AudioSource, Generic[AT]):
class AudioPlayer(threading.Thread):
DELAY: float = OpusEncoder.FRAME_LENGTH / 1000.0
def __init__(self, source: AudioSource, client: Player, *, after=None):
def __init__(
self,
source: AudioSource,
client: Player,
*,
after: Optional[Callable[[Optional[Exception]], Any]] = None,
) -> None:
threading.Thread.__init__(self)
self.daemon: bool = True
self.source: AudioSource = source
@ -644,7 +649,7 @@ class AudioPlayer(threading.Thread):
self._end: threading.Event = threading.Event()
self._resumed: threading.Event = threading.Event()
self._resumed.set() # we are not paused
self._resumed.set() # We are not paused
self._current_error: Optional[Exception] = None
self._connected: threading.Event = client.client._connected
self._lock: threading.Lock = threading.Lock()
@ -724,8 +729,8 @@ class AudioPlayer(threading.Thread):
self._speak(SpeakingState.none)
def resume(self, *, update_speaking: bool = True) -> None:
self.loops = 0
self._start = time.perf_counter()
self.loops: int = 0
self._start: float = time.perf_counter()
self._resumed.set()
if update_speaking:
self._speak(SpeakingState.voice)
@ -744,6 +749,6 @@ class AudioPlayer(threading.Thread):
def _speak(self, speaking: SpeakingState) -> None:
try:
asyncio.run_coroutine_threadsafe(self.client.ws.speak(speaking), self.client.loop)
asyncio.run_coroutine_threadsafe(self.client.ws.speak(speaking), self.client.client.loop)
except Exception as e:
_log.info("Speaking call in player failed: %s", e)

10
discord/reaction.py

@ -25,6 +25,7 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
from typing import Any, TYPE_CHECKING, AsyncIterator, Union, Optional
from .user import User
from .object import Object
# fmt: off
@ -34,7 +35,6 @@ __all__ = (
# fmt: on
if TYPE_CHECKING:
from .user import User
from .member import Member
from .types.message import Reaction as ReactionPayload
from .message import Message
@ -94,10 +94,10 @@ class Reaction:
""":class:`bool`: If this is a custom emoji."""
return not isinstance(self.emoji, str)
def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
return isinstance(other, self.__class__) and other.emoji == self.emoji
def __ne__(self, other: Any) -> bool:
def __ne__(self, other: object) -> bool:
if isinstance(other, self.__class__):
return other.emoji != self.emoji
return True
@ -149,8 +149,8 @@ class Reaction:
.. versionadded:: 1.3
.. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising
:exc:`ValueError`.
This function will now raise :exc:`ValueError` instead of
``InvalidArgument``.
Raises
--------

8
discord/role.py

@ -212,7 +212,7 @@ class Role(Hashable):
def __repr__(self) -> str:
return f'<Role id={self.id} name={self.name!r}>'
def __lt__(self, other: Any) -> bool:
def __lt__(self, other: object) -> bool:
if not isinstance(other, Role) or not isinstance(self, Role):
return NotImplemented
@ -242,7 +242,7 @@ class Role(Hashable):
def __gt__(self, other: Any) -> bool:
return Role.__lt__(other, self)
def __ge__(self, other: Any) -> bool:
def __ge__(self, other: object) -> bool:
r = Role.__lt__(self, other)
if r is NotImplemented:
return NotImplemented
@ -416,8 +416,8 @@ class Role(Hashable):
The ``display_icon``, ``icon``, and ``unicode_emoji`` keyword-only parameters were added.
.. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising
:exc:`ValueError`.
This function will now raise :exc:`ValueError` instead of
``InvalidArgument``.
Parameters
-----------

78
discord/scheduled_event.py

@ -41,6 +41,7 @@ if TYPE_CHECKING:
)
from .abc import Snowflake
from .guild import Guild
from .channel import VoiceChannel, StageChannel
from .state import ConnectionState
from .user import User
@ -79,15 +80,15 @@ class ScheduledEvent(Hashable):
The scheduled event's ID.
name: :class:`str`
The name of the scheduled event.
description: :class:`str`
description: Optional[:class:`str`]
The description of the scheduled event.
entity_type: :class:`EntityType`
The type of entity this event is for.
entity_id: :class:`int`
The ID of the entity this event is for.
entity_id: Optional[:class:`int`]
The ID of the entity this event is for if available.
start_time: :class:`datetime.datetime`
The time that the scheduled event will start in UTC.
end_time: :class:`datetime.datetime`
end_time: Optional[:class:`datetime.datetime`]
The time that the scheduled event will end in UTC.
privacy_level: :class:`PrivacyLevel`
The privacy level of the scheduled event.
@ -130,9 +131,9 @@ class ScheduledEvent(Hashable):
self.id: int = int(data['id'])
self.guild_id: int = int(data['guild_id'])
self.name: str = data['name']
self.description: str = data.get('description', '')
self.entity_type = try_enum(EntityType, data['entity_type'])
self.entity_id: int = int(data['id'])
self.description: Optional[str] = data.get('description')
self.entity_type: EntityType = try_enum(EntityType, data['entity_type'])
self.entity_id: Optional[int] = _get_as_snowflake(data, 'entity_id')
self.start_time: datetime = parse_time(data['scheduled_start_time'])
self.privacy_level: PrivacyLevel = try_enum(PrivacyLevel, data['status'])
self.status: EventStatus = try_enum(EventStatus, data['status'])
@ -145,15 +146,14 @@ class ScheduledEvent(Hashable):
self.end_time: Optional[datetime] = parse_time(data.get('scheduled_end_time'))
self.channel_id: Optional[int] = _get_as_snowflake(data, 'channel_id')
metadata = data.get('metadata')
if metadata:
self._unroll_metadata(metadata)
metadata = data.get('entity_metadata')
self._unroll_metadata(metadata)
def _unroll_metadata(self, data: EntityMetadata):
self.location: Optional[str] = data.get('location')
def _unroll_metadata(self, data: Optional[EntityMetadata]):
self.location: Optional[str] = data.get('location') if data else None
@classmethod
def from_creation(cls, *, state: ConnectionState, data: GuildScheduledEventPayload):
def from_creation(cls, *, state: ConnectionState, data: GuildScheduledEventPayload) -> None:
creator_id = data.get('creator_id')
self = cls(state=state, data=data)
if creator_id:
@ -169,11 +169,21 @@ class ScheduledEvent(Hashable):
return None
return Asset._from_scheduled_event_cover_image(self._state, self.id, self._cover_image)
@property
def guild(self) -> Optional[Guild]:
"""Optional[:class:`Guild`]: The guild this scheduled event is in."""
return self._state._get_guild(self.guild_id)
@property
def channel(self) -> Optional[Union[VoiceChannel, StageChannel]]:
"""Optional[Union[:class:`VoiceChannel`, :class:`StageChannel`]]: The channel this scheduled event is in."""
return self.guild.get_channel(self.channel_id) # type: ignore
@property
def url(self) -> str:
""":class:`str`: The url for the scheduled event."""
return f'https://discord.com/events/{self.guild_id}/{self.id}'
async def start(self, *, reason: Optional[str] = None) -> ScheduledEvent:
"""|coro|
@ -286,7 +296,7 @@ class ScheduledEvent(Hashable):
description: str = MISSING,
channel: Optional[Snowflake] = MISSING,
start_time: datetime = MISSING,
end_time: datetime = MISSING,
end_time: Optional[datetime] = MISSING,
privacy_level: PrivacyLevel = MISSING,
entity_type: EntityType = MISSING,
status: EventStatus = MISSING,
@ -314,10 +324,14 @@ class ScheduledEvent(Hashable):
start_time: :class:`datetime.datetime`
The time that the scheduled event will start. This must be a timezone-aware
datetime object. Consider using :func:`utils.utcnow`.
end_time: :class:`datetime.datetime`
end_time: Optional[:class:`datetime.datetime`]
The time that the scheduled event will end. This must be a timezone-aware
datetime object. Consider using :func:`utils.utcnow`.
If the entity type is either :attr:`EntityType.voice` or
:attr:`EntityType.stage_instance`, the end_time can be cleared by
passing ``None``.
Required if the entity type is :attr:`EntityType.external`.
privacy_level: :class:`PrivacyLevel`
The privacy level of the scheduled event.
@ -325,8 +339,8 @@ class ScheduledEvent(Hashable):
The new entity type.
status: :class:`EventStatus`
The new status of the scheduled event.
image: :class:`bytes`
The new image of the scheduled event.
image: Optional[:class:`bytes`]
The new image of the scheduled event or ``None`` to remove the image.
location: :class:`str`
The new location of the scheduled event.
@ -383,7 +397,7 @@ class ScheduledEvent(Hashable):
payload['status'] = status.value
if image is not MISSING:
image_as_str: str = _bytes_to_base64_data(image)
image_as_str: Optional[str] = _bytes_to_base64_data(image) if image is not None else image
payload['image'] = image_as_str
if entity_type is not MISSING:
@ -400,25 +414,31 @@ class ScheduledEvent(Hashable):
payload['channel_id'] = channel.id
if location is not MISSING:
if location not in (MISSING, None):
raise TypeError('location cannot be set when entity_type is voice or stage_instance')
payload['entity_metadata'] = None
else:
if channel is not MISSING:
if channel not in (MISSING, None):
raise TypeError('channel cannot be set when entity_type is external')
payload['channel_id'] = None
if location is MISSING or location is None:
raise TypeError('location must be set when entity_type is external')
metadata['location'] = location
if end_time is MISSING:
if end_time is MISSING or end_time is None:
raise TypeError('end_time must be set when entity_type is external')
if end_time.tzinfo is None:
raise ValueError(
'end_time must be an aware datetime. Consider using discord.utils.utcnow() or datetime.datetime.now().astimezone() for local time.'
)
payload['scheduled_end_time'] = end_time.isoformat()
if end_time is not MISSING:
if end_time is not None:
if end_time.tzinfo is None:
raise ValueError(
'end_time must be an aware datetime. Consider using discord.utils.utcnow() or datetime.datetime.now().astimezone() for local time.'
)
payload['scheduled_end_time'] = end_time.isoformat()
else:
payload['scheduled_end_time'] = end_time
if metadata:
payload['entity_metadata'] = metadata
@ -459,7 +479,7 @@ class ScheduledEvent(Hashable):
) -> AsyncIterator[User]:
"""|coro|
Retrieves all :class:`User` that are in this thread.
Retrieves all :class:`User` that are subscribed to this event.
This requires :attr:`Intents.members` to get information about members
other than yourself.
@ -472,7 +492,7 @@ class ScheduledEvent(Hashable):
Returns
--------
List[:class:`User`]
All thread members in the thread.
All subscribed users of this event.
"""
async def _before_strategy(retrieve, before, limit):
@ -548,4 +568,4 @@ class ScheduledEvent(Hashable):
self._users[user.id] = user
def _pop_user(self, user_id: int) -> None:
self._users.pop(user_id)
self._users.pop(user_id, None)

22
discord/stage_instance.py

@ -26,7 +26,7 @@ from __future__ import annotations
from typing import Optional, TYPE_CHECKING
from .utils import MISSING, cached_slot_property
from .utils import MISSING, cached_slot_property, _get_as_snowflake
from .mixins import Hashable
from .enums import PrivacyLevel, try_enum
@ -41,6 +41,7 @@ if TYPE_CHECKING:
from .state import ConnectionState
from .channel import StageChannel
from .guild import Guild
from .scheduled_event import ScheduledEvent
class StageInstance(Hashable):
@ -76,6 +77,10 @@ class StageInstance(Hashable):
The privacy level of the stage instance.
discoverable_disabled: :class:`bool`
Whether discoverability for the stage instance is disabled.
scheduled_event_id: Optional[:class:`int`]
The ID of scheduled event that belongs to the stage instance if any.
.. versionadded:: 2.0
"""
__slots__ = (
@ -86,20 +91,23 @@ class StageInstance(Hashable):
'topic',
'privacy_level',
'discoverable_disabled',
'scheduled_event_id',
'_cs_channel',
'_cs_scheduled_event',
)
def __init__(self, *, state: ConnectionState, guild: Guild, data: StageInstancePayload) -> None:
self._state = state
self.guild = guild
self._state: ConnectionState = state
self.guild: Guild = guild
self._update(data)
def _update(self, data: StageInstancePayload):
def _update(self, data: StageInstancePayload) -> None:
self.id: int = int(data['id'])
self.channel_id: int = int(data['channel_id'])
self.topic: str = data['topic']
self.privacy_level: PrivacyLevel = try_enum(PrivacyLevel, data['privacy_level'])
self.discoverable_disabled: bool = data.get('discoverable_disabled', False)
self.scheduled_event_id: Optional[int] = _get_as_snowflake(data, 'guild_scheduled_event_id')
def __repr__(self) -> str:
return f'<StageInstance id={self.id} guild={self.guild!r} channel_id={self.channel_id} topic={self.topic!r}>'
@ -115,6 +123,12 @@ class StageInstance(Hashable):
# The returned channel will always be a StageChannel or None
return self._state.get_channel(self.channel_id) # type: ignore
@cached_slot_property('_cs_scheduled_event')
def scheduled_event(self) -> Optional[ScheduledEvent]:
"""Optional[:class:`ScheduledEvent`]: The scheduled event that belongs to the stage instance."""
# Guild.get_scheduled_event() expects an int, we are passing Optional[int]
return self.guild.get_scheduled_event(self.scheduled_event_id) # type: ignore
async def edit(
self,
*,

74
discord/state.py

@ -41,6 +41,8 @@ from typing import (
Coroutine,
Tuple,
Deque,
Literal,
overload,
)
import weakref
import inspect
@ -93,7 +95,7 @@ if TYPE_CHECKING:
from .types.activity import Activity as ActivityPayload
from .types.channel import DMChannel as DMChannelPayload
from .types.user import User as UserPayload, PartialUser as PartialUserPayload
from .types.emoji import Emoji as EmojiPayload
from .types.emoji import Emoji as EmojiPayload, PartialEmoji as PartialEmojiPayload
from .types.sticker import GuildSticker as GuildStickerPayload
from .types.guild import Guild as GuildPayload
from .types.message import Message as MessagePayload, PartialMessage as PartialMessagePayload
@ -376,24 +378,24 @@ class ConnectionState:
def __init__(
self,
*,
dispatch: Callable,
handlers: Dict[str, Callable],
hooks: Dict[str, Callable],
dispatch: Callable[..., Any],
handlers: Dict[str, Callable[..., Any]],
hooks: Dict[str, Callable[..., Coroutine[Any, Any, Any]]],
http: HTTPClient,
loop: asyncio.AbstractEventLoop,
client: Client,
**options: Any,
) -> None:
self.loop: asyncio.AbstractEventLoop = loop
# Set later, after Client.login
self.loop: asyncio.AbstractEventLoop = utils.MISSING
self.http: HTTPClient = http
self.client = client
self.max_messages: Optional[int] = options.get('max_messages', 1000)
if self.max_messages is not None and self.max_messages <= 0:
self.max_messages = 1000
self.dispatch: Callable = dispatch
self.handlers: Dict[str, Callable] = handlers
self.hooks: Dict[str, Callable] = hooks
self.dispatch: Callable[..., Any] = dispatch
self.handlers: Dict[str, Callable[..., Any]] = handlers
self.hooks: Dict[str, Callable[..., Coroutine[Any, Any, Any]]] = hooks
self._ready_task: Optional[asyncio.Task] = None
self.heartbeat_timeout: float = options.get('heartbeat_timeout', 60.0)
@ -439,11 +441,11 @@ class ConnectionState:
if cache_flags._empty:
self.store_user = self.create_user # type: ignore
parsers = {}
self.parsers: Dict[str, Callable[[Any], None]]
self.parsers = parsers = {}
for attr, func in inspect.getmembers(self):
if attr.startswith('parse_'):
parsers[attr[6:].upper()] = func
self.parsers: Dict[str, Callable[[Dict[str, Any]], None]] = parsers
self.clear()
@ -505,6 +507,9 @@ class ConnectionState:
else:
await coro(*args, **kwargs)
async def async_setup(self) -> None:
pass
@property
def session_id(self) -> Optional[str]:
return self.ws.session_id
@ -588,7 +593,7 @@ class ConnectionState:
def create_user(self, data: Union[UserPayload, PartialUserPayload]) -> User:
return User(state=self, data=data)
def get_user(self, id):
def get_user(self, id: int) -> Optional[User]:
return self._users.get(id)
def store_emoji(self, guild: Guild, data: EmojiPayload) -> Emoji:
@ -1049,7 +1054,7 @@ class ConnectionState:
if old_member._client_status != member._client_status or old_member._activities != member._activities:
self.dispatch('presence_update', old_member, member)
def parse_user_update(self, data: gw.UserUpdateEvent):
def parse_user_update(self, data: gw.UserUpdateEvent) -> None:
if self.user:
self.user._update(data)
@ -1260,6 +1265,8 @@ class ConnectionState:
existing = guild.get_thread(int(data['id']))
if existing is not None:
old = existing._update(data)
if existing.archived:
guild._remove_thread(existing)
if old is not None:
self.dispatch('thread_update', old, existing)
else: # Shouldn't happen
@ -1397,10 +1404,8 @@ class ConnectionState:
def parse_guild_member_remove(self, data: gw.GuildMemberRemoveEvent) -> None:
guild = self._get_guild(int(data['guild_id']))
if guild is not None:
try:
if guild._member_count is not None:
guild._member_count -= 1
except AttributeError:
pass
user_id = int(data['user']['id'])
member = guild.get_member(user_id)
@ -1630,7 +1635,7 @@ class ConnectionState:
guild.stickers = tuple(map(lambda d: self.store_sticker(guild, d), data['stickers']))
self.dispatch('guild_stickers_update', guild, before_stickers, guild.stickers)
def _get_create_guild(self, data):
def _get_create_guild(self, data: gw.GuildCreateEvent):
guild = self._get_guild(int(data['id']))
# Discord being Discord sends a GUILD_CREATE after an OPCode 14 is sent (a la bots)
# However, we want that if we forced a GUILD_CREATE for an unavailable guild
@ -1640,7 +1645,7 @@ class ConnectionState:
return self._add_guild_from_data(data)
def is_guild_evicted(self, guild) -> bool:
def is_guild_evicted(self, guild: Guild) -> bool:
return guild.id not in self._guilds
async def assert_guild_presence_count(self, guild: Guild):
@ -1706,9 +1711,15 @@ class ConnectionState:
)
request.start()
if wait:
return await request.wait()
return request.get_future()
@overload
async def chunk_guild(self, guild: Guild, *, wait: Literal[True] = ..., channels: List[abcSnowflake] = ...) -> Optional[List[Member]]:
...
@overload
async def chunk_guild(
self, guild: Guild, *, wait: Literal[False] = ..., channels: List[abcSnowflake] = ...
) -> asyncio.Future[Optional[List[Member]]]:
...
async def chunk_guild(
self,
@ -1716,7 +1727,7 @@ class ConnectionState:
*,
wait: bool = True,
channels: List[abcSnowflake] = MISSING,
):
) -> Union[asyncio.Future[Optional[List[Member]]], Optional[List[Member]]]:
if not guild.me:
await guild.query_members(user_ids=[self.self_id], cache=True) # type: ignore - self_id is always present here
@ -1960,7 +1971,7 @@ class ConnectionState:
if guild is not None:
scheduled_event = ScheduledEvent(state=self, data=data)
guild._scheduled_events[scheduled_event.id] = scheduled_event
self.dispatch('scheduled_event_create', guild, scheduled_event)
self.dispatch('scheduled_event_create', scheduled_event)
else:
_log.debug('SCHEDULED_EVENT_CREATE referencing unknown guild ID: %s. Discarding.', data['guild_id'])
@ -1971,7 +1982,7 @@ class ConnectionState:
if scheduled_event is not None:
old_scheduled_event = copy.copy(scheduled_event)
scheduled_event._update(data)
self.dispatch('scheduled_event_update', guild, old_scheduled_event, scheduled_event)
self.dispatch('scheduled_event_update', old_scheduled_event, scheduled_event)
else:
_log.debug('SCHEDULED_EVENT_UPDATE referencing unknown scheduled event ID: %s. Discarding.', data['id'])
else:
@ -1985,7 +1996,7 @@ class ConnectionState:
except KeyError:
pass
else:
self.dispatch('scheduled_event_delete', guild, scheduled_event)
self.dispatch('scheduled_event_delete', scheduled_event)
else:
_log.debug('SCHEDULED_EVENT_DELETE referencing unknown guild ID: %s. Discarding.', data['guild_id'])
@ -1997,10 +2008,9 @@ class ConnectionState:
user = self.get_user(int(data['user_id']))
if user is not None:
scheduled_event._add_user(user)
self.dispatch('scheduled_event_user_add', guild, scheduled_event, user)
self.dispatch('scheduled_event_user_add', scheduled_event, user)
else:
_log.debug('SCHEDULED_EVENT_USER_ADD referencing unknown user ID: %s. Discarding.', data['user_id'])
self.dispatch('scheduled_event_user_add', guild, scheduled_event, user)
else:
_log.debug(
'SCHEDULED_EVENT_USER_ADD referencing unknown scheduled event ID: %s. Discarding.',
@ -2020,7 +2030,6 @@ class ConnectionState:
self.dispatch('scheduled_event_user_remove', scheduled_event, user)
else:
_log.debug('SCHEDULED_EVENT_USER_REMOVE referencing unknown user ID: %s. Discarding.', data['user_id'])
self.dispatch('scheduled_event_user_remove', scheduled_event, user)
else:
_log.debug(
'SCHEDULED_EVENT_USER_REMOVE referencing unknown scheduled event ID: %s. Discarding.',
@ -2173,16 +2182,19 @@ class ConnectionState:
return channel.guild.get_member(user_id)
return self.get_user(user_id)
def get_reaction_emoji(self, data) -> Union[Emoji, PartialEmoji]:
def get_reaction_emoji(self, data: PartialEmojiPayload) -> Union[Emoji, PartialEmoji, str]:
emoji_id = utils._get_as_snowflake(data, 'id')
if not emoji_id:
return data['name']
# the name key will be a str
return data['name'] # type: ignore
try:
return self._emojis[emoji_id]
except KeyError:
return PartialEmoji.with_state(self, animated=data.get('animated', False), id=emoji_id, name=data['name'])
return PartialEmoji.with_state(
self, animated=data.get('animated', False), id=emoji_id, name=data['name'] # type: ignore
)
def _upgrade_partial_emoji(self, emoji: PartialEmoji) -> Union[Emoji, PartialEmoji, str]:
emoji_id = emoji.id

4
discord/template.py

@ -174,8 +174,8 @@ class Template:
The ``region`` parameter has been removed.
.. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising
:exc:`ValueError`.
This function will now raise :exc:`ValueError` instead of
``InvalidArgument``.
Parameters
----------

70
discord/threads.py

@ -43,6 +43,7 @@ __all__ = (
if TYPE_CHECKING:
from datetime import datetime
from typing_extensions import Self
from .types.threads import (
Thread as ThreadPayload,
@ -143,13 +144,13 @@ class Thread(Messageable, Hashable):
'_created_at',
)
def __init__(self, *, guild: Guild, state: ConnectionState, data: ThreadPayload):
def __init__(self, *, guild: Guild, state: ConnectionState, data: ThreadPayload) -> None:
self._state: ConnectionState = state
self.guild = guild
self.guild: Guild = guild
self._members: Dict[int, ThreadMember] = {}
self._from_data(data)
async def _get_channel(self):
async def _get_channel(self) -> Self:
return self
def __repr__(self) -> str:
@ -162,26 +163,25 @@ class Thread(Messageable, Hashable):
return self.name
def _from_data(self, data: ThreadPayload):
self.id = int(data['id'])
self.parent_id = int(data['parent_id'])
self.owner_id = int(data['owner_id'])
self.name = data['name']
self._type = try_enum(ChannelType, data['type'])
self.last_message_id = _get_as_snowflake(data, 'last_message_id')
self.slowmode_delay = data.get('rate_limit_per_user', 0)
self.message_count = data['message_count']
self.member_count = data['member_count']
self._member_ids = data['member_ids_preview']
self.id: int = int(data['id'])
self.parent_id: int = int(data['parent_id'])
self.owner_id: int = int(data['owner_id'])
self.name: str = data['name']
self._type: ChannelType = try_enum(ChannelType, data['type'])
self.last_message_id: Optional[int] = _get_as_snowflake(data, 'last_message_id')
self.slowmode_delay: int = data.get('rate_limit_per_user', 0)
self.message_count: int = data['message_count']
self.member_count: int = data['member_count']
self._member_ids: List[Union[str, int]] = data['member_ids_preview']
self._unroll_metadata(data['thread_metadata'])
def _unroll_metadata(self, data: ThreadMetadata):
self.archived = data['archived']
self.auto_archive_duration = data['auto_archive_duration']
self.archive_timestamp = parse_time(data['archive_timestamp'])
self._created_at = parse_time(data.get('creation_timestamp'))
self.locked = data.get('locked', False)
self.invitable = data.get('invitable', True)
self._created_at = parse_time(data.get('create_timestamp'))
self.archived: bool = data['archived']
self.auto_archive_duration: int = data['auto_archive_duration']
self.archive_timestamp: datetime = parse_time(data['archive_timestamp'])
self.locked: bool = data.get('locked', False)
self.invitable: bool = data.get('invitable', True)
self._created_at: Optional[datetime] = parse_time(data.get('create_timestamp'))
def _update(self, data):
old = copy.copy(self)
@ -249,8 +249,7 @@ class Thread(Messageable, Hashable):
def members(self) -> List[ThreadMember]:
"""List[:class:`ThreadMember`]: A list of thread members in this thread.
Initial members are not provided by Discord. You must call :func:`fetch_members`
or have thread subscribing enabled.
Initial members are not provided by Discord. You must call :func:`fetch_members`.
"""
return list(self._members.values())
@ -577,7 +576,7 @@ class Thread(Messageable, Hashable):
# The data payload will always be a Thread payload
return Thread(data=data, state=self._state, guild=self.guild) # type: ignore
async def join(self):
async def join(self) -> None:
"""|coro|
Joins this thread.
@ -594,7 +593,7 @@ class Thread(Messageable, Hashable):
"""
await self._state.http.join_thread(self.id)
async def leave(self):
async def leave(self) -> None:
"""|coro|
Leaves this thread.
@ -606,7 +605,7 @@ class Thread(Messageable, Hashable):
"""
await self._state.http.leave_thread(self.id)
async def add_user(self, user: Snowflake, /):
async def add_user(self, user: Snowflake, /) -> None:
"""|coro|
Adds a user to this thread.
@ -629,7 +628,7 @@ class Thread(Messageable, Hashable):
"""
await self._state.http.add_user_to_thread(self.id, user.id)
async def remove_user(self, user: Snowflake, /):
async def remove_user(self, user: Snowflake, /) -> None:
"""|coro|
Removes a user from this thread.
@ -680,7 +679,7 @@ class Thread(Messageable, Hashable):
return self.members # Includes correct self.me
async def delete(self):
async def delete(self) -> None:
"""|coro|
Deletes this thread.
@ -772,23 +771,24 @@ class ThreadMember(Hashable):
'parent',
)
def __init__(self, parent: Thread, data: ThreadMemberPayload):
self.parent = parent
self._state = parent._state
def __init__(self, parent: Thread, data: ThreadMemberPayload) -> None:
self.parent: Thread = parent
self._state: ConnectionState = parent._state
self._from_data(data)
def __repr__(self) -> str:
return f'<ThreadMember id={self.id} thread_id={self.thread_id} joined_at={self.joined_at!r}>'
def _from_data(self, data: ThreadMemberPayload):
def _from_data(self, data: ThreadMemberPayload) -> None:
state = self._state
self.id: int
try:
self.id = int(data['user_id'])
except KeyError:
assert state.self_id is not None
self.id = state.self_id
self.id = state.self_id # type: ignore
self.thread_id: int
try:
self.thread_id = int(data['id'])
except KeyError:
@ -798,7 +798,7 @@ class ThreadMember(Hashable):
self.flags = data.get('flags')
if (mdata := data.get('member')) is not None:
guild = self.parent.parent.guild # type: ignore
guild = self.parent.guild
mdata['guild_id'] = guild.id
self.id = user_id = int(data['user_id'])
mdata['presence'] = data.get('presence')
@ -817,4 +817,4 @@ class ThreadMember(Hashable):
"""Optional[:class:`Member`]: The member this :class:`ThreadMember` represents. If the member
is not cached then this will be ``None``.
"""
return self.parent.parent.guild.get_member(self.id) # type: ignore
return self.parent.guild.get_member(self.id)

1
discord/types/activity.py

@ -112,3 +112,4 @@ class Activity(_BaseActivity, total=False):
session_id: Optional[str]
instance: bool
buttons: List[ActivityButton]
sync_id: str

1
discord/types/channel.py

@ -156,3 +156,4 @@ class StageInstance(TypedDict):
topic: str
privacy_level: PrivacyLevel
discoverable_disabled: bool
guild_scheduled_event_id: Optional[int]

242
discord/types/interactions.py

@ -0,0 +1,242 @@
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Dict, List, Literal, TypedDict, Union
from .channel import ChannelTypeWithoutThread, ThreadMetadata
from .threads import ThreadType
from .member import Member
from .message import Attachment
from .role import Role
from .snowflake import Snowflake
from .user import User
if TYPE_CHECKING:
from .message import Message
InteractionType = Literal[1, 2, 3, 4, 5]
class _BasePartialChannel(TypedDict):
id: Snowflake
name: str
permissions: str
class PartialChannel(_BasePartialChannel):
type: ChannelTypeWithoutThread
class PartialThread(_BasePartialChannel):
type: ThreadType
thread_metadata: ThreadMetadata
parent_id: Snowflake
class ResolvedData(TypedDict, total=False):
users: Dict[str, User]
members: Dict[str, Member]
roles: Dict[str, Role]
channels: Dict[str, Union[PartialChannel, PartialThread]]
messages: Dict[str, Message]
attachments: Dict[str, Attachment]
class _BaseApplicationCommandInteractionDataOption(TypedDict):
name: str
class _CommandGroupApplicationCommandInteractionDataOption(_BaseApplicationCommandInteractionDataOption):
type: Literal[1, 2]
options: List[ApplicationCommandInteractionDataOption]
class _BaseValueApplicationCommandInteractionDataOption(_BaseApplicationCommandInteractionDataOption, total=False):
focused: bool
class _StringValueApplicationCommandInteractionDataOption(_BaseValueApplicationCommandInteractionDataOption):
type: Literal[3]
value: str
class _IntegerValueApplicationCommandInteractionDataOption(_BaseValueApplicationCommandInteractionDataOption):
type: Literal[4]
value: int
class _BooleanValueApplicationCommandInteractionDataOption(_BaseValueApplicationCommandInteractionDataOption):
type: Literal[5]
value: bool
class _SnowflakeValueApplicationCommandInteractionDataOption(_BaseValueApplicationCommandInteractionDataOption):
type: Literal[6, 7, 8, 9, 11]
value: Snowflake
class _NumberValueApplicationCommandInteractionDataOption(_BaseValueApplicationCommandInteractionDataOption):
type: Literal[10]
value: float
_ValueApplicationCommandInteractionDataOption = Union[
_StringValueApplicationCommandInteractionDataOption,
_IntegerValueApplicationCommandInteractionDataOption,
_BooleanValueApplicationCommandInteractionDataOption,
_SnowflakeValueApplicationCommandInteractionDataOption,
_NumberValueApplicationCommandInteractionDataOption,
]
ApplicationCommandInteractionDataOption = Union[
_CommandGroupApplicationCommandInteractionDataOption,
_ValueApplicationCommandInteractionDataOption,
]
class _BaseApplicationCommandInteractionDataOptional(TypedDict, total=False):
resolved: ResolvedData
guild_id: Snowflake
class _BaseApplicationCommandInteractionData(_BaseApplicationCommandInteractionDataOptional):
id: Snowflake
name: str
class ChatInputApplicationCommandInteractionData(_BaseApplicationCommandInteractionData, total=False):
type: Literal[1]
options: List[ApplicationCommandInteractionDataOption]
class _BaseNonChatInputApplicationCommandInteractionData(_BaseApplicationCommandInteractionData):
target_id: Snowflake
class UserApplicationCommandInteractionData(_BaseNonChatInputApplicationCommandInteractionData):
type: Literal[2]
class MessageApplicationCommandInteractionData(_BaseNonChatInputApplicationCommandInteractionData):
type: Literal[3]
ApplicationCommandInteractionData = Union[
ChatInputApplicationCommandInteractionData,
UserApplicationCommandInteractionData,
MessageApplicationCommandInteractionData,
]
class _BaseMessageComponentInteractionData(TypedDict):
custom_id: str
class ButtonMessageComponentInteractionData(_BaseMessageComponentInteractionData):
component_type: Literal[2]
class SelectMessageComponentInteractionData(_BaseMessageComponentInteractionData):
component_type: Literal[3]
values: List[str]
MessageComponentInteractionData = Union[ButtonMessageComponentInteractionData, SelectMessageComponentInteractionData]
class ModalSubmitTextInputInteractionData(TypedDict):
type: Literal[4]
custom_id: str
value: str
ModalSubmitComponentItemInteractionData = ModalSubmitTextInputInteractionData
class ModalSubmitActionRowInteractionData(TypedDict):
type: Literal[1]
components: List[ModalSubmitComponentItemInteractionData]
ModalSubmitComponentInteractionData = Union[ModalSubmitActionRowInteractionData, ModalSubmitComponentItemInteractionData]
class ModalSubmitInteractionData(TypedDict):
custom_id: str
components: List[ModalSubmitActionRowInteractionData]
InteractionData = Union[
ApplicationCommandInteractionData,
MessageComponentInteractionData,
ModalSubmitInteractionData,
]
class _BaseInteractionOptional(TypedDict, total=False):
guild_id: Snowflake
channel_id: Snowflake
locale: str
guild_locale: str
class _BaseInteraction(_BaseInteractionOptional):
id: Snowflake
application_id: Snowflake
token: str
version: Literal[1]
class PingInteraction(_BaseInteraction):
type: Literal[1]
class ApplicationCommandInteraction(_BaseInteraction):
type: Literal[2, 4]
data: ApplicationCommandInteractionData
class MessageComponentInteraction(_BaseInteraction):
type: Literal[3]
data: MessageComponentInteractionData
class ModalSubmitInteraction(_BaseInteraction):
type: Literal[5]
data: ModalSubmitInteractionData
Interaction = Union[PingInteraction, ApplicationCommandInteraction, MessageComponentInteraction, ModalSubmitInteraction]
class MessageInteraction(TypedDict):
id: Snowflake
type: InteractionType
name: str
user: User

4
discord/types/scheduled_event.py

@ -35,7 +35,7 @@ EntityType = Literal[1, 2, 3]
class _BaseGuildScheduledEventOptional(TypedDict, total=False):
creator_id: Optional[Snowflake]
description: str
description: Optional[str]
creator: User
user_count: int
image: Optional[str]
@ -75,7 +75,7 @@ class EntityMetadata(TypedDict):
class ExternalScheduledEvent(_BaseGuildScheduledEvent):
channel_id: Literal[None]
entity_metadata: EntityMetadata
scheduled_end_time: Optional[str]
scheduled_end_time: str
entity_type: Literal[3]

14
discord/types/widget.py

@ -46,18 +46,20 @@ class WidgetMember(User, total=False):
suppress: bool
class _WidgetOptional(TypedDict, total=False):
class Widget(TypedDict):
id: Snowflake
name: str
instant_invite: Optional[str]
channels: List[WidgetChannel]
members: List[WidgetMember]
presence_count: int
class Widget(_WidgetOptional):
id: Snowflake
name: str
instant_invite: str
class WidgetSettings(TypedDict):
enabled: bool
channel_id: Optional[Snowflake]
class WidgetSettings(TypedDict):
class EditWidgetSettings(TypedDict, total=False):
enabled: bool
channel_id: Optional[Snowflake]

10
discord/user.py

@ -245,10 +245,10 @@ class BaseUser(_UserTag):
def __str__(self) -> str:
return f'{self.name}#{self.discriminator}'
def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
return isinstance(other, _UserTag) and other.id == self.id
def __ne__(self, other: Any) -> bool:
def __ne__(self, other: object) -> bool:
return not self.__eq__(other)
def __hash__(self) -> int:
@ -613,8 +613,8 @@ class ClientUser(BaseUser):
The edit is no longer in-place, instead the newly edited client user is returned.
.. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising
:exc:`ValueError`.
This function will now raise :exc:`ValueError` instead of
``InvalidArgument``.
Parameters
-----------
@ -845,7 +845,7 @@ class User(BaseUser, discord.abc.Connectable, discord.abc.Messageable):
def _get_voice_state_pair(self) -> Tuple[int, int]:
return self._state.self_id, self.dm_channel.id
async def _get_channel(self):
async def _get_channel(self) -> DMChannel:
ch = await self.create_dm()
return ch

53
discord/utils.py

@ -29,6 +29,7 @@ from typing import (
Any,
AsyncIterable,
AsyncIterator,
Awaitable,
Callable,
Coroutine,
Dict,
@ -42,6 +43,7 @@ from typing import (
NamedTuple,
Optional,
Protocol,
Set,
Sequence,
Tuple,
Type,
@ -76,7 +78,7 @@ import yarl
from .enums import BrowserEnum
try:
import orjson
import orjson # type: ignore
except ModuleNotFoundError:
HAS_ORJSON = False
else:
@ -111,6 +113,9 @@ class _MissingSentinel:
def __bool__(self):
return False
def __hash__(self):
return 0
def __repr__(self):
return '...'
@ -137,7 +142,7 @@ if TYPE_CHECKING:
from aiohttp import ClientSession
from functools import cached_property as cached_property
from typing_extensions import ParamSpec
from typing_extensions import ParamSpec, Self
from .permissions import Permissions
from .abc import Messageable, Snowflake
@ -151,8 +156,16 @@ if TYPE_CHECKING:
P = ParamSpec('P')
MaybeCoroFunc = Union[
Callable[P, Coroutine[Any, Any, 'T']],
Callable[P, 'T'],
]
_SnowflakeListBase = array.array[int]
else:
cached_property = _cached_property
_SnowflakeListBase = array.array
T = TypeVar('T')
@ -194,7 +207,7 @@ class classproperty(Generic[T_co]):
def __get__(self, instance: Optional[Any], owner: Type[Any]) -> T_co:
return self.fget(owner)
def __set__(self, instance, value) -> None:
def __set__(self, instance: Optional[Any], value: Any) -> None:
raise AttributeError('cannot set attribute')
@ -226,7 +239,7 @@ class SequenceProxy(Sequence[T_co]):
def __reversed__(self) -> Iterator[T_co]:
return reversed(self.__proxied)
def index(self, value: Any, *args, **kwargs) -> int:
def index(self, value: Any, *args: Any, **kwargs: Any) -> int:
return self.__proxied.index(value, *args, **kwargs)
def count(self, value: Any) -> int:
@ -255,10 +268,10 @@ def parse_time(timestamp: Optional[str]) -> Optional[datetime.datetime]:
def copy_doc(original: Callable) -> Callable[[T], T]:
def decorator(overriden: T) -> T:
overriden.__doc__ = original.__doc__
overriden.__signature__ = _signature(original) # type: ignore
return overriden
def decorator(overridden: T) -> T:
overridden.__doc__ = original.__doc__
overridden.__signature__ = _signature(original) # type: ignore
return overridden
return decorator
@ -588,9 +601,13 @@ def _bytes_to_base64_data(data: bytes) -> str:
return fmt.format(mime=mime, data=b64)
def _is_submodule(parent: str, child: str) -> bool:
return parent == child or child.startswith(parent + '.')
if HAS_ORJSON:
def _to_json(obj: Any) -> str: # type: ignore
def _to_json(obj: Any) -> str:
return orjson.dumps(obj).decode('utf-8')
_from_json = orjson.loads # type: ignore
@ -614,15 +631,15 @@ def _parse_ratelimit_header(request: Any, *, use_clock: bool = False) -> float:
return float(reset_after)
async def maybe_coroutine(f, *args, **kwargs):
async def maybe_coroutine(f: MaybeCoroFunc[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
value = f(*args, **kwargs)
if _isawaitable(value):
return await value
else:
return value
return value # type: ignore
async def async_all(gen, *, check=_isawaitable):
async def async_all(gen: Iterable[Awaitable[T]], *, check: Callable[[T], bool] = _isawaitable) -> bool:
for elem in gen:
if check(elem):
elem = await elem
@ -631,7 +648,7 @@ async def async_all(gen, *, check=_isawaitable):
return True
async def sane_wait_for(futures, *, timeout):
async def sane_wait_for(futures: Iterable[Awaitable[T]], *, timeout: Optional[float]) -> Set[asyncio.Task[T]]:
ensured = [asyncio.ensure_future(fut) for fut in futures]
done, pending = await asyncio.wait(ensured, timeout=timeout, return_when=asyncio.ALL_COMPLETED)
@ -649,7 +666,7 @@ def get_slots(cls: Type[Any]) -> Iterator[str]:
continue
def compute_timedelta(dt: datetime.datetime):
def compute_timedelta(dt: datetime.datetime) -> float:
if dt.tzinfo is None:
dt = dt.astimezone()
now = datetime.datetime.now(datetime.timezone.utc)
@ -698,7 +715,7 @@ def valid_icon_size(size: int) -> bool:
return not size & (size - 1) and 4096 >= size >= 16
class SnowflakeList(array.array):
class SnowflakeList(_SnowflakeListBase):
"""Internal data storage class to efficiently store a list of snowflakes.
This should have the following characteristics:
@ -717,7 +734,7 @@ class SnowflakeList(array.array):
def __init__(self, data: Iterable[int], *, is_sorted: bool = False):
...
def __new__(cls, data: Iterable[int], *, is_sorted: bool = False):
def __new__(cls, data: Iterable[int], *, is_sorted: bool = False) -> Self:
return array.array.__new__(cls, 'Q', data if is_sorted else sorted(data)) # type: ignore
def add(self, element: int) -> None:
@ -1022,7 +1039,7 @@ def evaluate_annotation(
cache: Dict[str, Any],
*,
implicit_str: bool = True,
):
) -> Any:
if isinstance(tp, ForwardRef):
tp = tp.__forward_arg__
# ForwardRefs always evaluate their internals
@ -1092,7 +1109,7 @@ def is_inside_class(func: Callable[..., Any]) -> bool:
# denoting which class it belongs to. So, e.g. for A.foo the qualname
# would be A.foo while a global foo() would just be foo.
#
# Unfortuately, for nested functions this breaks. So inside an outer
# Unfortunately, for nested functions this breaks. So inside an outer
# function named outer, those two would end up having a qualname with
# outer.<locals>.A.foo and outer.<locals>.foo

25
discord/voice_client.py

@ -74,6 +74,7 @@ has_nacl: bool
try:
import nacl.secret # type: ignore
import nacl.utils # type: ignore
has_nacl = True
except ImportError:
@ -373,14 +374,14 @@ class VoiceClient(VoiceProtocol):
The endpoint we are connecting to.
channel: :class:`abc.Connectable`
The voice channel connected to.
loop: :class:`asyncio.AbstractEventLoop`
The event loop that the voice client is running on.
"""
channel: abc.Connectable
endpoint_ip: str
voice_port: int
secret_key: List[int]
ip: str
port: int
secret_key: Optional[str]
def __init__(self, client: Client, channel: abc.Connectable):
if not has_nacl:
@ -414,7 +415,7 @@ class VoiceClient(VoiceProtocol):
self.idrcs: Dict[int, int] = {}
self.ssids: Dict[int, int] = {}
warn_nacl = not has_nacl
warn_nacl: bool = not has_nacl
supported_modes: Tuple[SupportedModes, ...] = (
'xsalsa20_poly1305_lite',
'xsalsa20_poly1305_suffix',
@ -443,8 +444,15 @@ class VoiceClient(VoiceProtocol):
# Connection related
def checked_add(self, attr: str, value: int, limit: int) -> None:
val = getattr(self, attr)
if val + value > limit:
setattr(self, attr, 0)
else:
setattr(self, attr, val + value)
async def on_voice_state_update(self, data: GuildVoiceStatePayload) -> None:
self.session_id = data['session_id']
self.session_id: str = data['session_id']
channel_id = data['channel_id']
if not self._handshaking or self._potentially_reconnecting:
@ -484,11 +492,12 @@ class VoiceClient(VoiceProtocol):
self.endpoint, _, _ = endpoint.rpartition(':')
if self.endpoint.startswith('wss://'):
self.endpoint = self.endpoint[6:] # Shouldn't ever be there...
# Just in case, strip it off since we're going to add it later
self.endpoint: str = self.endpoint[6:]
self.endpoint_ip = MISSING
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.socket: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.socket.setblocking(False)
if not self._handshaking:
@ -575,7 +584,7 @@ class VoiceClient(VoiceProtocol):
raise
if self._runner is MISSING:
self._runner = self.loop.create_task(self.poll_voice_ws(reconnect))
self._runner = self.client.loop.create_task(self.poll_voice_ws(reconnect))
async def potential_reconnect(self) -> bool:
# Attempt to stop the player thread from playing early

337
discord/webhook/async_.py

@ -30,7 +30,7 @@ import json
import re
from urllib.parse import quote as urlquote
from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Union, overload
from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Sequence, Tuple, Union, TypeVar, Type, overload
from contextvars import ContextVar
import weakref
@ -43,7 +43,7 @@ from ..enums import try_enum, WebhookType
from ..user import BaseUser, User
from ..flags import MessageFlags
from ..asset import Asset
from ..http import Route, handle_message_parameters, MultipartParameters
from ..http import Route, handle_message_parameters, HTTPClient
from ..mixins import Hashable
from ..channel import PartialMessageable
from ..file import File
@ -58,23 +58,37 @@ __all__ = (
_log = logging.getLogger(__name__)
if TYPE_CHECKING:
from typing_extensions import Self
from types import TracebackType
from ..embeds import Embed
from ..mentions import AllowedMentions
from ..message import Attachment
from ..state import ConnectionState
from ..http import Response
from ..guild import Guild
from ..channel import TextChannel
from ..abc import Snowflake
import datetime
from ..types.webhook import (
Webhook as WebhookPayload,
SourceGuild as SourceGuildPayload,
)
from ..types.message import (
Message as MessagePayload,
)
from ..guild import Guild
from ..channel import TextChannel
from ..abc import Snowflake
import datetime
from ..types.user import (
User as UserPayload,
PartialUser as PartialUserPayload,
)
from ..types.channel import (
PartialChannel as PartialChannelPayload,
)
BE = TypeVar('BE', bound=BaseException)
_State = Union[ConnectionState, '_WebhookState']
MISSING = utils.MISSING
MISSING: Any = utils.MISSING
class AsyncDeferredLock:
@ -82,14 +96,19 @@ class AsyncDeferredLock:
self.lock = lock
self.delta: Optional[float] = None
async def __aenter__(self):
async def __aenter__(self) -> Self:
await self.lock.acquire()
return self
def delay_by(self, delta: float) -> None:
self.delta = delta
async def __aexit__(self, type, value, traceback):
async def __aexit__(
self,
exc_type: Optional[Type[BE]],
exc: Optional[BE],
traceback: Optional[TracebackType],
) -> None:
if self.delta:
await asyncio.sleep(self.delta)
self.lock.release()
@ -106,7 +125,7 @@ class AsyncWebhookAdapter:
*,
payload: Optional[Dict[str, Any]] = None,
multipart: Optional[List[Dict[str, Any]]] = None,
files: Optional[List[File]] = None,
files: Optional[Sequence[File]] = None,
reason: Optional[str] = None,
auth_token: Optional[str] = None,
params: Optional[Dict[str, Any]] = None,
@ -259,7 +278,7 @@ class AsyncWebhookAdapter:
session: aiohttp.ClientSession,
payload: Optional[Dict[str, Any]] = None,
multipart: Optional[List[Dict[str, Any]]] = None,
files: Optional[List[File]] = None,
files: Optional[Sequence[File]] = None,
thread_id: Optional[int] = None,
wait: bool = False,
) -> Response[Optional[MessagePayload]]:
@ -276,6 +295,7 @@ class AsyncWebhookAdapter:
message_id: int,
*,
session: aiohttp.ClientSession,
thread_id: Optional[int] = None,
) -> Response[MessagePayload]:
route = Route(
'GET',
@ -284,7 +304,8 @@ class AsyncWebhookAdapter:
webhook_token=token,
message_id=message_id,
)
return self.request(route, session)
params = None if thread_id is None else {'thread_id': thread_id}
return self.request(route, session, params=params)
def edit_webhook_message(
self,
@ -295,7 +316,8 @@ class AsyncWebhookAdapter:
session: aiohttp.ClientSession,
payload: Optional[Dict[str, Any]] = None,
multipart: Optional[List[Dict[str, Any]]] = None,
files: Optional[List[File]] = None,
files: Optional[Sequence[File]] = None,
thread_id: Optional[int] = None,
) -> Response[Message]:
route = Route(
'PATCH',
@ -304,7 +326,8 @@ class AsyncWebhookAdapter:
webhook_token=token,
message_id=message_id,
)
return self.request(route, session, payload=payload, multipart=multipart, files=files)
params = None if thread_id is None else {'thread_id': thread_id}
return self.request(route, session, payload=payload, multipart=multipart, files=files, params=params)
def delete_webhook_message(
self,
@ -313,6 +336,7 @@ class AsyncWebhookAdapter:
message_id: int,
*,
session: aiohttp.ClientSession,
thread_id: Optional[int] = None,
) -> Response[None]:
route = Route(
'DELETE',
@ -321,7 +345,8 @@ class AsyncWebhookAdapter:
webhook_token=token,
message_id=message_id,
)
return self.request(route, session)
params = None if thread_id is None else {'thread_id': thread_id}
return self.request(route, session, params=params)
def fetch_webhook(
self,
@ -344,111 +369,6 @@ class AsyncWebhookAdapter:
return self.request(route, session=session)
def interaction_response_params(type: int, data: Optional[Dict[str, Any]] = None) -> MultipartParameters:
payload: Dict[str, Any] = {
'type': type,
}
if data is not None:
payload['data'] = data
return MultipartParameters(payload=payload, multipart=None, files=None)
# This is a subset of handle_message_parameters
def interaction_message_response_params(
*,
type: int,
content: Optional[str] = MISSING,
tts: bool = False,
flags: MessageFlags = MISSING,
file: File = MISSING,
files: List[File] = MISSING,
embed: Optional[Embed] = MISSING,
embeds: List[Embed] = MISSING,
attachments: List[Union[Attachment, File]] = MISSING,
allowed_mentions: Optional[AllowedMentions] = MISSING,
previous_allowed_mentions: Optional[AllowedMentions] = None,
) -> MultipartParameters:
if files is not MISSING and file is not MISSING:
raise TypeError('Cannot mix file and files keyword arguments.')
if embeds is not MISSING and embed is not MISSING:
raise TypeError('Cannot mix embed and embeds keyword arguments.')
if file is not MISSING:
files = [file]
if attachments is not MISSING and files is not MISSING:
raise TypeError('Cannot mix attachments and files keyword arguments.')
data: Optional[Dict[str, Any]] = {
'tts': tts,
}
if embeds is not MISSING:
if len(embeds) > 10:
raise ValueError('embeds has a maximum of 10 elements.')
data['embeds'] = [e.to_dict() for e in embeds]
if embed is not MISSING:
if embed is None:
data['embeds'] = []
else:
data['embeds'] = [embed.to_dict()]
if content is not MISSING:
if content is not None:
data['content'] = str(content)
else:
data['content'] = None
if flags is not MISSING:
data['flags'] = flags.value
if allowed_mentions:
if previous_allowed_mentions is not None:
data['allowed_mentions'] = previous_allowed_mentions.merge(allowed_mentions).to_dict()
else:
data['allowed_mentions'] = allowed_mentions.to_dict()
elif previous_allowed_mentions is not None:
data['allowed_mentions'] = previous_allowed_mentions.to_dict()
if attachments is MISSING:
attachments = files # type: ignore
else:
files = [a for a in attachments if isinstance(a, File)]
if attachments is not MISSING:
file_index = 0
attachments_payload = []
for attachment in attachments:
if isinstance(attachment, File):
attachments_payload.append(attachment.to_dict(file_index))
file_index += 1
else:
attachments_payload.append(attachment.to_dict())
data['attachments'] = attachments_payload
multipart = []
if files:
data = {'type': type, 'data': data}
multipart.append({'name': 'payload_json', 'value': utils._to_json(data)})
data = None
for index, file in enumerate(files):
multipart.append(
{
'name': f'files[{index}]',
'value': file.fp,
'filename': file.filename,
'content_type': 'application/octet-stream',
}
)
else:
data = {'type': type, 'data': data}
return MultipartParameters(payload=data, multipart=multipart, files=files)
async_context: ContextVar[AsyncWebhookAdapter] = ContextVar('async_webhook_context', default=AsyncWebhookAdapter())
@ -469,11 +389,11 @@ class PartialWebhookChannel(Hashable):
__slots__ = ('id', 'name')
def __init__(self, *, data):
self.id = int(data['id'])
self.name = data['name']
def __init__(self, *, data: PartialChannelPayload) -> None:
self.id: int = int(data['id'])
self.name: str = data['name']
def __repr__(self):
def __repr__(self) -> str:
return f'<PartialWebhookChannel name={self.name!r} id={self.id}>'
@ -494,13 +414,13 @@ class PartialWebhookGuild(Hashable):
__slots__ = ('id', 'name', '_icon', '_state')
def __init__(self, *, data, state):
self._state = state
self.id = int(data['id'])
self.name = data['name']
self._icon = data['icon']
def __init__(self, *, data: SourceGuildPayload, state: _State) -> None:
self._state: _State = state
self.id: int = int(data['id'])
self.name: str = data['name']
self._icon: str = data['icon']
def __repr__(self):
def __repr__(self) -> str:
return f'<PartialWebhookGuild name={self.name!r} id={self.id}>'
@property
@ -514,14 +434,14 @@ class PartialWebhookGuild(Hashable):
class _FriendlyHttpAttributeErrorHelper:
__slots__ = ()
def __getattr__(self, attr):
def __getattr__(self, attr: str) -> Any:
raise AttributeError('PartialWebhookState does not support http methods.')
class _WebhookState:
__slots__ = ('_parent', '_webhook')
__slots__ = ('_parent', '_webhook', '_thread')
def __init__(self, webhook: Any, parent: Optional[Union[ConnectionState, _WebhookState]]):
def __init__(self, webhook: Any, parent: Optional[_State], thread: Snowflake = MISSING):
self._webhook: Any = webhook
self._parent: Optional[ConnectionState]
@ -530,23 +450,25 @@ class _WebhookState:
else:
self._parent = parent
def _get_guild(self, guild_id):
self._thread: Snowflake = thread
def _get_guild(self, guild_id: Optional[int]) -> Optional[Guild]:
if self._parent is not None:
return self._parent._get_guild(guild_id)
return None
def store_user(self, data):
def store_user(self, data: Union[UserPayload, PartialUserPayload]) -> BaseUser:
if self._parent is not None:
return self._parent.store_user(data)
# state parameter is artificial
return BaseUser(state=self, data=data) # type: ignore
def create_user(self, data):
def create_user(self, data: Union[UserPayload, PartialUserPayload]) -> BaseUser:
# state parameter is artificial
return BaseUser(state=self, data=data) # type: ignore
@property
def http(self):
def http(self) -> Union[HTTPClient, _FriendlyHttpAttributeErrorHelper]:
if self._parent is not None:
return self._parent.http
@ -554,7 +476,7 @@ class _WebhookState:
# However, using it should result in a late-binding error
return _FriendlyHttpAttributeErrorHelper()
def __getattr__(self, attr):
def __getattr__(self, attr: str) -> Any:
if self._parent is not None:
return getattr(self._parent, attr)
@ -578,9 +500,9 @@ class WebhookMessage(Message):
async def edit(
self,
content: Optional[str] = MISSING,
embeds: List[Embed] = MISSING,
embeds: Sequence[Embed] = MISSING,
embed: Optional[Embed] = MISSING,
attachments: List[Union[Attachment, File]] = MISSING,
attachments: Sequence[Union[Attachment, File]] = MISSING,
allowed_mentions: Optional[AllowedMentions] = None,
) -> WebhookMessage:
"""|coro|
@ -593,8 +515,8 @@ class WebhookMessage(Message):
The edit is no longer in-place, instead the newly edited message is returned.
.. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising
:exc:`ValueError`.
This function will now raise :exc:`ValueError` instead of
``InvalidArgument``.
Parameters
------------
@ -642,6 +564,7 @@ class WebhookMessage(Message):
embed=embed,
attachments=attachments,
allowed_mentions=allowed_mentions,
thread=self._state._thread,
)
async def add_files(self, *files: File) -> WebhookMessage:
@ -722,13 +645,13 @@ class WebhookMessage(Message):
async def inner_call(delay: float = delay):
await asyncio.sleep(delay)
try:
await self._state._webhook.delete_message(self.id)
await self._state._webhook.delete_message(self.id, thread=self._state._thread)
except HTTPException:
pass
asyncio.create_task(inner_call())
else:
await self._state._webhook.delete_message(self.id)
await self._state._webhook.delete_message(self.id, thread=self._state._thread)
class BaseWebhook(Hashable):
@ -747,19 +670,24 @@ class BaseWebhook(Hashable):
'_state',
)
def __init__(self, data: WebhookPayload, token: Optional[str] = None, state: Optional[ConnectionState] = None):
def __init__(
self,
data: WebhookPayload,
token: Optional[str] = None,
state: Optional[_State] = None,
) -> None:
self.auth_token: Optional[str] = token
self._state: Union[ConnectionState, _WebhookState] = state or _WebhookState(self, parent=state)
self._state: _State = state or _WebhookState(self, parent=state)
self._update(data)
def _update(self, data: WebhookPayload):
self.id = int(data['id'])
self.type = try_enum(WebhookType, int(data['type']))
self.channel_id = utils._get_as_snowflake(data, 'channel_id')
self.guild_id = utils._get_as_snowflake(data, 'guild_id')
self.name = data.get('name')
self._avatar = data.get('avatar')
self.token = data.get('token')
def _update(self, data: WebhookPayload) -> None:
self.id: int = int(data['id'])
self.type: WebhookType = try_enum(WebhookType, int(data['type']))
self.channel_id: Optional[int] = utils._get_as_snowflake(data, 'channel_id')
self.guild_id: Optional[int] = utils._get_as_snowflake(data, 'guild_id')
self.name: Optional[str] = data.get('name')
self._avatar: Optional[str] = data.get('avatar')
self.token: Optional[str] = data.get('token')
user = data.get('user')
self.user: Optional[Union[BaseUser, User]] = None
@ -927,11 +855,17 @@ class Webhook(BaseWebhook):
__slots__: Tuple[str, ...] = ('session',)
def __init__(self, data: WebhookPayload, session: aiohttp.ClientSession, token: Optional[str] = None, state=None):
def __init__(
self,
data: WebhookPayload,
session: aiohttp.ClientSession,
token: Optional[str] = None,
state: Optional[_State] = None,
) -> None:
super().__init__(data, token, state)
self.session = session
self.session: aiohttp.ClientSession = session
def __repr__(self):
def __repr__(self) -> str:
return f'<Webhook id={self.id!r}>'
@property
@ -940,7 +874,7 @@ class Webhook(BaseWebhook):
return f'https://discord.com/api/webhooks/{self.id}/{self.token}'
@classmethod
def partial(cls, id: int, token: str, *, session: aiohttp.ClientSession, bot_token: Optional[str] = None) -> Webhook:
def partial(cls, id: int, token: str, *, session: aiohttp.ClientSession, bot_token: Optional[str] = None) -> Self:
"""Creates a partial :class:`Webhook`.
Parameters
@ -976,12 +910,12 @@ class Webhook(BaseWebhook):
return cls(data, session, token=bot_token)
@classmethod
def from_url(cls, url: str, *, session: aiohttp.ClientSession, bot_token: Optional[str] = None) -> Webhook:
def from_url(cls, url: str, *, session: aiohttp.ClientSession, bot_token: Optional[str] = None) -> Self:
"""Creates a partial :class:`Webhook` from a webhook URL.
.. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising
:exc:`ValueError`.
This function will now raise :exc:`ValueError` instead of
``InvalidArgument``.
Parameters
------------
@ -1019,7 +953,7 @@ class Webhook(BaseWebhook):
return cls(data, session, token=bot_token) # type: ignore
@classmethod
def _as_follower(cls, data, *, channel, user) -> Webhook:
def _as_follower(cls, data, *, channel, user) -> Self:
name = f"{channel.guild} #{channel}"
feed: WebhookPayload = {
'id': data['webhook_id'],
@ -1035,8 +969,8 @@ class Webhook(BaseWebhook):
return cls(feed, session=session, state=state, token=state.http.token)
@classmethod
def from_state(cls, data, state) -> Webhook:
session = state.http._HTTPClient__session
def from_state(cls, data: WebhookPayload, state: ConnectionState) -> Self:
session = state.http._HTTPClient__session # type: ignore
return cls(data, session=session, state=state, token=state.http.token)
async def fetch(self, *, prefer_auth: bool = True) -> Webhook:
@ -1085,7 +1019,7 @@ class Webhook(BaseWebhook):
return Webhook(data, self.session, token=self.auth_token, state=self._state)
async def delete(self, *, reason: Optional[str] = None, prefer_auth: bool = True):
async def delete(self, *, reason: Optional[str] = None, prefer_auth: bool = True) -> None:
"""|coro|
Deletes this Webhook.
@ -1137,8 +1071,8 @@ class Webhook(BaseWebhook):
Edits this Webhook.
.. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising
:exc:`ValueError`.
This function will now raise :exc:`ValueError` instead of
``~InvalidArgument``.
Parameters
------------
@ -1203,8 +1137,8 @@ class Webhook(BaseWebhook):
return Webhook(data=data, session=self.session, token=self.auth_token, state=self._state)
def _create_message(self, data):
state = _WebhookState(self, parent=self._state)
def _create_message(self, data, *, thread: Snowflake):
state = _WebhookState(self, parent=self._state, thread=thread)
# state may be artificial (unlikely at this point...)
channel = self.channel or PartialMessageable(state=self._state, id=int(data['channel_id'])) # type: ignore
# state is artificial
@ -1219,9 +1153,9 @@ class Webhook(BaseWebhook):
avatar_url: Any = MISSING,
tts: bool = MISSING,
file: File = MISSING,
files: List[File] = MISSING,
files: Sequence[File] = MISSING,
embed: Embed = MISSING,
embeds: List[Embed] = MISSING,
embeds: Sequence[Embed] = MISSING,
allowed_mentions: AllowedMentions = MISSING,
thread: Snowflake = MISSING,
wait: Literal[True],
@ -1238,9 +1172,9 @@ class Webhook(BaseWebhook):
avatar_url: Any = MISSING,
tts: bool = MISSING,
file: File = MISSING,
files: List[File] = MISSING,
files: Sequence[File] = MISSING,
embed: Embed = MISSING,
embeds: List[Embed] = MISSING,
embeds: Sequence[Embed] = MISSING,
allowed_mentions: AllowedMentions = MISSING,
thread: Snowflake = MISSING,
wait: Literal[False] = ...,
@ -1256,9 +1190,9 @@ class Webhook(BaseWebhook):
avatar_url: Any = MISSING,
tts: bool = False,
file: File = MISSING,
files: List[File] = MISSING,
files: Sequence[File] = MISSING,
embed: Embed = MISSING,
embeds: List[Embed] = MISSING,
embeds: Sequence[Embed] = MISSING,
allowed_mentions: AllowedMentions = MISSING,
thread: Snowflake = MISSING,
wait: bool = False,
@ -1278,8 +1212,8 @@ class Webhook(BaseWebhook):
``embeds`` parameter, which must be a :class:`list` of :class:`Embed` objects to send.
.. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising
:exc:`ValueError`.
This function will now raise :exc:`ValueError` instead of
``InvalidArgument``.
Parameters
------------
@ -1389,11 +1323,11 @@ class Webhook(BaseWebhook):
msg = None
if wait:
msg = self._create_message(data)
msg = self._create_message(data, thread=thread)
return msg
async def fetch_message(self, id: int, /) -> WebhookMessage:
async def fetch_message(self, id: int, /, *, thread: Snowflake = MISSING) -> WebhookMessage:
"""|coro|
Retrieves a single :class:`~discord.WebhookMessage` owned by this webhook.
@ -1404,6 +1338,8 @@ class Webhook(BaseWebhook):
------------
id: :class:`int`
The message ID to look for.
thread: :class:`~discord.abc.Snowflake`
The thread to look in.
Raises
--------
@ -1425,24 +1361,30 @@ class Webhook(BaseWebhook):
if self.token is None:
raise ValueError('This webhook does not have a token associated with it')
thread_id: Optional[int] = None
if thread is not MISSING:
thread_id = thread.id
adapter = async_context.get()
data = await adapter.get_webhook_message(
self.id,
self.token,
id,
session=self.session,
thread_id=thread_id,
)
return self._create_message(data)
return self._create_message(data, thread=thread)
async def edit_message(
self,
message_id: int,
*,
content: Optional[str] = MISSING,
embeds: List[Embed] = MISSING,
embeds: Sequence[Embed] = MISSING,
embed: Optional[Embed] = MISSING,
attachments: List[Union[Attachment, File]] = MISSING,
attachments: Sequence[Union[Attachment, File]] = MISSING,
allowed_mentions: Optional[AllowedMentions] = None,
thread: Snowflake = MISSING,
) -> WebhookMessage:
"""|coro|
@ -1457,8 +1399,8 @@ class Webhook(BaseWebhook):
The edit is no longer in-place, instead the newly edited message is returned.
.. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising
:exc:`ValueError`.
This function will now raise :exc:`ValueError` instead of
``InvalidArgument``.
Parameters
------------
@ -1479,6 +1421,10 @@ class Webhook(BaseWebhook):
allowed_mentions: :class:`AllowedMentions`
Controls the mentions being processed in this message.
See :meth:`.abc.Messageable.send` for more information.
thread: :class:`~discord.abc.Snowflake`
The thread the webhook message belongs to.
.. versionadded:: 2.0
Raises
-------
@ -1511,6 +1457,11 @@ class Webhook(BaseWebhook):
allowed_mentions=allowed_mentions,
previous_allowed_mentions=previous_mentions,
)
thread_id: Optional[int] = None
if thread is not MISSING:
thread_id = thread.id
adapter = async_context.get()
data = await adapter.edit_webhook_message(
self.id,
@ -1520,12 +1471,13 @@ class Webhook(BaseWebhook):
payload=params.payload,
multipart=params.multipart,
files=params.files,
thread_id=thread_id,
)
message = self._create_message(data)
message = self._create_message(data, thread=thread)
return message
async def delete_message(self, message_id: int, /) -> None:
async def delete_message(self, message_id: int, /, *, thread: Snowflake = MISSING) -> None:
"""|coro|
Deletes a message owned by this webhook.
@ -1540,13 +1492,17 @@ class Webhook(BaseWebhook):
``message_id`` parameter is now positional-only.
.. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising
:exc:`ValueError`.
This function will now raise :exc:`ValueError` instead of
``InvalidArgument``.
Parameters
------------
message_id: :class:`int`
The message ID to delete.
thread: :class:`~discord.abc.Snowflake`
The thread the webhook message belongs to.
.. versionadded:: 2.0
Raises
-------
@ -1560,10 +1516,15 @@ class Webhook(BaseWebhook):
if self.token is None:
raise ValueError('This webhook does not have a token associated with it')
thread_id: Optional[int] = None
if thread is not MISSING:
thread_id = thread.id
adapter = async_context.get()
await adapter.delete_webhook_message(
self.id,
self.token,
message_id,
session=self.session,
thread_id=thread_id,
)

148
discord/webhook/sync.py

@ -37,7 +37,7 @@ import time
import re
from urllib.parse import quote as urlquote
from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Union, overload
from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Sequence, Tuple, Union, TypeVar, Type, overload
import weakref
from .. import utils
@ -56,36 +56,50 @@ __all__ = (
_log = logging.getLogger(__name__)
if TYPE_CHECKING:
from typing_extensions import Self
from types import TracebackType
from ..file import File
from ..embeds import Embed
from ..mentions import AllowedMentions
from ..message import Attachment
from ..abc import Snowflake
from ..state import ConnectionState
from ..types.webhook import (
Webhook as WebhookPayload,
)
from ..abc import Snowflake
from ..types.message import (
Message as MessagePayload,
)
BE = TypeVar('BE', bound=BaseException)
try:
from requests import Session, Response
except ModuleNotFoundError:
pass
MISSING = utils.MISSING
MISSING: Any = utils.MISSING
class DeferredLock:
def __init__(self, lock: threading.Lock):
self.lock = lock
def __init__(self, lock: threading.Lock) -> None:
self.lock: threading.Lock = lock
self.delta: Optional[float] = None
def __enter__(self):
def __enter__(self) -> Self:
self.lock.acquire()
return self
def delay_by(self, delta: float) -> None:
self.delta = delta
def __exit__(self, type, value, traceback):
def __exit__(
self,
exc_type: Optional[Type[BE]],
exc: Optional[BE],
traceback: Optional[TracebackType],
) -> None:
if self.delta:
time.sleep(self.delta)
self.lock.release()
@ -102,7 +116,7 @@ class WebhookAdapter:
*,
payload: Optional[Dict[str, Any]] = None,
multipart: Optional[List[Dict[str, Any]]] = None,
files: Optional[List[File]] = None,
files: Optional[Sequence[File]] = None,
reason: Optional[str] = None,
auth_token: Optional[str] = None,
params: Optional[Dict[str, Any]] = None,
@ -218,7 +232,7 @@ class WebhookAdapter:
token: Optional[str] = None,
session: Session,
reason: Optional[str] = None,
):
) -> None:
route = Route('DELETE', '/webhooks/{webhook_id}', webhook_id=webhook_id)
return self.request(route, session, reason=reason, auth_token=token)
@ -229,7 +243,7 @@ class WebhookAdapter:
*,
session: Session,
reason: Optional[str] = None,
):
) -> None:
route = Route('DELETE', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
return self.request(route, session, reason=reason)
@ -241,7 +255,7 @@ class WebhookAdapter:
*,
session: Session,
reason: Optional[str] = None,
):
) -> WebhookPayload:
route = Route('PATCH', '/webhooks/{webhook_id}', webhook_id=webhook_id)
return self.request(route, session, reason=reason, payload=payload, auth_token=token)
@ -253,7 +267,7 @@ class WebhookAdapter:
*,
session: Session,
reason: Optional[str] = None,
):
) -> WebhookPayload:
route = Route('PATCH', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
return self.request(route, session, reason=reason, payload=payload)
@ -265,10 +279,10 @@ class WebhookAdapter:
session: Session,
payload: Optional[Dict[str, Any]] = None,
multipart: Optional[List[Dict[str, Any]]] = None,
files: Optional[List[File]] = None,
files: Optional[Sequence[File]] = None,
thread_id: Optional[int] = None,
wait: bool = False,
):
) -> MessagePayload:
params = {'wait': int(wait)}
if thread_id:
params['thread_id'] = thread_id
@ -282,7 +296,8 @@ class WebhookAdapter:
message_id: int,
*,
session: Session,
):
thread_id: Optional[int] = None,
) -> MessagePayload:
route = Route(
'GET',
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}',
@ -290,7 +305,8 @@ class WebhookAdapter:
webhook_token=token,
message_id=message_id,
)
return self.request(route, session)
params = None if thread_id is None else {'thread_id': thread_id}
return self.request(route, session, params=params)
def edit_webhook_message(
self,
@ -301,8 +317,9 @@ class WebhookAdapter:
session: Session,
payload: Optional[Dict[str, Any]] = None,
multipart: Optional[List[Dict[str, Any]]] = None,
files: Optional[List[File]] = None,
):
files: Optional[Sequence[File]] = None,
thread_id: Optional[int] = None,
) -> MessagePayload:
route = Route(
'PATCH',
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}',
@ -310,7 +327,8 @@ class WebhookAdapter:
webhook_token=token,
message_id=message_id,
)
return self.request(route, session, payload=payload, multipart=multipart, files=files)
params = None if thread_id is None else {'thread_id': thread_id}
return self.request(route, session, payload=payload, multipart=multipart, files=files, params=params)
def delete_webhook_message(
self,
@ -319,7 +337,8 @@ class WebhookAdapter:
message_id: int,
*,
session: Session,
):
thread_id: Optional[int] = None,
) -> None:
route = Route(
'DELETE',
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}',
@ -327,7 +346,8 @@ class WebhookAdapter:
webhook_token=token,
message_id=message_id,
)
return self.request(route, session)
params = None if thread_id is None else {'thread_id': thread_id}
return self.request(route, session, params=params)
def fetch_webhook(
self,
@ -335,7 +355,7 @@ class WebhookAdapter:
token: str,
*,
session: Session,
):
) -> WebhookPayload:
route = Route('GET', '/webhooks/{webhook_id}', webhook_id=webhook_id)
return self.request(route, session=session, auth_token=token)
@ -345,7 +365,7 @@ class WebhookAdapter:
token: str,
*,
session: Session,
):
) -> WebhookPayload:
route = Route('GET', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
return self.request(route, session=session)
@ -380,16 +400,16 @@ class SyncWebhookMessage(Message):
def edit(
self,
content: Optional[str] = MISSING,
embeds: List[Embed] = MISSING,
embeds: Sequence[Embed] = MISSING,
embed: Optional[Embed] = MISSING,
attachments: List[Union[Attachment, File]] = MISSING,
attachments: Sequence[Union[Attachment, File]] = MISSING,
allowed_mentions: Optional[AllowedMentions] = None,
) -> SyncWebhookMessage:
"""Edits the message.
.. versionchanged:: 2.0
This function no-longer raises ``InvalidArgument`` instead raising
:exc:`ValueError` or :exc:`TypeError` in various cases.
This function will now raise :exc:`TypeError` or
:exc:`ValueError` instead of ``InvalidArgument``.
Parameters
------------
@ -437,6 +457,7 @@ class SyncWebhookMessage(Message):
embed=embed,
attachments=attachments,
allowed_mentions=allowed_mentions,
thread=self._state._thread,
)
def add_files(self, *files: File) -> SyncWebhookMessage:
@ -508,7 +529,7 @@ class SyncWebhookMessage(Message):
if delay is not None:
time.sleep(delay)
self._state._webhook.delete_message(self.id)
self._state._webhook.delete_message(self.id, thread=self._state._thread)
class SyncWebhook(BaseWebhook):
@ -569,11 +590,17 @@ class SyncWebhook(BaseWebhook):
__slots__: Tuple[str, ...] = ('session',)
def __init__(self, data: WebhookPayload, session: Session, token: Optional[str] = None, state=None):
def __init__(
self,
data: WebhookPayload,
session: Session,
token: Optional[str] = None,
state: Optional[Union[ConnectionState, _WebhookState]] = None,
) -> None:
super().__init__(data, token, state)
self.session = session
self.session: Session = session
def __repr__(self):
def __repr__(self) -> str:
return f'<Webhook id={self.id!r}>'
@property
@ -812,8 +839,8 @@ class SyncWebhook(BaseWebhook):
return SyncWebhook(data=data, session=self.session, token=self.auth_token, state=self._state)
def _create_message(self, data):
state = _WebhookState(self, parent=self._state)
def _create_message(self, data: MessagePayload, *, thread: Snowflake = MISSING) -> SyncWebhookMessage:
state = _WebhookState(self, parent=self._state, thread=thread)
# state may be artificial (unlikely at this point...)
channel = self.channel or PartialMessageable(state=self._state, id=int(data['channel_id'])) # type: ignore
# state is artificial
@ -828,10 +855,11 @@ class SyncWebhook(BaseWebhook):
avatar_url: Any = MISSING,
tts: bool = MISSING,
file: File = MISSING,
files: List[File] = MISSING,
files: Sequence[File] = MISSING,
embed: Embed = MISSING,
embeds: List[Embed] = MISSING,
embeds: Sequence[Embed] = MISSING,
allowed_mentions: AllowedMentions = MISSING,
thread: Snowflake = MISSING,
wait: Literal[True],
suppress_embeds: bool = MISSING,
) -> SyncWebhookMessage:
@ -846,10 +874,11 @@ class SyncWebhook(BaseWebhook):
avatar_url: Any = MISSING,
tts: bool = MISSING,
file: File = MISSING,
files: List[File] = MISSING,
files: Sequence[File] = MISSING,
embed: Embed = MISSING,
embeds: List[Embed] = MISSING,
embeds: Sequence[Embed] = MISSING,
allowed_mentions: AllowedMentions = MISSING,
thread: Snowflake = MISSING,
wait: Literal[False] = ...,
suppress_embeds: bool = MISSING,
) -> None:
@ -863,9 +892,9 @@ class SyncWebhook(BaseWebhook):
avatar_url: Any = MISSING,
tts: bool = False,
file: File = MISSING,
files: List[File] = MISSING,
files: Sequence[File] = MISSING,
embed: Embed = MISSING,
embeds: List[Embed] = MISSING,
embeds: Sequence[Embed] = MISSING,
allowed_mentions: AllowedMentions = MISSING,
thread: Snowflake = MISSING,
wait: bool = False,
@ -984,9 +1013,9 @@ class SyncWebhook(BaseWebhook):
wait=wait,
)
if wait:
return self._create_message(data)
return self._create_message(data, thread=thread)
def fetch_message(self, id: int, /) -> SyncWebhookMessage:
def fetch_message(self, id: int, /, *, thread: Snowflake = MISSING) -> SyncWebhookMessage:
"""Retrieves a single :class:`~discord.SyncWebhookMessage` owned by this webhook.
.. versionadded:: 2.0
@ -995,6 +1024,8 @@ class SyncWebhook(BaseWebhook):
------------
id: :class:`int`
The message ID to look for.
thread: :class:`~discord.abc.Snowflake`
The thread to look in.
Raises
--------
@ -1016,24 +1047,30 @@ class SyncWebhook(BaseWebhook):
if self.token is None:
raise ValueError('This webhook does not have a token associated with it')
thread_id: Optional[int] = None
if thread is not MISSING:
thread_id = thread.id
adapter: WebhookAdapter = _get_webhook_adapter()
data = adapter.get_webhook_message(
self.id,
self.token,
id,
session=self.session,
thread_id=thread_id,
)
return self._create_message(data)
return self._create_message(data, thread=thread)
def edit_message(
self,
message_id: int,
*,
content: Optional[str] = MISSING,
embeds: List[Embed] = MISSING,
embeds: Sequence[Embed] = MISSING,
embed: Optional[Embed] = MISSING,
attachments: List[Union[Attachment, File]] = MISSING,
attachments: Sequence[Union[Attachment, File]] = MISSING,
allowed_mentions: Optional[AllowedMentions] = None,
thread: Snowflake = MISSING,
) -> SyncWebhookMessage:
"""Edits a message owned by this webhook.
@ -1061,6 +1098,10 @@ class SyncWebhook(BaseWebhook):
allowed_mentions: :class:`AllowedMentions`
Controls the mentions being processed in this message.
See :meth:`.abc.Messageable.send` for more information.
thread: :class:`~discord.abc.Snowflake`
The thread the webhook message belongs to.
.. versionadded:: 2.0
Raises
-------
@ -1087,6 +1128,11 @@ class SyncWebhook(BaseWebhook):
allowed_mentions=allowed_mentions,
previous_allowed_mentions=previous_mentions,
)
thread_id: Optional[int] = None
if thread is not MISSING:
thread_id = thread.id
adapter: WebhookAdapter = _get_webhook_adapter()
data = adapter.edit_webhook_message(
self.id,
@ -1096,10 +1142,11 @@ class SyncWebhook(BaseWebhook):
payload=params.payload,
multipart=params.multipart,
files=params.files,
thread_id=thread_id,
)
return self._create_message(data)
return self._create_message(data, thread=thread)
def delete_message(self, message_id: int, /) -> None:
def delete_message(self, message_id: int, /, *, thread: Snowflake = MISSING) -> None:
"""Deletes a message owned by this webhook.
This is a lower level interface to :meth:`WebhookMessage.delete` in case
@ -1111,6 +1158,10 @@ class SyncWebhook(BaseWebhook):
------------
message_id: :class:`int`
The message ID to delete.
hread: :class:`~discord.abc.Snowflake`
The thread the webhook message belongs to.
.. versionadded:: 2.0
Raises
-------
@ -1124,10 +1175,15 @@ class SyncWebhook(BaseWebhook):
if self.token is None:
raise ValueError('This webhook does not have a token associated with it')
thread_id: Optional[int] = None
if thread is not MISSING:
thread_id = thread.id
adapter: WebhookAdapter = _get_webhook_adapter()
adapter.delete_webhook_message(
self.id,
self.token,
message_id,
session=self.session,
thread_id=thread_id,
)

31
discord/widget.py

@ -188,7 +188,7 @@ class WidgetMember(BaseUser):
except KeyError:
activity = None
else:
activity = create_activity(game)
activity = create_activity(game, state)
self.activity: Optional[Union[BaseActivity, Spotify]] = activity
@ -231,7 +231,7 @@ class Widget:
channels: List[:class:`WidgetChannel`]
The accessible voice channels in the guild.
members: List[:class:`Member`]
The online members in the server. Offline members
The online members in the guild. Offline members
do not appear in the widget.
.. note::
@ -240,10 +240,15 @@ class Widget:
the users will be "anonymized" with linear IDs and discriminator
information being incorrect. Likewise, the number of members
retrieved is capped.
presence_count: :class:`int`
The approximate number of online members in the guild.
Offline members are not included in this count.
.. versionadded:: 2.0
"""
__slots__ = ('_state', 'channels', '_invite', 'id', 'members', 'name')
__slots__ = ('_state', 'channels', '_invite', 'id', 'members', 'name', 'presence_count')
def __init__(self, *, state: ConnectionState, data: WidgetPayload) -> None:
self._state = state
@ -268,10 +273,12 @@ class Widget:
self.members.append(WidgetMember(state=self._state, data=member, connected_channel=connected_channel))
self.presence_count: int = data['presence_count']
def __str__(self) -> str:
return self.json_url
def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
if isinstance(other, Widget):
return self.id == other.id
return False
@ -290,11 +297,11 @@ class Widget:
return f"https://discord.com/api/guilds/{self.id}/widget.json"
@property
def invite_url(self) -> str:
def invite_url(self) -> Optional[str]:
"""Optional[:class:`str`]: The invite URL for the guild, if available."""
return self._invite
async def fetch_invite(self, *, with_counts: bool = True) -> Invite:
async def fetch_invite(self, *, with_counts: bool = True) -> Optional[Invite]:
"""|coro|
Retrieves an :class:`Invite` from the widget's invite URL.
@ -310,9 +317,11 @@ class Widget:
Returns
--------
:class:`Invite`
The invite from the widget's invite URL.
Optional[:class:`Invite`]
The invite from the widget's invite URL, if available.
"""
resolved = resolve_invite(self._invite)
data = await self._state.http.get_invite(resolved.code, with_counts=with_counts)
return Invite.from_incomplete(state=self._state, data=data)
if self._invite:
resolved = resolve_invite(self._invite)
data = await self._state.http.get_invite(resolved.code, with_counts=with_counts)
return Invite.from_incomplete(state=self._state, data=data)
return None

16
docs/_static/custom.js

@ -95,3 +95,19 @@ document.addEventListener('keydown', (event) => {
activeModal.close();
}
});
function searchBarClick(event, which) {
event.preventDefault();
if (event.button === 1 || event.buttons === 4) {
which.target = "_blank"; // Middle mouse button was clicked. Set our target to a new tab.
}
else if (event.button === 2) {
return // Right button was clicked... Don't do anything here.
}
else {
which.target = "_self"; // Revert to same window.
}
which.submit();
}

18
docs/_static/style.css

@ -61,6 +61,8 @@ Historically however, thanks to:
--search-focus: var(--blue-1);
--search-button: var(--grey-1);
--search-button-hover: var(--grey-1-8);
--search-sidebar-background: var(--grey-1);
--search-sidebar-text: var(--grey-7);
--footer-text: var(--grey-5);
--footer-link: var(--grey-6);
--hr-border: var(--grey-2);
@ -167,6 +169,8 @@ Historically however, thanks to:
--attribute-table-entry-hover-text: var(--blue-1);
--attribute-table-badge: var(--grey-4);
--highlighted-text: rgba(250, 166, 26, 0.2);
--search-sidebar-background: var(--grey-7);
--search-sidebar-text: var(--search-text);
}
img[src$="snake_dark.svg"] {
@ -523,6 +527,20 @@ input[type=search]:focus ~ button[type=submit] {
color: var(--search-button-hover);
}
/* search sidebar */
.search-sidebar > input[type=search],
.search-sidebar > button[type=submit] {
background-color: var(--search-sidebar-background);
color: var(--search-sidebar-text);
}
.sidebar-toggle .search-sidebar > input[type=search],
.sidebar-toggle .search-sidebar > button[type=submit] {
background-color: var(--mobile-nav-background);
color: var(--mobile-nav-text);
}
/* main content area */
main {

12
docs/_templates/layout.html

@ -89,10 +89,10 @@
<option value="{{ pathto(p + '/index')|e }}" {% if pagename is prefixedwith p %}selected{% endif %}>{{ ext }}</option>
{%- endfor %}
</select>
<form role="search" class="search" action="{{ pathto('search') }}" method="get">
<form id="search-form" role="search" class="search" action="{{ pathto('search') }}" method="get">
<div class="search-wrapper">
<input type="search" name="q" placeholder="{{ _('Search documentation') }}" />
<button type="submit">
<button type="submit" onmousedown="searchBarClick(event, document.getElementById('search-form'));">
<span class="material-icons">search</span>
</button>
</div>
@ -110,6 +110,14 @@
<span class="material-icons">settings</span>
</span>
<div id="sidebar">
<form id="sidebar-form" role="search" class="search" action="{{ pathto('search') }}" method="get">
<div class="search-wrapper search-sidebar">
<input type="search" name="q" placeholder="{{ _('Search documentation') }}" />
<button type="submit" onmousedown="searchBarClick(event, document.getElementById('sidebar-form'));">
<span class="material-icons">search</span>
</button>
</div>
</form>
{%- include "localtoc.html" %}
</div>
</aside>

1008
docs/api.rst

File diff suppressed because it is too large

2
docs/conf.py

@ -145,6 +145,8 @@ pygments_style = 'friendly'
# Nitpicky mode options
nitpick_ignore_files = [
"migrating_to_async",
"migrating_to_v1",
"migrating",
"whats_new",
]

21
docs/crowdin.yml

@ -0,0 +1,21 @@
# -*- coding: utf-8 -*-
project_id: "362783"
api_token_env: CROWDIN_API_KEY
files:
- source: /_build/locale/**/*.pot
translation: /locale/%two_letters_code%/LC_MESSAGES/%original_path%/%file_name%.po
# You must use `crowdin download --all` for this project
# I discovered after like an hour of debugging the Java CLI that `--all` actually means "use server sources"
# Without this, crowdin tries to determine the mapping itself, and decides that because
# `/locale/ja/LC_MESSAGES/_build/locale/...` doesn't exist, that it won't download anything
# There is no workaround for this. I tried. Trying to adjust the project base path just breaks things further.
# Crowdin does the conflict resolution on its end. The process to update translations is thus:
# - make gettext
# - crowdin upload
# - crowdin download --all
# You must set ${CROWDIN_API_KEY} in the environment.
# I will write an Actions workflow for this at a later date.

7
docs/ext/commands/api.rst

@ -411,6 +411,9 @@ Converters
.. autoclass:: discord.ext.commands.GuildStickerConverter
:members:
.. autoclass:: discord.ext.commands.ScheduledEventConverter
:members:
.. autoclass:: discord.ext.commands.clean_content
:members:
@ -539,6 +542,9 @@ Exceptions
.. autoexception:: discord.ext.commands.GuildStickerNotFound
:members:
.. autoexception:: discord.ext.commands.ScheduledEventNotFound
:members:
.. autoexception:: discord.ext.commands.BadBoolArgument
:members:
@ -623,6 +629,7 @@ Exception Hierarchy
- :exc:`~.commands.BadInviteArgument`
- :exc:`~.commands.EmojiNotFound`
- :exc:`~.commands.GuildStickerNotFound`
- :exc:`~.commands.ScheduledEventNotFound`
- :exc:`~.commands.PartialEmojiConversionFailure`
- :exc:`~.commands.BadBoolArgument`
- :exc:`~.commands.ThreadNotFound`

5
docs/ext/commands/cogs.rst

@ -58,7 +58,7 @@ Once you have defined your cogs, you need to tell the bot to register the cogs t
.. code-block:: python3
bot.add_cog(Greetings(bot))
await bot.add_cog(Greetings(bot))
This binds the cog to the bot, adding all commands and listeners to the bot automatically.
@ -66,7 +66,7 @@ Note that we reference the cog by name, which we can override through :ref:`ext_
.. code-block:: python3
bot.remove_cog('Greetings')
await bot.remove_cog('Greetings')
Using Cogs
-------------
@ -112,6 +112,7 @@ As cogs get more complicated and have more commands, there comes a point where w
They are as follows:
- :meth:`.Cog.cog_load`
- :meth:`.Cog.cog_unload`
- :meth:`.Cog.cog_check`
- :meth:`.Cog.cog_command_error`

19
docs/ext/commands/commands.rst

@ -11,6 +11,13 @@ how you can arbitrarily nest groups and commands to have a rich sub-command syst
Commands are defined by attaching it to a regular Python function. The command is then invoked by the user using a similar
signature to the Python function.
.. warning::
You must have access to the :attr:`~discord.Intents.message_content` intent for the commands extension
to function. This must be set both in the developer portal and within your code.
Failure to do this will result in your bot not responding to any of your commands.
For example, in the given command definition:
.. code-block:: python3
@ -171,9 +178,9 @@ As seen earlier, every command must take at least a single parameter, called the
This parameter gives you access to something called the "invocation context". Essentially all the information you need to
know how the command was executed. It contains a lot of useful information:
- :attr:`.Context.guild` to fetch the :class:`Guild` of the command, if any.
- :attr:`.Context.message` to fetch the :class:`Message` of the command.
- :attr:`.Context.author` to fetch the :class:`Member` or :class:`User` that called the command.
- :attr:`.Context.guild` returns the :class:`Guild` of the command, if any.
- :attr:`.Context.message` returns the :class:`Message` of the command.
- :attr:`.Context.author` returns the :class:`Member` or :class:`User` that called the command.
- :meth:`.Context.send` to send a message to the channel the command was used in.
The context implements the :class:`abc.Messageable` interface, so anything you can do on a :class:`abc.Messageable` you
@ -393,6 +400,8 @@ A lot of discord models work out of the gate as a parameter:
- :class:`Emoji`
- :class:`PartialEmoji`
- :class:`Thread` (since v2.0)
- :class:`GuildSticker` (since v2.0)
- :class:`ScheduledEvent` (since v2.0)
Having any of these set as the converter will intelligently convert the argument to the appropriate target type you
specify.
@ -441,6 +450,10 @@ converter is given below:
+--------------------------+-------------------------------------------------+
| :class:`Thread` | :class:`~ext.commands.ThreadConverter` |
+--------------------------+-------------------------------------------------+
| :class:`GuildSticker` | :class:`~ext.commands.GuildStickerConverter` |
+--------------------------+-------------------------------------------------+
| :class:`ScheduledEvent` | :class:`~ext.commands.ScheduledEventConverter` |
+--------------------------+-------------------------------------------------+
By providing the converter it allows us to use them as building blocks for another converter:

10
docs/ext/commands/extensions.rst

@ -24,10 +24,10 @@ An example extension looks like this:
async def hello(ctx):
await ctx.send(f'Hello {ctx.author.display_name}.')
def setup(bot):
async def setup(bot):
bot.add_command(hello)
In this example we define a simple command, and when the extension is loaded this command is added to the bot. Now the final step to this is loading the extension, which we do by calling :meth:`.Bot.load_extension`. To load this extension we call ``bot.load_extension('hello')``.
In this example we define a simple command, and when the extension is loaded this command is added to the bot. Now the final step to this is loading the extension, which we do by calling :meth:`.Bot.load_extension`. To load this extension we call ``await bot.load_extension('hello')``.
.. admonition:: Cogs
:class: helpful
@ -45,7 +45,7 @@ When you make a change to the extension and want to reload the references, the l
.. code-block:: python3
>>> bot.reload_extension('hello')
>>> await bot.reload_extension('hello')
Once the extension reloads, any changes that we did will be applied. This is useful if we want to add or remove functionality without restarting our bot. If an error occurred during the reloading process, the bot will pretend as if the reload never happened.
@ -57,8 +57,8 @@ Although rare, sometimes an extension needs to clean-up or know when it's being
.. code-block:: python3
:caption: basic_ext.py
def setup(bot):
async def setup(bot):
print('I am being loaded!')
def teardown(bot):
async def teardown(bot):
print('I am being unloaded!')

4
docs/faq.rst

@ -326,10 +326,6 @@ Quick example: ::
embed.set_image(url="attachment://image.png")
await channel.send(file=file, embed=embed)
.. note ::
Due to a Discord limitation, filenames may not include underscores.
Is there an event for audit log entries being created?
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

1397
docs/migrating.rst

File diff suppressed because it is too large

1174
docs/migrating_to_v1.rst

File diff suppressed because it is too large

7
docs/quickstart.rst

@ -19,9 +19,14 @@ It looks something like this:
.. code-block:: python3
# This example requires the 'message_content' intent.
import discord
client = discord.Client()
intents = discord.Intents.default()
intents.message_content = True
client = discord.Client(intents=intents)
@client.event
async def on_ready():

1
examples/background_task.py

@ -9,6 +9,7 @@ class MyClient(discord.Client):
# an attribute we can access from our task
self.counter = 0
async def setup_hook(self) -> None:
# start the task to run in the background
self.my_background_task.start()

1
examples/background_task_asyncio.py

@ -5,6 +5,7 @@ class MyClient(discord.Client):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
async def setup_hook(self) -> None:
# create the background task and run it in the background
self.bg_task = self.loop.create_task(self.my_background_task())

8
examples/basic_voice.py

@ -131,5 +131,9 @@ async def on_ready():
print(f'Logged in as {bot.user} (ID: {bot.user.id})')
print('------')
bot.add_cog(Music(bot))
bot.run('token')
async def main():
async with bot:
await bot.add_cog(Music(bot))
await bot.start('token')
asyncio.run(main())

69
examples/modal.py

@ -0,0 +1,69 @@
import discord
from discord import app_commands
import traceback
# Just default intents and a `discord.Client` instance
# We don't need a `commands.Bot` instance because we are not
# creating text-based commands.
intents = discord.Intents.default()
client = discord.Client(intents=intents)
# We need an `discord.app_commands.CommandTree` instance
# to register application commands (slash commands in this case)
tree = app_commands.CommandTree(client)
# The guild in which this slash command will be registered.
# As global commands can take up to an hour to propagate, it is ideal
# to test it in a guild.
TEST_GUILD = discord.Object(ID)
@client.event
async def on_ready():
print(f'Logged in as {client.user} (ID: {client.user.id})')
print('------')
# Sync the application command with Discord.
await tree.sync(guild=TEST_GUILD)
class Feedback(discord.ui.Modal, title='Feedback'):
# Our modal classes MUST subclass `discord.ui.Modal`,
# but the title can be whatever you want.
# This will be a short input, where the user can enter their name
# It will also have a placeholder, as denoted by the `placeholder` kwarg.
# By default, it is required and is a short-style input which is exactly
# what we want.
name = discord.ui.TextInput(
label='Name',
placeholder='Your name here...',
)
# This is a longer, paragraph style input, where user can submit feedback
# Unlike the name, it is not required. If filled out, however, it will
# only accept a maximum of 300 characters, as denoted by the
# `max_length=300` kwarg.
feedback = discord.ui.TextInput(
label='What do you think of this new feature?',
style=discord.TextStyle.long,
placeholder='Type your feedback here...',
required=False,
max_length=300,
)
async def on_submit(self, interaction: discord.Interaction):
await interaction.response.send_message(f'Thanks for your feedback, {self.name.value}!', ephemeral=True)
async def on_error(self, error: Exception, interaction: discord.Interaction) -> None:
await interaction.response.send_message('Oops! Something went wrong.', ephemeral=True)
# Make sure we know what the error actually is
traceback.print_tb(error.__traceback__)
@tree.command(guild=TEST_GUILD, description="Submit feedback")
async def feedback(interaction: discord.Interaction):
# Send the modal with an instance of our `Feedback` class
await interaction.response.send_modal(Feedback())
client.run('token')

2
pyproject.toml

@ -28,6 +28,7 @@ line_length = 125
[tool.pyright]
include = [
"discord",
"discord/app_commands",
"discord/types",
"discord/ext",
"discord/ext/commands",
@ -39,6 +40,7 @@ exclude = [
"dist",
"docs",
]
reportUnnecessaryTypeIgnoreComment = "warning"
pythonVersion = "3.8"
typeCheckingMode = "basic"

2
requirements.txt

@ -1 +1 @@
aiohttp>=3.6.0,<3.9.0
aiohttp>=3.6.0,<4

2
setup.py

@ -34,7 +34,7 @@ with open('README.rst') as f:
readme = f.read()
extras_require = {
'voice': ['PyNaCl>=1.3.0,<1.5'],
'voice': ['PyNaCl>=1.3.0,<1.6'],
'docs': [
'sphinx==4.4.0',
'sphinxcontrib_trio==1.1.2',

103
tests/test_ext_tasks.py

@ -10,6 +10,7 @@ import asyncio
import datetime
import pytest
import sys
from discord import utils
from discord.ext import tasks
@ -75,3 +76,105 @@ async def test_explicit_initial_runs_tomorrow_multi():
assert not has_run
finally:
loop.cancel()
def test_task_regression_issue7659():
jst = datetime.timezone(datetime.timedelta(hours=9))
# 00:00, 03:00, 06:00, 09:00, 12:00, 15:00, 18:00, 21:00
times = [datetime.time(hour=h, tzinfo=jst) for h in range(0, 24, 3)]
@tasks.loop(time=times)
async def loop():
pass
before_midnight = datetime.datetime(2022, 3, 12, 23, 50, 59, tzinfo=jst)
after_midnight = before_midnight + datetime.timedelta(minutes=9, seconds=2)
expected_before_midnight = datetime.datetime(2022, 3, 13, 0, 0, 0, tzinfo=jst)
expected_after_midnight = datetime.datetime(2022, 3, 13, 3, 0, 0, tzinfo=jst)
assert loop._get_next_sleep_time(before_midnight) == expected_before_midnight
assert loop._get_next_sleep_time(after_midnight) == expected_after_midnight
today = datetime.date.today()
minute_before = [datetime.datetime.combine(today, time, tzinfo=jst) - datetime.timedelta(minutes=1) for time in times]
for before, expected_time in zip(minute_before, times):
expected = datetime.datetime.combine(today, expected_time, tzinfo=jst)
actual = loop._get_next_sleep_time(before)
assert actual == expected
def test_task_regression_issue7676():
jst = datetime.timezone(datetime.timedelta(hours=9))
# 00:00, 03:00, 06:00, 09:00, 12:00, 15:00, 18:00, 21:00
times = [datetime.time(hour=h, tzinfo=jst) for h in range(0, 24, 3)]
@tasks.loop(time=times)
async def loop():
pass
# Create pseudo UTC times
now = utils.utcnow()
today = now.date()
times_before_in_utc = [
datetime.datetime.combine(today, time, tzinfo=jst).astimezone(datetime.timezone.utc) - datetime.timedelta(minutes=1)
for time in times
]
for before, expected_time in zip(times_before_in_utc, times):
actual = loop._get_next_sleep_time(before)
actual_time = actual.timetz()
assert actual_time == expected_time
@pytest.mark.skipif(sys.version_info < (3, 9), reason="zoneinfo requires 3.9")
def test_task_is_imaginary():
import zoneinfo
tz = zoneinfo.ZoneInfo('America/New_York')
# 2:30 AM was skipped
dt = datetime.datetime(2022, 3, 13, 2, 30, tzinfo=tz)
assert tasks.is_imaginary(dt)
now = utils.utcnow()
# UTC time is never imaginary or ambiguous
assert not tasks.is_imaginary(now)
@pytest.mark.skipif(sys.version_info < (3, 9), reason="zoneinfo requires 3.9")
def test_task_is_ambiguous():
import zoneinfo
tz = zoneinfo.ZoneInfo('America/New_York')
# 1:30 AM happened twice
dt = datetime.datetime(2022, 11, 6, 1, 30, tzinfo=tz)
assert tasks.is_ambiguous(dt)
now = utils.utcnow()
# UTC time is never imaginary or ambiguous
assert not tasks.is_imaginary(now)
@pytest.mark.skipif(sys.version_info < (3, 9), reason="zoneinfo requires 3.9")
@pytest.mark.parametrize(
('dt', 'key', 'expected'),
[
(datetime.datetime(2022, 11, 6, 1, 30), 'America/New_York', datetime.datetime(2022, 11, 6, 1, 30, fold=1)),
(datetime.datetime(2022, 3, 13, 2, 30), 'America/New_York', datetime.datetime(2022, 3, 13, 3, 30)),
(datetime.datetime(2022, 4, 8, 2, 30), 'America/New_York', datetime.datetime(2022, 4, 8, 2, 30)),
(datetime.datetime(2023, 1, 7, 12, 30), 'UTC', datetime.datetime(2023, 1, 7, 12, 30)),
],
)
def test_task_date_resolve(dt, key, expected):
import zoneinfo
tz = zoneinfo.ZoneInfo(key)
actual = tasks.resolve_datetime(dt.replace(tzinfo=tz))
expected = expected.replace(tzinfo=tz)
assert actual == expected

Loading…
Cancel
Save