diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 4761544c9..3c75626a6 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -41,11 +41,9 @@ jobs: - name: Run pyright run: | - # It is OK for the types to not pass at this stage - # We are just running it as a quick reference check - pyright || echo "Type checking did not pass" + pyright - name: Run black if: ${{ always() && steps.install-deps.outcome == 'success' }} run: | - black --check --verbose discord + black --check --verbose discord examples diff --git a/discord/abc.py b/discord/abc.py index 062d5027f..da6064096 100644 --- a/discord/abc.py +++ b/discord/abc.py @@ -79,9 +79,8 @@ if TYPE_CHECKING: from .state import ConnectionState from .guild import Guild from .member import Member - from .channel import CategoryChannel from .message import Message, MessageReference, PartialMessage - from .channel import DMChannel, GroupChannel, PartialMessageable, TextChannel, VocalGuildChannel + from .channel import TextChannel, DMChannel, GroupChannel, PartialMessageable, VoiceChannel, CategoryChannel from .threads import Thread from .enums import InviteTarget from .types.channel import ( @@ -94,10 +93,9 @@ if TYPE_CHECKING: SnowflakeList, ) - PartialMessageableChannel = Union[TextChannel, Thread, DMChannel, PartialMessageable] + PartialMessageableChannel = Union[TextChannel, VoiceChannel, Thread, DMChannel, PartialMessageable] MessageableChannel = Union[PartialMessageableChannel, GroupChannel] SnowflakeTime = Union["Snowflake", datetime] - ConnectableChannel = Union[VocalGuildChannel, DMChannel, GroupChannel, User] MISSING = utils.MISSING @@ -110,6 +108,43 @@ class _Undefined: _undefined: Any = _Undefined() +async def _purge_helper( + channel: Union[Thread, TextChannel, VoiceChannel], + *, + limit: Optional[int] = 100, + check: Callable[[Message], bool] = MISSING, + before: Optional[SnowflakeTime] = None, + after: Optional[SnowflakeTime] = None, + around: Optional[SnowflakeTime] = None, + oldest_first: Optional[bool] = False, + reason: Optional[str] = None, +) -> List[Message]: + if check is MISSING: + check = lambda m: True + + state = channel._state + channel_id = channel.id + iterator = channel.history(limit=limit, before=before, after=after, oldest_first=oldest_first, around=around) + ret: List[Message] = [] + count = 0 + + async for message in iterator: + if count == 50: + to_delete = ret[-50:] + await state._delete_messages(channel_id, to_delete, reason=reason) + count = 0 + if not check(message): + continue + + count += 1 + ret.append(message) + + # Some messages remaining to poll + to_delete = ret[-count:] + await state._delete_messages(channel_id, to_delete, reason=reason) + return ret + + @runtime_checkable class Snowflake(Protocol): """An ABC that details the common operations on a Discord model. @@ -528,7 +563,7 @@ class GuildChannel: If there is no category then this is ``None``. """ - return self.guild.get_channel(self.category_id) # type: ignore - These are coerced into CategoryChannel + return self.guild.get_channel(self.category_id) # type: ignore # These are coerced into CategoryChannel @property def permissions_synced(self) -> bool: @@ -555,6 +590,7 @@ class GuildChannel: - Guild roles - Channel overrides - Member overrides + - Member timeout If a :class:`~discord.Role` is passed, then it checks the permissions someone with that role would have, which is essentially: @@ -641,6 +677,12 @@ class GuildChannel: if base.administrator: return Permissions.all() + if obj.is_timed_out(): + # Timeout leads to every permission except VIEW_CHANNEL and READ_MESSAGE_HISTORY + # being explicitly denied + base.value &= Permissions._timeout_mask() + return base + # Apply @everyone allow/deny first since it's special try: maybe_everyone = self._overwrites[0] @@ -860,7 +902,7 @@ class GuildChannel: obj = cls(state=self._state, guild=self.guild, data=data) # Temporarily add it to the cache - self.guild._channels[obj.id] = obj # type: ignore - obj is a GuildChannel + self.guild._channels[obj.id] = obj # type: ignore # obj is a GuildChannel return obj async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> Self: @@ -1768,6 +1810,8 @@ class Connectable(Protocol): - :class:`~discord.StageChannel` - :class:`~discord.DMChannel` - :class:`~discord.GroupChannel` + - :class:`~discord.User` + - :class:`~discord.Member` """ __slots__ = () diff --git a/discord/appinfo.py b/discord/appinfo.py index a4ec407c5..26ec2cc66 100644 --- a/discord/appinfo.py +++ b/discord/appinfo.py @@ -316,7 +316,7 @@ class Application(PartialApplication): self.redirect_uris: List[str] = data.get('redirect_uris', []) self.primary_sku_id: Optional[int] = utils._get_as_snowflake(data, 'primary_sku_id') self.slug: Optional[str] = data.get('slug') - self.interactions_endpoint_url: Optional[str] = data['interactions_endpoint_url'] + self.interactions_endpoint_url: Optional[str] = data.get('interactions_endpoint_url') self.verification_state = try_enum(ApplicationVerificationState, data['verification_state']) self.store_application_state = try_enum(StoreApplicationState, data['store_application_state']) @@ -335,7 +335,7 @@ class Application(PartialApplication): if owner is not None: self.owner: abcUser = state.create_user(owner) else: - self.owner: abcUser = state.user # type: ignore - state.user will always be present here + self.owner: abcUser = state.user # type: ignore # state.user will always be present here def __repr__(self) -> str: return ( @@ -469,7 +469,7 @@ class Application(PartialApplication): The new secret. """ data = await self._state.http.reset_secret(self.id) - return data['secret'] + return data['secret'] # type: ignore # Usually not there async def create_bot(self) -> ApplicationBot: """|coro| @@ -544,7 +544,7 @@ class InteractionApplication(Hashable): self._icon: Optional[str] = data.get('icon') self.type: Optional[ApplicationType] = try_enum(ApplicationType, data['type']) if 'type' in data else None - self.bot: User = None # type: ignore - This should never be None but it's volatile + self.bot: User = None # type: ignore # This should never be None but it's volatile user = data.get('bot') if user is not None: self.bot = User(state=self._state, data=user) diff --git a/discord/asset.py b/discord/asset.py index 018f2dcbd..99fd57fe0 100644 --- a/discord/asset.py +++ b/discord/asset.py @@ -29,6 +29,7 @@ import os from typing import Any, Literal, Optional, TYPE_CHECKING, Tuple, Union from .errors import DiscordException from . import utils +from .file import File import yarl @@ -92,7 +93,7 @@ class AssetMixin: Parameters ---------- fp: Union[:class:`io.BufferedIOBase`, :class:`os.PathLike`] - The file-like object to save this attachment to or the filename + The file-like object to save this asset to or the filename to use. If a filename is passed then a file is created with that filename and used instead. seek_begin: :class:`bool` @@ -124,6 +125,43 @@ class AssetMixin: with open(fp, 'wb') as f: return f.write(data) + async def to_file(self, *, spoiler: bool = False) -> File: + """|coro| + + Converts the asset into a :class:`File` suitable for sending via + :meth:`abc.Messageable.send`. + + .. versionadded:: 2.0 + + Parameters + ----------- + spoiler: :class:`bool` + Whether the file is a spoiler. + + Raises + ------ + DiscordException + The asset does not have an associated state. + TypeError + The asset is a sticker with lottie type. + HTTPException + Downloading the asset failed. + Forbidden + You do not have permissions to access this asset. + NotFound + The asset was deleted. + + Returns + ------- + :class:`File` + The asset as a file suitable for sending. + """ + + data = await self.read() + url = yarl.URL(self.url) + _, _, filename = url.path.rpartition('/') + return File(io.BytesIO(data), filename=filename, spoiler=spoiler) + class Asset(AssetMixin): """Represents a CDN asset on Discord. diff --git a/discord/audit_logs.py b/discord/audit_logs.py index dbf7b5396..b7a6af46b 100644 --- a/discord/audit_logs.py +++ b/discord/audit_logs.py @@ -132,13 +132,19 @@ def _transform_icon(entry: AuditLogEntry, data: Optional[str]) -> Optional[Asset if entry.action is enums.AuditLogAction.guild_update: return Asset._from_guild_icon(entry._state, entry.guild.id, data) else: - return Asset._from_icon(entry._state, entry._target_id, data, path='role') # type: ignore - target_id won't be None in this case + return Asset._from_icon(entry._state, entry._target_id, data, path='role') # type: ignore # target_id won't be None in this case def _transform_avatar(entry: AuditLogEntry, data: Optional[str]) -> Optional[Asset]: if data is None: return None - return Asset._from_avatar(entry._state, entry._target_id, data) # type: ignore - target_id won't be None in this case + return Asset._from_avatar(entry._state, entry._target_id, data) # type: ignore # target_id won't be None in this case + + +def _transform_cover_image(entry: AuditLogEntry, data: Optional[str]) -> Optional[Asset]: + if data is None: + return None + return Asset._from_scheduled_event_cover_image(entry._state, entry._target_id, data) # type: ignore # target_id won't be None in this case def _guild_hash_transformer(path: str) -> Callable[[AuditLogEntry, Optional[str]], Optional[Asset]]: @@ -238,6 +244,8 @@ class AuditLogChanges: 'mfa_level': (None, _enum_transformer(enums.MFALevel)), 'status': (None, _enum_transformer(enums.EventStatus)), 'entity_type': (None, _enum_transformer(enums.EntityType)), + 'preferred_locale': (None, _enum_transformer(enums.Locale)), + 'image_hash': ('cover_image', _transform_cover_image), } # fmt: on @@ -250,10 +258,10 @@ class AuditLogChanges: # Special cases for role add/remove if attr == '$add': - self._handle_role(self.before, self.after, entry, elem['new_value']) # type: ignore - new_value is a list of roles in this case + self._handle_role(self.before, self.after, entry, elem['new_value']) # type: ignore # new_value is a list of roles in this case continue elif attr == '$remove': - self._handle_role(self.after, self.before, entry, elem['new_value']) # type: ignore - new_value is a list of roles in this case + self._handle_role(self.after, self.before, entry, elem['new_value']) # type: ignore # new_value is a list of roles in this case continue try: @@ -310,7 +318,7 @@ class AuditLogChanges: if role is None: role = Object(id=role_id) - role.name = e['name'] # type: ignore - Object doesn't usually have name + role.name = e['name'] # type: ignore # Object doesn't usually have name data.append(role) @@ -448,7 +456,7 @@ class AuditLogEntry(Hashable): role = self.guild.get_role(instance_id) if role is None: role = Object(id=instance_id) - role.name = self.extra.get('role_name') # type: ignore - Object doesn't usually have name + role.name = self.extra.get('role_name') # type: ignore # Object doesn't usually have name self.extra = role elif self.action.name.startswith('stage_instance'): channel_id = int(extra['channel_id']) @@ -540,7 +548,7 @@ class AuditLogEntry(Hashable): 'code': changeset.code, 'temporary': changeset.temporary, 'uses': changeset.uses, - 'channel': None, # type: ignore - the channel is passed to the Invite constructor directly + 'channel': None, # type: ignore # the channel is passed to the Invite constructor directly } obj = Invite(state=self._state, data=fake_payload, guild=self.guild, channel=changeset.channel) diff --git a/discord/calls.py b/discord/calls.py index ef6438229..696bb3f4b 100644 --- a/discord/calls.py +++ b/discord/calls.py @@ -89,12 +89,12 @@ class CallMessage: @property def initiator(self) -> User: """:class:`User`: Returns the user that started the call.""" - return self.message.author # type: ignore - Cannot be a Member in private messages + return self.message.author # type: ignore # Cannot be a Member in private messages @property def channel(self) -> _PrivateChannel: r""":class:`PrivateChannel`\: The private channel associated with this message.""" - return self.message.channel # type: ignore - Can only be a private channel here + return self.message.channel # type: ignore # Can only be a private channel here @property def duration(self) -> datetime.timedelta: @@ -186,7 +186,7 @@ class PrivateCall: def initiator(self) -> Optional[User]: """Optional[:class:`User`]: Returns the user that started the call. The call message must be available to obtain this information.""" if self.message: - return self.message.author # type: ignore - Cannot be a Member in private messages + return self.message.author # type: ignore # Cannot be a Member in private messages @property def connected(self) -> bool: diff --git a/discord/channel.py b/discord/channel.py index 0189b5cbf..6c866e551 100644 --- a/discord/channel.py +++ b/discord/channel.py @@ -60,7 +60,6 @@ __all__ = ( 'StageChannel', 'DMChannel', 'CategoryChannel', - 'StoreChannel', 'GroupChannel', 'PartialMessageable', ) @@ -84,7 +83,6 @@ if TYPE_CHECKING: StageChannel as StageChannelPayload, DMChannel as DMChannelPayload, CategoryChannel as CategoryChannelPayload, - StoreChannel as StoreChannelPayload, GroupDMChannel as GroupChannelPayload, ) @@ -468,32 +466,16 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): List[:class:`.Message`] The list of messages that were deleted. """ - if check is MISSING: - check = lambda m: True - - state = self._state - channel_id = self.id - iterator = self.history(limit=limit, before=before, after=after, oldest_first=oldest_first, around=around) - ret: List[Message] = [] - count = 0 - - async for message in iterator: - if count == 50: - to_delete = ret[-50:] - await state._delete_messages(channel_id, to_delete) - count = 0 - - if not check(message): - continue - - count += 1 - ret.append(message) - - # Some messages remaining to poll - to_delete = ret[-count:] - await state._delete_messages(channel_id, to_delete, reason=reason) - - return ret + return await discord.abc._purge_helper( + self, + limit=limit, + check=check, + before=before, + after=after, + around=around, + oldest_first=oldest_first, + reason=reason, + ) async def webhooks(self) -> List[Webhook]: """|coro| @@ -554,7 +536,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): from .webhook import Webhook if avatar is not None: - avatar = utils._bytes_to_base64_data(avatar) # type: ignore - Silence reassignment error + avatar = utils._bytes_to_base64_data(avatar) # type: ignore # Silence reassignment error data = await self._state.http.create_webhook(self.id, name=str(name), avatar=avatar, reason=reason) return Webhook.from_state(data, state=self._state) @@ -748,7 +730,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): limit: Optional[int] = 100, before: Optional[Union[Snowflake, datetime.datetime]] = None, ) -> AsyncIterator[Thread]: - """Returns an :term:`asynchronous iterator` that iterates over all archived threads in the guild, + """Returns an :term:`asynchronous iterator` that iterates over all archived threads in this text channel, in order of decreasing ID for joined threads, and decreasing :attr:`Thread.archive_timestamp` otherwise. You must have :attr:`~Permissions.read_message_history` to use this. If iterating over private threads @@ -777,7 +759,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): HTTPException The request to get the archived threads failed. ValueError - `joined`` was set to ``True`` and ``private`` was set to ``False``. You cannot retrieve public archived + ``joined`` was set to ``True`` and ``private`` was set to ``False``. You cannot retrieve public archived threads that you have joined. Yields @@ -848,6 +830,7 @@ class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hasha 'category_id', 'rtc_region', 'video_quality_mode', + 'last_message_id', ) def __init__(self, *, state: ConnectionState, guild: Guild, data: Union[VoiceChannelPayload, StageChannelPayload]): @@ -867,6 +850,7 @@ class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hasha self.rtc_region: Optional[str] = data.get('rtc_region') self.video_quality_mode: VideoQualityMode = try_enum(VideoQualityMode, data.get('video_quality_mode', 1)) self.category_id: Optional[int] = utils._get_as_snowflake(data, 'parent_id') + self.last_message_id: Optional[int] = utils._get_as_snowflake(data, 'last_message_id') self.position: int = data['position'] self.bitrate: int = data['bitrate'] self.user_limit: int = data['user_limit'] @@ -932,7 +916,7 @@ class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hasha return base -class VoiceChannel(VocalGuildChannel): +class VoiceChannel(discord.abc.Messageable, VocalGuildChannel): """Represents a Discord guild voice channel. .. container:: operations @@ -981,6 +965,11 @@ class VoiceChannel(VocalGuildChannel): video_quality_mode: :class:`VideoQualityMode` The camera video quality for the voice channel's participants. + .. versionadded:: 2.0 + last_message_id: Optional[:class:`int`] + The last message ID of the message sent to this channel. It may + *not* point to an existing or valid message. + .. versionadded:: 2.0 """ @@ -1000,11 +989,234 @@ class VoiceChannel(VocalGuildChannel): joined = ' '.join('%s=%r' % t for t in attrs) return f'<{self.__class__.__name__} {joined}>' + async def _get_channel(self) -> Self: + return self + @property def type(self) -> ChannelType: """:class:`ChannelType`: The channel's Discord type.""" return ChannelType.voice + @property + def last_message(self) -> Optional[Message]: + """Fetches the last message from this channel in cache. + + The message might not be valid or point to an existing message. + + .. versionadded:: 2.0 + + .. admonition:: Reliable Fetching + :class: helpful + + For a slightly more reliable method of fetching the + last message, consider using either :meth:`history` + or :meth:`fetch_message` with the :attr:`last_message_id` + attribute. + + Returns + --------- + Optional[:class:`Message`] + The last message in this channel or ``None`` if not found. + """ + return self._state._get_message(self.last_message_id) if self.last_message_id else None + + def get_partial_message(self, message_id: int, /) -> PartialMessage: + """Creates a :class:`PartialMessage` from the message ID. + + This is useful if you want to work with a message and only have its ID without + doing an unnecessary API call. + + .. versionadded:: 2.0 + + Parameters + ------------ + message_id: :class:`int` + The message ID to create a partial message for. + + Returns + --------- + :class:`PartialMessage` + The partial message. + """ + + from .message import PartialMessage + + return PartialMessage(channel=self, id=message_id) + + async def delete_messages(self, messages: Iterable[Snowflake], /, *, reason: Optional[str] = None) -> None: + """|coro| + + Deletes a list of messages. This is similar to :meth:`Message.delete` + except it bulk deletes multiple messages. + + You must have the :attr:`~Permissions.manage_messages` permission to + use this (unless they're your own). + + .. note:: + Users do not have access to the message bulk-delete endpoint. + Since messages are just iterated over and deleted one-by-one, + it's easy to get ratelimited using this method. + + Parameters + ----------- + messages: Iterable[:class:`abc.Snowflake`] + An iterable of messages denoting which ones to bulk delete. + reason: Optional[:class:`str`] + The reason for deleting the messages. Shows up on the audit log. + + Raises + ------ + Forbidden + You do not have proper permissions to delete the messages. + HTTPException + Deleting the messages failed. + """ + if not isinstance(messages, (list, tuple)): + messages = list(messages) + + if len(messages) == 0: + return # Do nothing + + await self._state._delete_messages(self.id, messages, reason=reason) + + async def purge( + self, + *, + limit: Optional[int] = 100, + check: Callable[[Message], bool] = MISSING, + before: Optional[SnowflakeTime] = None, + after: Optional[SnowflakeTime] = None, + around: Optional[SnowflakeTime] = None, + oldest_first: Optional[bool] = False, + reason: Optional[str] = None, + ) -> List[Message]: + """|coro| + + Purges a list of messages that meet the criteria given by the predicate + ``check``. If a ``check`` is not provided then all messages are deleted + without discrimination. + + The :attr:`~Permissions.read_message_history` permission is needed to + retrieve message history. + + Examples + --------- + + Deleting bot's messages :: + + def is_me(m): + return m.author == client.user + + deleted = await channel.purge(limit=100, check=is_me) + await channel.send(f'Deleted {len(deleted)} message(s)') + + Parameters + ----------- + limit: Optional[:class:`int`] + The number of messages to search through. This is not the number + of messages that will be deleted, though it can be. + check: Callable[[:class:`Message`], :class:`bool`] + The function used to check if a message should be deleted. + It must take a :class:`Message` as its sole parameter. + before: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]] + Same as ``before`` in :meth:`history`. + after: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]] + Same as ``after`` in :meth:`history`. + around: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]] + Same as ``around`` in :meth:`history`. + oldest_first: Optional[:class:`bool`] + Same as ``oldest_first`` in :meth:`history`. + reason: Optional[:class:`str`] + The reason for purging the messages. Shows up on the audit log. + + Raises + ------- + Forbidden + You do not have proper permissions to do the actions required. + HTTPException + Purging the messages failed. + + Returns + -------- + List[:class:`.Message`] + The list of messages that were deleted. + """ + return await discord.abc._purge_helper( + self, + limit=limit, + check=check, + before=before, + after=after, + around=around, + oldest_first=oldest_first, + reason=reason, + ) + + async def webhooks(self) -> List[Webhook]: + """|coro| + + Gets the list of webhooks from this channel. + + Requires :attr:`~.Permissions.manage_webhooks` permissions. + + .. versionadded:: 2.0 + + Raises + ------- + Forbidden + You don't have permissions to get the webhooks. + + Returns + -------- + List[:class:`Webhook`] + The webhooks for this channel. + """ + + from .webhook import Webhook + + data = await self._state.http.channel_webhooks(self.id) + return [Webhook.from_state(d, state=self._state) for d in data] + + async def create_webhook(self, *, name: str, avatar: Optional[bytes] = None, reason: Optional[str] = None) -> Webhook: + """|coro| + + Creates a webhook for this channel. + + Requires :attr:`~.Permissions.manage_webhooks` permissions. + + .. versionadded:: 2.0 + + Parameters + ------------- + name: :class:`str` + The webhook's name. + avatar: Optional[:class:`bytes`] + A :term:`py:bytes-like object` representing the webhook's default avatar. + This operates similarly to :meth:`~ClientUser.edit`. + reason: Optional[:class:`str`] + The reason for creating this webhook. Shows up in the audit logs. + + Raises + ------- + HTTPException + Creating the webhook failed. + Forbidden + You do not have permissions to create a webhook. + + Returns + -------- + :class:`Webhook` + The created webhook. + """ + + from .webhook import Webhook + + if avatar is not None: + avatar = utils._bytes_to_base64_data(avatar) # type: ignore # Silence reassignment error + + data = await self._state.http.create_webhook(self.id, name=str(name), avatar=avatar, reason=reason) + return Webhook.from_state(data, state=self._state) + @utils.copy_doc(discord.abc.GuildChannel.clone) async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> VoiceChannel: return await self._clone_impl({'bitrate': self.bitrate, 'user_limit': self.user_limit}, name=name, reason=reason) @@ -1613,180 +1825,6 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable): return await self.guild.create_stage_channel(name, category=self, **options) -class StoreChannel(discord.abc.GuildChannel, Hashable): - """Represents a Discord guild store channel. - - .. container:: operations - - .. describe:: x == y - - Checks if two channels are equal. - - .. describe:: x != y - - Checks if two channels are not equal. - - .. describe:: hash(x) - - Returns the channel's hash. - - .. describe:: str(x) - - Returns the channel's name. - - Attributes - ----------- - name: :class:`str` - The channel name. - guild: :class:`Guild` - The guild the channel belongs to. - id: :class:`int` - The channel ID. - category_id: :class:`int` - The category channel ID this channel belongs to. - position: :class:`int` - The position in the channel list. This is a number that starts at 0. e.g. the - top channel is position 0. - nsfw: :class:`bool` - If the channel is marked as "not safe for work". - - .. note:: - - To check if the channel or the guild of that channel are marked as NSFW, consider :meth:`is_nsfw` instead. - """ - - __slots__ = ( - 'name', - 'id', - 'guild', - '_state', - 'nsfw', - 'category_id', - 'position', - '_overwrites', - ) - - def __init__(self, *, state: ConnectionState, guild: Guild, data: StoreChannelPayload): - self._state: ConnectionState = state - self.id: int = int(data['id']) - self._update(guild, data) - - def __repr__(self) -> str: - return f'' - - def _update(self, guild: Guild, data: StoreChannelPayload) -> None: - self.guild: Guild = guild - self.name: str = data['name'] - self.category_id: Optional[int] = utils._get_as_snowflake(data, 'parent_id') - self.position: int = data['position'] - self.nsfw: bool = data.get('nsfw', False) - self._fill_overwrites(data) - - @property - def _sorting_bucket(self) -> int: - return ChannelType.text.value - - @property - def type(self) -> ChannelType: - """:class:`ChannelType`: The channel's Discord type.""" - return ChannelType.store - - @utils.copy_doc(discord.abc.GuildChannel.permissions_for) - def permissions_for(self, obj: Union[Member, Role], /) -> Permissions: - base = super().permissions_for(obj) - - # store channels do not have voice related permissions - denied = Permissions.voice() - base.value &= ~denied.value - return base - - def is_nsfw(self) -> bool: - """:class:`bool`: Checks if the channel is NSFW.""" - return self.nsfw - - @utils.copy_doc(discord.abc.GuildChannel.clone) - async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> StoreChannel: - return await self._clone_impl({'nsfw': self.nsfw}, name=name, reason=reason) - - @overload - async def edit( - self, - *, - name: str = ..., - position: int = ..., - nsfw: bool = ..., - sync_permissions: bool = ..., - category: Optional[CategoryChannel] = ..., - reason: Optional[str] = ..., - overwrites: Mapping[Union[Role, Member], PermissionOverwrite] = ..., - ) -> Optional[StoreChannel]: - ... - - @overload - async def edit(self) -> Optional[StoreChannel]: - ... - - async def edit(self, *, reason: Optional[str] = None, **options: Any) -> Optional[StoreChannel]: - """|coro| - - Edits the channel. - - You must have the :attr:`~Permissions.manage_channels` permission to - use this. - - .. versionchanged:: 2.0 - Edits are no longer in-place, the newly edited channel is returned instead. - - .. versionchanged:: 2.0 - This function will now raise :exc:`TypeError` or - :exc:`ValueError` instead of ``InvalidArgument``. - - Parameters - ---------- - name: :class:`str` - The new channel name. - position: :class:`int` - The new channel's position. - nsfw: :class:`bool` - To mark the channel as NSFW or not. - sync_permissions: :class:`bool` - Whether to sync permissions with the channel's new or pre-existing - category. Defaults to ``False``. - category: Optional[:class:`CategoryChannel`] - The new category for this channel. Can be ``None`` to remove the - category. - reason: Optional[:class:`str`] - The reason for editing this channel. Shows up on the audit log. - overwrites: :class:`Mapping` - A :class:`Mapping` of target (either a role or a member) to - :class:`PermissionOverwrite` to apply to the channel. - - .. versionadded:: 1.3 - - Raises - ------ - ValueError - The new ``position`` is less than 0 or greater than the number of channels. - TypeError - The permission overwrite information is not in proper form. - Forbidden - You do not have permissions to edit the channel. - HTTPException - Editing the channel failed. - - Returns - -------- - Optional[:class:`.StoreChannel`] - The newly edited store channel. If the edit was only positional - then ``None`` is returned instead. - """ - - payload = await self._edit(options, reason=reason) - if payload is not None: - # the payload will always be the proper channel payload - return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore - - class DMChannel(discord.abc.Messageable, discord.abc.Connectable, Hashable): """Represents a Discord direct message channel. @@ -2404,8 +2442,6 @@ def _guild_channel_factory(channel_type: int): return CategoryChannel, value elif value is ChannelType.news: return TextChannel, value - elif value is ChannelType.store: - return StoreChannel, value elif value is ChannelType.stage_voice: return StageChannel, value else: diff --git a/discord/client.py b/discord/client.py index 78a55b94e..e2174e069 100644 --- a/discord/client.py +++ b/discord/client.py @@ -594,7 +594,7 @@ class Client: except ReconnectWebSocket as e: _log.info('Got a request to %s the websocket.', e.op) self.dispatch('disconnect') - ws_params.update(sequence=self.ws.sequence, resume=e.resume, session=self.ws.session_id) # type: ignore - These are always present at this point + ws_params.update(sequence=self.ws.sequence, resume=e.resume, session=self.ws.session_id) # type: ignore # These are always present at this point continue except ( OSError, @@ -618,7 +618,7 @@ class Client: # If we get connection reset by peer then try to RESUME if isinstance(exc, OSError) and exc.errno in (54, 10054): - ws_params.update(sequence=self.ws.sequence, initial=False, resume=True, session=self.ws.session_id) # type: ignore - These are always present at this point + ws_params.update(sequence=self.ws.sequence, initial=False, resume=True, session=self.ws.session_id) # type: ignore # These are always present at this point continue # We should only get this when an unhandled close code happens, @@ -636,7 +636,7 @@ class Client: # Always try to RESUME the connection # If the connection is not RESUME-able then the gateway will invalidate the session # This is apparently what the official Discord client does - ws_params.update(sequence=self.ws.sequence, resume=True, session=self.ws.session_id) # type: ignore - These are always present at this point + ws_params.update(sequence=self.ws.sequence, resume=True, session=self.ws.session_id) # type: ignore # These are always present at this point async def close(self) -> None: """|coro| @@ -973,7 +973,7 @@ class Client: Optional[Union[:class:`.abc.GuildChannel`, :class:`.Thread`, :class:`.abc.PrivateChannel`]] The returned channel or ``None`` if not found. """ - return self._connection.get_channel(id) # type: ignore - The cache contains all channel types + return self._connection.get_channel(id) # type: ignore # The cache contains all channel types def get_partial_messageable(self, id: int, *, type: Optional[ChannelType] = None) -> PartialMessageable: """Returns a partial messageable with the given channel ID. @@ -1372,11 +1372,11 @@ class Client: custom_activity = activity payload: Dict[str, Any] = {} - if status != getattr(self.user.settings, 'status', None): # type: ignore - user is always present when logged in + if status != getattr(self.user.settings, 'status', None): # type: ignore # user is always present when logged in payload['status'] = status - if custom_activity != getattr(self.user.settings, 'custom_activity', None): # type: ignore - user is always present when logged in + if custom_activity != getattr(self.user.settings, 'custom_activity', None): # type: ignore # user is always present when logged in payload['custom_activity'] = custom_activity - await self.user.edit_settings(**payload) # type: ignore - user is always present when logged in + await self.user.edit_settings(**payload) # type: ignore # user is always present when logged in status_str = str(status) activities_tuple = tuple(a.to_dict() for a in activities) @@ -1574,7 +1574,7 @@ class Client: Creates a :class:`.Guild`. .. versionchanged:: 2.0 - ``name`` and ``icon`` parameters are now keyword-only. The `region`` parameter has been removed. + ``name`` and ``icon`` parameters are now keyword-only. The ``region`` parameter has been removed. .. versionchanged:: 2.0 This function will now raise :exc:`ValueError` instead of @@ -2224,7 +2224,7 @@ class Client: """ state = self._connection channels = await state.http.get_private_channels() - return [_private_channel_factory(data['type'])[0](me=self.user, data=data, state=state) for data in channels] # type: ignore - user is always present when logged in + return [_private_channel_factory(data['type'])[0](me=self.user, data=data, state=state) for data in channels] # type: ignore # user is always present when logged in async def create_dm(self, user: Snowflake, /) -> DMChannel: """|coro| @@ -2282,7 +2282,7 @@ class Client: users: List[_Snowflake] = [u.id for u in recipients] state = self._connection data = await state.http.start_group(users) - return GroupChannel(me=self.user, data=data, state=state) # type: ignore - user is always present when logged in + return GroupChannel(me=self.user, data=data, state=state) # type: ignore # user is always present when logged in @overload async def send_friend_request(self, user: BaseUser, /) -> Relationship: diff --git a/discord/colour.py b/discord/colour.py index 5308cb74f..761405702 100644 --- a/discord/colour.py +++ b/discord/colour.py @@ -26,14 +26,7 @@ from __future__ import annotations import colorsys import random -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Optional, - Tuple, - Union, -) +from typing import TYPE_CHECKING, Optional, Tuple, Union if TYPE_CHECKING: from typing_extensions import Self diff --git a/discord/commands.py b/discord/commands.py index fbb615c3c..c2fb625ac 100644 --- a/discord/commands.py +++ b/discord/commands.py @@ -29,7 +29,8 @@ from typing import Any, Dict, List, Optional, Protocol, Tuple, Type, runtime_che from .enums import AppCommandOptionType, AppCommandType, ChannelType, InteractionType, try_enum from .errors import InvalidData -from .utils import _generate_session_id, time_snowflake +from .mixins import Hashable +from .utils import time_snowflake if TYPE_CHECKING: from .abc import Messageable, Snowflake @@ -113,7 +114,7 @@ class ApplicationCommand(Protocol): return i -class BaseCommand(ApplicationCommand): +class BaseCommand(ApplicationCommand, Hashable): """Represents a base command. Attributes diff --git a/discord/components.py b/discord/components.py index d439025db..054b8c44f 100644 --- a/discord/components.py +++ b/discord/components.py @@ -135,7 +135,7 @@ class ActionRow(Component): return { 'type': int(self.type), 'components': [child.to_dict() for child in self.children], - } # type: ignore - Type checker does not understand these are the same + } # type: ignore # Type checker does not understand these are the same class Button(Component): diff --git a/discord/embeds.py b/discord/embeds.py index 8ea4086a8..e28b0c59e 100644 --- a/discord/embeds.py +++ b/discord/embeds.py @@ -709,4 +709,4 @@ class Embed: if self.title: result['title'] = self.title - return result # type: ignore - This payload is equivalent to the EmbedData type + return result # type: ignore # This payload is equivalent to the EmbedData type diff --git a/discord/emoji.py b/discord/emoji.py index d6db70657..89b2dfe98 100644 --- a/discord/emoji.py +++ b/discord/emoji.py @@ -116,8 +116,8 @@ class Emoji(_EmojiTag, AssetMixin): def _from_data(self, emoji: EmojiPayload): self.require_colons: bool = emoji.get('require_colons', False) self.managed: bool = emoji.get('managed', False) - self.id: int = int(emoji['id']) # type: ignore - This won't be None for full emoji objects. - self.name: str = emoji['name'] # type: ignore - This won't be None for full emoji objects. + self.id: int = int(emoji['id']) # type: ignore # This won't be None for full emoji objects. + self.name: str = emoji['name'] # type: ignore # This won't be None for full emoji objects. self.animated: bool = emoji.get('animated', False) self.available: bool = emoji.get('available', True) self._roles: SnowflakeList = SnowflakeList(map(int, emoji.get('roles', []))) @@ -256,7 +256,7 @@ class Emoji(_EmojiTag, AssetMixin): payload['roles'] = [role.id for role in roles] data = await self._state.http.edit_custom_emoji(self.guild_id, self.id, payload=payload, reason=reason) - return Emoji(guild=self.guild, data=data, state=self._state) # type: ignore - if guild is None, the http request would have failed + return Emoji(guild=self.guild, data=data, state=self._state) # type: ignore # If guild is None, the http request would have failed async def fetch_guild(self): """|coro| diff --git a/discord/enums.py b/discord/enums.py index 8ae7a3433..195252306 100644 --- a/discord/enums.py +++ b/discord/enums.py @@ -145,7 +145,7 @@ class EnumMeta(type): attrs['_enum_member_names_'] = member_names attrs['_enum_value_cls_'] = value_cls actual_cls = super().__new__(cls, name, bases, attrs) - value_cls._actual_enum_cls_ = actual_cls # type: ignore - Runtime attribute isn't understood + value_cls._actual_enum_cls_ = actual_cls # type: ignore # Runtime attribute isn't understood return actual_cls def __iter__(cls) -> Iterator[Any]: @@ -873,7 +873,7 @@ class AppCommandType(Enum): def create_unknown_value(cls: Type[E], val: Any) -> E: - value_cls = cls._enum_value_cls_ # type: ignore - This is narrowed below + value_cls = cls._enum_value_cls_ # type: ignore # This is narrowed below name = f'unknown_{val}' return value_cls(name=name, value=val) @@ -885,6 +885,6 @@ def try_enum(cls: Type[E], val: Any) -> E: """ try: - return cls._enum_value_map_[val] # type: ignore - All errors are caught below + return cls._enum_value_map_[val] # type: ignore # All errors are caught below except (KeyError, TypeError, AttributeError): return create_unknown_value(cls, val) diff --git a/discord/errors.py b/discord/errors.py index b89a3a4c2..17a719760 100644 --- a/discord/errors.py +++ b/discord/errors.py @@ -115,7 +115,7 @@ class HTTPException(DiscordException): def __init__(self, response: _ResponseType, message: Optional[Union[str, Dict[str, Any]]]): self.response: _ResponseType = response - self.status: int = response.status # type: ignore - This attribute is filled by the library even if using requests + self.status: int = response.status # type: ignore # This attribute is filled by the library even if using requests self.code: int self.text: str if isinstance(message, dict): diff --git a/discord/ext/commands/_types.py b/discord/ext/commands/_types.py index 25926f405..c0c782fee 100644 --- a/discord/ext/commands/_types.py +++ b/discord/ext/commands/_types.py @@ -23,7 +23,7 @@ DEALINGS IN THE SOFTWARE. """ -from typing import Any, Callable, Coroutine, TYPE_CHECKING, TypeVar, Union, Tuple +from typing import Any, Awaitable, Callable, Coroutine, TYPE_CHECKING, TypeVar, Union, Tuple T = TypeVar('T') @@ -37,18 +37,16 @@ if TYPE_CHECKING: from .errors import CommandError P = ParamSpec('P') - MaybeCoroFunc = Union[ - Callable[P, 'Coro[T]'], - Callable[P, T], - ] + MaybeAwaitableFunc = Callable[P, 'MaybeAwaitable[T]'] else: P = TypeVar('P') - MaybeCoroFunc = Tuple[P, T] + MaybeAwaitableFunc = Tuple[P, T] _Bot = Bot Coro = Coroutine[Any, Any, T] -MaybeCoro = Union[T, Coro[T]] CoroFunc = Callable[..., Coro[Any]] +MaybeCoro = Union[T, Coro[T]] +MaybeAwaitable = Union[T, Awaitable[T]] Check = Union[Callable[["Cog", "ContextT"], MaybeCoro[bool]], Callable[["ContextT"], MaybeCoro[bool]]] Hook = Union[Callable[["Cog", "ContextT"], Coro[Any]], Callable[["ContextT"], Coro[Any]]] diff --git a/discord/ext/commands/bot.py b/discord/ext/commands/bot.py index a910a0722..3a83006b3 100644 --- a/discord/ext/commands/bot.py +++ b/discord/ext/commands/bot.py @@ -65,18 +65,18 @@ if TYPE_CHECKING: import importlib.machinery from discord.message import Message - from discord.abc import User, Snowflake + from discord.abc import User from ._types import ( _Bot, BotT, Check, CoroFunc, ContextT, - MaybeCoroFunc, + MaybeAwaitableFunc, ) _Prefix = Union[Iterable[str], str] - _PrefixCallable = MaybeCoroFunc[[BotT, Message], _Prefix] + _PrefixCallable = MaybeAwaitableFunc[[BotT, Message], _Prefix] PrefixType = Union[_Prefix, _PrefixCallable[BotT]] __all__ = ( @@ -152,22 +152,20 @@ class BotBase(GroupMixin[None]): def __init__( self, command_prefix: PrefixType[BotT], - help_command: Optional[HelpCommand[Any]] = _default, + help_command: Optional[HelpCommand] = _default, description: Optional[str] = None, **options: Any, ) -> None: super().__init__(**options) self.command_prefix: PrefixType[BotT] = command_prefix self.extra_events: Dict[str, List[CoroFunc]] = {} - # Self doesn't have the ClientT bound, but since this is a mixin it technically does - self.__tree: app_commands.CommandTree[Self] = app_commands.CommandTree(self) # type: ignore self.__cogs: Dict[str, Cog] = {} self.__extensions: Dict[str, types.ModuleType] = {} self._checks: List[Check] = [] self._check_once: List[Check] = [] self._before_invoke: Optional[CoroFunc] = None self._after_invoke: Optional[CoroFunc] = None - self._help_command: Optional[HelpCommand[Any]] = None + self._help_command: Optional[HelpCommand] = None self.description: str = inspect.cleandoc(description) if description else '' self.owner_id: Optional[int] = options.get('owner_id') self.owner_ids: Optional[Collection[int]] = options.get('owner_ids', set()) @@ -594,8 +592,6 @@ class BotBase(GroupMixin[None]): /, *, override: bool = False, - guild: Optional[Snowflake] = MISSING, - guilds: List[Snowflake] = MISSING, ) -> None: """|coro| @@ -603,9 +599,6 @@ class BotBase(GroupMixin[None]): A cog is a class that has its own event listeners and commands. - If the cog is a :class:`.app_commands.Group` then it is added to - the bot's :class:`~discord.app_commands.CommandTree` as well. - .. note:: Exceptions raised inside a :class:`.Cog`'s :meth:`~.Cog.cog_load` method will be @@ -632,19 +625,6 @@ class BotBase(GroupMixin[None]): If a previously loaded cog with the same name should be ejected instead of raising an error. - .. versionadded:: 2.0 - guild: Optional[:class:`~discord.abc.Snowflake`] - If the cog is an application command group, then this would be the - guild where the cog group would be added to. If not given then - it becomes a global command instead. - - .. versionadded:: 2.0 - guilds: List[:class:`~discord.abc.Snowflake`] - If the cog is an application command group, then this would be the - guilds where the cog group would be added to. If not given then - it becomes a global command instead. Cannot be mixed with - ``guild``. - .. versionadded:: 2.0 Raises @@ -666,12 +646,9 @@ class BotBase(GroupMixin[None]): if existing is not None: if not override: raise discord.ClientException(f'Cog named {cog_name!r} already loaded') - await self.remove_cog(cog_name, guild=guild, guilds=guilds) - - if isinstance(cog, app_commands.Group): - self.__tree.add_command(cog, override=override, guild=guild, guilds=guilds) + await self.remove_cog(cog_name) - cog = await cog._inject(self, override=override, guild=guild, guilds=guilds) + cog = await cog._inject(self, override=override) self.__cogs[cog_name] = cog def get_cog(self, name: str, /) -> Optional[Cog]: @@ -701,9 +678,6 @@ class BotBase(GroupMixin[None]): self, name: str, /, - *, - guild: Optional[Snowflake] = MISSING, - guilds: List[Snowflake] = MISSING, ) -> Optional[Cog]: """|coro| @@ -726,19 +700,6 @@ class BotBase(GroupMixin[None]): ----------- name: :class:`str` The name of the cog to remove. - guild: Optional[:class:`~discord.abc.Snowflake`] - If the cog is an application command group, then this would be the - guild where the cog group would be removed from. If not given then - a global command is removed instead instead. - - .. versionadded:: 2.0 - guilds: List[:class:`~discord.abc.Snowflake`] - If the cog is an application command group, then this would be the - guilds where the cog group would be removed from. If not given then - a global command is removed instead instead. Cannot be mixed with - ``guild``. - - .. versionadded:: 2.0 Returns ------- @@ -754,15 +715,7 @@ class BotBase(GroupMixin[None]): if help_command and help_command.cog is cog: help_command.cog = None - guild_ids = _retrieve_guild_ids(cog, guild, guilds) - if isinstance(cog, app_commands.Group): - if guild_ids is None: - self.__tree.remove_command(name) - else: - for guild_id in guild_ids: - self.__tree.remove_command(name, guild=discord.Object(guild_id)) - - await cog._eject(self, guild_ids=guild_ids) + await cog._eject(self) return cog @@ -797,9 +750,6 @@ class BotBase(GroupMixin[None]): for index in reversed(remove): del event_list[index] - # remove all relevant application commands from the tree - self.__tree._remove_with_module(name) - async def _call_module_finalizers(self, lib: types.ModuleType, key: str) -> None: try: func = getattr(lib, 'teardown') @@ -1023,11 +973,11 @@ class BotBase(GroupMixin[None]): # help command stuff @property - def help_command(self) -> Optional[HelpCommand[Any]]: + def help_command(self) -> Optional[HelpCommand]: return self._help_command @help_command.setter - def help_command(self, value: Optional[HelpCommand[Any]]) -> None: + def help_command(self, value: Optional[HelpCommand]) -> None: if value is not None: if not isinstance(value, HelpCommand): raise TypeError('help_command must be a subclass of HelpCommand') @@ -1041,20 +991,6 @@ class BotBase(GroupMixin[None]): else: self._help_command = None - # application command interop - - # As mentioned above, this is a mixin so the Self type hint fails here. - # However, since the only classes that can use this are subclasses of Client - # anyway, then this is sound. - @property - def tree(self) -> app_commands.CommandTree[Self]: # type: ignore - """:class:`~discord.app_commands.CommandTree`: The command tree responsible for handling the application commands - in this bot. - - .. versionadded:: 2.0 - """ - return self.__tree - # command processing async def get_prefix(self, message: Message, /) -> Union[List[str], str]: @@ -1079,6 +1015,7 @@ class BotBase(GroupMixin[None]): listening for. """ prefix = ret = self.command_prefix + if callable(prefix): # self will be a Bot or AutoShardedBot ret = await discord.utils.maybe_coroutine(prefix, self, message) # type: ignore @@ -1097,9 +1034,6 @@ class BotBase(GroupMixin[None]): f"returning either of these, not {ret.__class__.__name__}" ) - if not ret: - raise ValueError("Iterable command_prefix must contain at least one prefix") - return ret @overload @@ -1306,8 +1240,7 @@ class Bot(BotBase, discord.Client): The command prefix could also be an iterable of strings indicating that multiple checks for the prefix should be used and the first one to match will be the invocation prefix. You can get this prefix via - :attr:`.Context.prefix`. To avoid confusion empty iterables are not - allowed. + :attr:`.Context.prefix`. .. note:: diff --git a/discord/ext/commands/cog.py b/discord/ext/commands/cog.py index 1cd4f8895..b69a47d25 100644 --- a/discord/ext/commands/cog.py +++ b/discord/ext/commands/cog.py @@ -25,16 +25,14 @@ from __future__ import annotations import inspect import discord -from discord import app_commands from discord.utils import maybe_coroutine -from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, TYPE_CHECKING, Tuple, TypeVar, Union +from typing import Any, Callable, Dict, Generator, List, Optional, TYPE_CHECKING, Tuple, TypeVar from ._types import _BaseCommand, BotT if TYPE_CHECKING: from typing_extensions import Self - from discord.abc import Snowflake from .bot import BotBase from .context import Context @@ -113,34 +111,23 @@ class CogMeta(type): __cog_name__: str __cog_settings__: Dict[str, Any] __cog_commands__: List[Command[Any, ..., Any]] - __cog_is_app_commands_group__: bool - __cog_app_commands__: List[Union[app_commands.Group, app_commands.Command[Any, ..., Any]]] __cog_listeners__: List[Tuple[str, str]] def __new__(cls, *args: Any, **kwargs: Any) -> Self: name, bases, attrs = args attrs['__cog_name__'] = kwargs.get('name', name) attrs['__cog_settings__'] = kwargs.pop('command_attrs', {}) - is_parent = any(issubclass(base, app_commands.Group) for base in bases) - attrs['__cog_is_app_commands_group__'] = is_parent description = kwargs.get('description', None) if description is None: description = inspect.cleandoc(attrs.get('__doc__', '')) attrs['__cog_description__'] = description - if is_parent: - attrs['__discord_app_commands_skip_init_binding__'] = True - # This is hacky, but it signals the Group not to process this info. - # It's overridden later. - attrs['__discord_app_commands_group_children__'] = True - else: - # Remove the extraneous keyword arguments we're using - kwargs.pop('name', None) - kwargs.pop('description', None) + # Remove the extraneous keyword arguments we're using + kwargs.pop('name', None) + kwargs.pop('description', None) commands = {} - cog_app_commands = {} listeners = {} no_bot_cog = 'Commands or listeners must not start with cog_ or bot_ (in method {0.__name__}.{1})' @@ -161,8 +148,6 @@ class CogMeta(type): if elem.startswith(('cog_', 'bot_')): raise TypeError(no_bot_cog.format(base, elem)) commands[elem] = value - elif isinstance(value, (app_commands.Group, app_commands.Command)) and value.parent is None: - cog_app_commands[elem] = value elif inspect.iscoroutinefunction(value): try: getattr(value, '__cog_listener__') @@ -174,13 +159,6 @@ class CogMeta(type): listeners[elem] = value new_cls.__cog_commands__ = list(commands.values()) # this will be copied in Cog.__new__ - new_cls.__cog_app_commands__ = list(cog_app_commands.values()) - - if is_parent: - # Prefill the app commands for the Group as well.. - # The type checker doesn't like runtime attribute modification and this one's - # optional so it can't be cheesed. - new_cls.__discord_app_commands_group_children__ = new_cls.__cog_app_commands__ # type: ignore listeners_as_list = [] for listener in listeners.values(): @@ -219,7 +197,6 @@ class Cog(metaclass=CogMeta): __cog_name__: str __cog_settings__: Dict[str, Any] __cog_commands__: List[Command[Self, ..., Any]] - __cog_app_commands__: List[Union[app_commands.Group, app_commands.Command[Self, ..., Any]]] __cog_listeners__: List[Tuple[str, str]] def __new__(cls, *args: Any, **kwargs: Any) -> Self: @@ -247,27 +224,6 @@ class Cog(metaclass=CogMeta): parent.remove_command(command.name) # type: ignore parent.add_command(command) # type: ignore - # Register the application commands - children: List[Union[app_commands.Group, app_commands.Command[Self, ..., Any]]] = [] - for command in cls.__cog_app_commands__: - if cls.__cog_is_app_commands_group__: - # Type checker doesn't understand this type of narrowing. - # Not even with TypeGuard somehow. - command.parent = self # type: ignore - - copy = command._copy_with_binding(self) - - children.append(copy) - if command._attr: - setattr(self, command._attr, copy) - - self.__cog_app_commands__ = children - if cls.__cog_is_app_commands_group__: - # Dynamic attribute setting - self.__discord_app_commands_group_children__ = children # type: ignore - # Enforce this to work even if someone forgets __init__ - self.module = cls.__module__ # type: ignore - return self def get_commands(self) -> List[Command[Self, ..., Any]]: @@ -485,7 +441,7 @@ class Cog(metaclass=CogMeta): """ pass - async def _inject(self, bot: BotBase, override: bool, guild: Optional[Snowflake], guilds: List[Snowflake]) -> Self: + async def _inject(self, bot: BotBase, override: bool) -> Self: cls = self.__class__ # we'll call this first so that errors can propagate without @@ -523,15 +479,9 @@ class Cog(metaclass=CogMeta): for name, method_name in self.__cog_listeners__: bot.add_listener(getattr(self, method_name), name) - # Only do this if these are "top level" commands - if not cls.__cog_is_app_commands_group__: - for command in self.__cog_app_commands__: - # This is already atomic - bot.tree.add_command(command, override=override, guild=guild, guilds=guilds) - return self - async def _eject(self, bot: BotBase, guild_ids: Optional[Iterable[int]]) -> None: + async def _eject(self, bot: BotBase) -> None: cls = self.__class__ try: @@ -539,15 +489,6 @@ class Cog(metaclass=CogMeta): if command.parent is None: bot.remove_command(command.name) - if not cls.__cog_is_app_commands_group__: - for command in self.__cog_app_commands__: - guild_ids = guild_ids or command._guild_ids - if guild_ids is None: - bot.tree.remove_command(command.name) - else: - for guild_id in guild_ids: - bot.tree.remove_command(command.name, guild=discord.Object(id=guild_id)) - for name, method_name in self.__cog_listeners__: bot.remove_listener(getattr(self, method_name), name) diff --git a/discord/ext/commands/context.py b/discord/ext/commands/context.py index 2192b70dd..55108d9d7 100644 --- a/discord/ext/commands/context.py +++ b/discord/ext/commands/context.py @@ -45,7 +45,6 @@ if TYPE_CHECKING: from discord.user import ClientUser, User from discord.voice_client import VoiceProtocol - from .bot import Bot from .cog import Cog from .core import Command from .view import StringView @@ -95,6 +94,11 @@ class Context(discord.abc.Messageable, Generic[BotT]): The parameter that is currently being inspected and converted. This is only of use for within converters. + .. versionadded:: 2.0 + current_argument: Optional[:class:`str`] + The argument string of the :attr:`current_parameter` that is currently being converted. + This is only of use for within converters. + .. versionadded:: 2.0 prefix: Optional[:class:`str`] The prefix that was used to invoke the command. @@ -140,6 +144,7 @@ class Context(discord.abc.Messageable, Generic[BotT]): subcommand_passed: Optional[str] = None, command_failed: bool = False, current_parameter: Optional[inspect.Parameter] = None, + current_argument: Optional[str] = None, ): self.message: Message = message self.bot: BotT = bot @@ -154,6 +159,7 @@ class Context(discord.abc.Messageable, Generic[BotT]): self.subcommand_passed: Optional[str] = subcommand_passed self.command_failed: bool = command_failed self.current_parameter: Optional[inspect.Parameter] = current_parameter + self.current_argument: Optional[str] = current_argument self._state: ConnectionState = self.message._state async def invoke(self, command: Command[CogT, P, T], /, *args: P.args, **kwargs: P.kwargs) -> T: @@ -353,7 +359,6 @@ class Context(discord.abc.Messageable, Generic[BotT]): """ from .core import Group, Command, wrap_callback from .errors import CommandError - from .help import _context bot = self.bot cmd = bot.help_command @@ -361,7 +366,8 @@ class Context(discord.abc.Messageable, Generic[BotT]): if cmd is None: return None - _context.set(self) + cmd = cmd.copy() + cmd.context = self if len(args) == 0: await cmd.prepare_help_command(self, None) diff --git a/discord/ext/commands/converter.py b/discord/ext/commands/converter.py index ddd688b44..28de883a9 100644 --- a/discord/ext/commands/converter.py +++ b/discord/ext/commands/converter.py @@ -74,7 +74,6 @@ __all__ = ( 'PartialEmojiConverter', 'CategoryChannelConverter', 'IDConverter', - 'StoreChannelConverter', 'ThreadConverter', 'GuildChannelConverter', 'GuildStickerConverter', @@ -375,7 +374,7 @@ class PartialMessageConverter(Converter[discord.PartialMessage]): guild_id, message_id, channel_id = self._get_id_matches(ctx, argument) channel = self._resolve_channel(ctx, guild_id, channel_id) if not channel or not isinstance(channel, discord.abc.Messageable): - raise ChannelNotFound(channel_id) # type: ignore - channel_id won't be None here + raise ChannelNotFound(channel_id) # type: ignore # channel_id won't be None here return discord.PartialMessage(channel=channel, id=message_id) @@ -407,7 +406,7 @@ class MessageConverter(IDConverter[discord.Message]): except discord.NotFound: raise MessageNotFound(argument) except discord.Forbidden: - raise ChannelNotReadable(channel) # type: ignore - type-checker thinks channel could be a DMChannel at this point + raise ChannelNotReadable(channel) # type: ignore # type-checker thinks channel could be a DMChannel at this point class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]): @@ -462,8 +461,6 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]): @staticmethod def _resolve_thread(ctx: Context[BotT], argument: str, attribute: str, type: Type[TT]) -> TT: - bot = ctx.bot - match = IDConverter._get_id_match(argument) or re.match(r'<#([0-9]{15,20})>$', argument) result = None guild = ctx.guild @@ -563,25 +560,6 @@ class CategoryChannelConverter(IDConverter[discord.CategoryChannel]): return GuildChannelConverter._resolve_channel(ctx, argument, 'categories', discord.CategoryChannel) -class StoreChannelConverter(IDConverter[discord.StoreChannel]): - """Converts to a :class:`~discord.StoreChannel`. - - All lookups are via the local guild. If in a DM context, then the lookup - is done by the global cache. - - The lookup strategy is as follows (in order): - - 1. Lookup by ID. - 2. Lookup by mention. - 3. Lookup by name. - - .. versionadded:: 1.7 - """ - - async def convert(self, ctx: Context[BotT], argument: str) -> discord.StoreChannel: - return GuildChannelConverter._resolve_channel(ctx, argument, 'channels', discord.StoreChannel) - - class ThreadConverter(IDConverter[discord.Thread]): """Coverts to a :class:`~discord.Thread`. @@ -1118,7 +1096,6 @@ CONVERTER_MAPPING: Dict[type, Any] = { discord.Emoji: EmojiConverter, discord.PartialEmoji: PartialEmojiConverter, discord.CategoryChannel: CategoryChannelConverter, - discord.StoreChannel: StoreChannelConverter, discord.Thread: ThreadConverter, discord.abc.GuildChannel: GuildChannelConverter, discord.GuildSticker: GuildStickerConverter, diff --git a/discord/ext/commands/cooldowns.py b/discord/ext/commands/cooldowns.py index 875ef145f..7aadf0506 100644 --- a/discord/ext/commands/cooldowns.py +++ b/discord/ext/commands/cooldowns.py @@ -84,7 +84,7 @@ class Cooldown: Attributes ----------- - rate: :class:`int` + rate: :class:`float` The total number of tokens available per :attr:`per` seconds. per: :class:`float` The length of the cooldown period in seconds. @@ -179,7 +179,7 @@ class Cooldown: self._tokens = self.rate self._last = 0.0 - def copy(self) -> Cooldown: + def copy(self) -> Self: """Creates a copy of this cooldown. Returns diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index 865629edc..026a84bac 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -519,9 +519,8 @@ class Command(_BaseCommand, Generic[CogT, P, T]): other.checks = self.checks.copy() if self._buckets.valid and not other._buckets.valid: other._buckets = self._buckets.copy() - if self._max_concurrency != other._max_concurrency: - # _max_concurrency won't be None at this point - other._max_concurrency = self._max_concurrency.copy() # type: ignore + if self._max_concurrency and self._max_concurrency != other._max_concurrency: + other._max_concurrency = self._max_concurrency.copy() try: other.on_error = self.on_error @@ -605,10 +604,10 @@ class Command(_BaseCommand, Generic[CogT, P, T]): previous = view.index if consume_rest_is_special: - argument = view.read_rest().strip() + ctx.current_argument = argument = view.read_rest().strip() else: try: - argument = view.get_quoted_word() + ctx.current_argument = argument = view.get_quoted_word() except ArgumentParsingError as exc: if self._is_typing_optional(param.annotation): view.index = previous @@ -631,7 +630,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): view.skip_ws() try: - argument = view.get_quoted_word() + ctx.current_argument = argument = view.get_quoted_word() value = await run_converters(ctx, converter, argument, param) # type: ignore except (CommandError, ArgumentParsingError): view.index = previous @@ -647,7 +646,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): view = ctx.view previous = view.index try: - argument = view.get_quoted_word() + ctx.current_argument = argument = view.get_quoted_word() value = await run_converters(ctx, converter, argument, param) # type: ignore except (CommandError, ArgumentParsingError): view.index = previous @@ -664,6 +663,15 @@ class Command(_BaseCommand, Generic[CogT, P, T]): """ return self.params.copy() + @property + def cooldown(self) -> Optional[Cooldown]: + """Optional[:class:`.Cooldown`]: The cooldown of a command when invoked + or ``None`` if the command doesn't have a registered cooldown. + + .. versionadded:: 2.0 + """ + return self._buckets._cooldown + @property def full_parent_name(self) -> str: """:class:`str`: Retrieves the fully qualified parent command name. @@ -746,7 +754,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): # kwarg only param denotes "consume rest" semantics if self.rest_is_raw: converter = get_converter(param) - argument = view.read_rest() + ctx.current_argument = argument = view.read_rest() kwargs[name] = await run_converters(ctx, converter, argument, param) else: kwargs[name] = await self.transform(ctx, param) @@ -1622,7 +1630,7 @@ def command( [ Union[ Callable[Concatenate[ContextT, P], Coro[Any]], - Callable[Concatenate[CogT, ContextT, P], Coro[Any]], # type: ignore - CogT is used here to allow covariance + Callable[Concatenate[CogT, ContextT, P], Coro[Any]], # type: ignore # CogT is used here to allow covariance ] ], CommandT, @@ -1691,7 +1699,7 @@ def group( ) -> Callable[ [ Union[ - Callable[Concatenate[CogT, ContextT, P], Coro[Any]], # type: ignore - CogT is used here to allow covariance + Callable[Concatenate[CogT, ContextT, P], Coro[Any]], # type: ignore # CogT is used here to allow covariance Callable[Concatenate[ContextT, P], Coro[Any]], ] ], @@ -2294,8 +2302,8 @@ def dynamic_cooldown( This differs from :func:`.cooldown` in that it takes a function that accepts a single parameter of type :class:`.discord.Message` and must - return a :class:`.Cooldown` or ``None``. If ``None`` is returned then - that cooldown is effectively bypassed. + return a :class:`.Cooldown` or ``None``. + If ``None`` is returned then that cooldown is effectively bypassed. A cooldown allows a command to only be used a specific amount of times in a specific time frame. These cooldowns can be based diff --git a/discord/ext/commands/errors.py b/discord/ext/commands/errors.py index 2b0567b5b..d9e46d8bb 100644 --- a/discord/ext/commands/errors.py +++ b/discord/ext/commands/errors.py @@ -24,7 +24,7 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations -from typing import Optional, Any, TYPE_CHECKING, List, Callable, Type, Tuple, Union +from typing import Optional, Any, TYPE_CHECKING, List, Callable, Tuple, Union from discord.errors import ClientException, DiscordException diff --git a/discord/ext/commands/help.py b/discord/ext/commands/help.py index c5c924dbf..01458ab28 100644 --- a/discord/ext/commands/help.py +++ b/discord/ext/commands/help.py @@ -24,8 +24,8 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations -from contextvars import ContextVar import itertools +import copy import functools import re @@ -33,12 +33,12 @@ from typing import ( TYPE_CHECKING, Optional, Generator, - Generic, List, TypeVar, Callable, Any, Dict, + Tuple, Iterable, Sequence, Mapping, @@ -50,14 +50,21 @@ from .core import Group, Command, get_signature_parameters from .errors import CommandError if TYPE_CHECKING: + from typing_extensions import Self import inspect import discord.abc - from ._types import Coro from .bot import BotBase - from .cog import Cog from .context import Context + from .cog import Cog + + from ._types import ( + Check, + ContextT, + BotT, + _Bot, + ) __all__ = ( 'Paginator', @@ -66,11 +73,7 @@ __all__ = ( 'MinimalHelpCommand', ) -T = TypeVar('T') - -ContextT = TypeVar('ContextT', bound='Context') FuncT = TypeVar('FuncT', bound=Callable[..., Any]) -HelpCommandCommand = Command[Optional['Cog'], ... if TYPE_CHECKING else Any, Any] MISSING: Any = discord.utils.MISSING @@ -216,12 +219,92 @@ def _not_overridden(f: FuncT) -> FuncT: return f -_context: ContextVar[Optional[Context]] = ContextVar('context', default=None) +class _HelpCommandImpl(Command): + def __init__(self, inject: HelpCommand, *args: Any, **kwargs: Any) -> None: + super().__init__(inject.command_callback, *args, **kwargs) + self._original: HelpCommand = inject + self._injected: HelpCommand = inject + self.params: Dict[str, inspect.Parameter] = get_signature_parameters( + inject.command_callback, globals(), skip_parameters=1 + ) + + async def prepare(self, ctx: Context[Any]) -> None: + self._injected = injected = self._original.copy() + injected.context = ctx + self.callback = injected.command_callback + self.params = get_signature_parameters(injected.command_callback, globals(), skip_parameters=1) + + on_error = injected.on_help_command_error + if not hasattr(on_error, '__help_command_not_overridden__'): + if self.cog is not None: + self.on_error = self._on_error_cog_implementation + else: + self.on_error = on_error + + await super().prepare(ctx) + + async def _parse_arguments(self, ctx: Context[BotT]) -> None: + # Make the parser think we don't have a cog so it doesn't + # inject the parameter into `ctx.args`. + original_cog = self.cog + self.cog = None + try: + await super()._parse_arguments(ctx) + finally: + self.cog = original_cog + + async def _on_error_cog_implementation(self, _, ctx: Context[BotT], error: CommandError) -> None: + await self._injected.on_help_command_error(ctx, error) + + def _inject_into_cog(self, cog: Cog) -> None: + # Warning: hacky + # Make the cog think that get_commands returns this command + # as well if we inject it without modifying __cog_commands__ + # since that's used for the injection and ejection of cogs. + def wrapped_get_commands( + *, _original: Callable[[], List[Command[Any, ..., Any]]] = cog.get_commands + ) -> List[Command[Any, ..., Any]]: + ret = _original() + ret.append(self) + return ret -class HelpCommand(HelpCommandCommand, Generic[ContextT]): + # Ditto here + def wrapped_walk_commands( + *, _original: Callable[[], Generator[Command[Any, ..., Any], None, None]] = cog.walk_commands + ): + yield from _original() + yield self + + functools.update_wrapper(wrapped_get_commands, cog.get_commands) + functools.update_wrapper(wrapped_walk_commands, cog.walk_commands) + cog.get_commands = wrapped_get_commands + cog.walk_commands = wrapped_walk_commands + self.cog = cog + + def _eject_cog(self) -> None: + if self.cog is None: + return + + # revert back into their original methods + cog = self.cog + cog.get_commands = cog.get_commands.__wrapped__ + cog.walk_commands = cog.walk_commands.__wrapped__ + self.cog = None + + +class HelpCommand: r"""The base implementation for help command formatting. + .. note:: + + Internally instances of this class are deep copied every time + the command itself is invoked to prevent a race condition + mentioned in :issue:`2123`. + + This means that relying on the state of this class to be + the same between command invocations would not work as expected. + Attributes ------------ context: Optional[:class:`Context`] @@ -253,67 +336,88 @@ class HelpCommand(HelpCommandCommand, Generic[ContextT]): MENTION_PATTERN = re.compile('|'.join(MENTION_TRANSFORMS.keys())) - def __init__( - self, - *, - show_hidden: bool = False, - verify_checks: bool = True, - command_attrs: Dict[str, Any] = MISSING, - ) -> None: - self.show_hidden: bool = show_hidden - self.verify_checks: bool = verify_checks + if TYPE_CHECKING: + __original_kwargs__: Dict[str, Any] + __original_args__: Tuple[Any, ...] + + def __new__(cls, *args: Any, **kwargs: Any) -> Self: + # To prevent race conditions of a single instance while also allowing + # for settings to be passed the original arguments passed must be assigned + # to allow for easier copies (which will be made when the help command is actually called) + # see issue 2123 + self = super().__new__(cls) + + # Shallow copies cannot be used in this case since it is not unusual to pass + # instances that need state, e.g. Paginator or what have you into the function + # The keys can be safely copied as-is since they're 99.99% certain of being + # string keys + deepcopy = copy.deepcopy + self.__original_kwargs__ = {k: deepcopy(v) for k, v in kwargs.items()} + self.__original_args__ = deepcopy(args) + return self + + def __init__(self, **options: Any) -> None: + self.show_hidden: bool = options.pop('show_hidden', False) + self.verify_checks: bool = options.pop('verify_checks', True) self.command_attrs: Dict[str, Any] - self.command_attrs = attrs = command_attrs if command_attrs is not MISSING else {} + self.command_attrs = attrs = options.pop('command_attrs', {}) attrs.setdefault('name', 'help') attrs.setdefault('help', 'Shows this message') - self._cog: Optional[Cog] = None - super().__init__(self._set_context, **attrs) - self.params: Dict[str, inspect.Parameter] = get_signature_parameters( - self.command_callback, globals(), skip_parameters=1 - ) - if not hasattr(self.on_help_command_error, '__help_command_not_overridden__'): - self.on_error = self.on_help_command_error + self.context: Context[_Bot] = MISSING + self._command_impl = _HelpCommandImpl(self, **self.command_attrs) - async def __call__(self, context: ContextT, /, *args: Any, **kwargs: Any) -> Any: - return await self._set_context(context, *args, **kwargs) - - async def _set_context(self, context: ContextT, *args: Any, **kwargs: Any) -> Any: - _context.set(context) - return await self.command_callback(context, *args, **kwargs) - - @property - def context(self) -> ContextT: - ctx = _context.get() - if ctx is None: - raise AttributeError('context attribute cannot be accessed in non command-invocation contexts.') - return ctx # type: ignore + def copy(self) -> Self: + obj = self.__class__(*self.__original_args__, **self.__original_kwargs__) + obj._command_impl = self._command_impl + return obj def _add_to_bot(self, bot: BotBase) -> None: - bot.add_command(self) # type: ignore + command = _HelpCommandImpl(self, **self.command_attrs) + bot.add_command(command) + self._command_impl = command def _remove_from_bot(self, bot: BotBase) -> None: - bot.remove_command(self.name) - self._eject_cog() + bot.remove_command(self._command_impl.name) + self._command_impl._eject_cog() - async def _call_without_cog(self, callback: Callable[[ContextT], Coro[T]], ctx: ContextT) -> T: - cog = self._cog - self.cog = None - try: - return await callback(ctx) - finally: - self.cog = cog + def add_check(self, func: Check[ContextT], /) -> None: + """ + Adds a check to the help command. + + .. versionadded:: 1.4 + + .. versionchanged:: 2.0 + + ``func`` parameter is now positional-only. + + Parameters + ---------- + func + The function that will be used as a check. + """ + + self._command_impl.add_check(func) - async def _parse_arguments(self, ctx: ContextT) -> None: - return await self._call_without_cog(super()._parse_arguments, ctx) + def remove_check(self, func: Check[ContextT], /) -> None: + """ + Removes a check from the help command. + + This function is idempotent and will not raise an exception if + the function is not in the command's checks. + + .. versionadded:: 1.4 + + .. versionchanged:: 2.0 - async def call_before_hooks(self, ctx: ContextT, /) -> None: - return await self._call_without_cog(super().call_before_hooks, ctx) + ``func`` parameter is now positional-only. - async def call_after_hooks(self, ctx: ContextT, /) -> None: - return await self._call_without_cog(super().call_after_hooks, ctx) + Parameters + ---------- + func + The function to remove from the checks. + """ - async def can_run(self, ctx: ContextT, /) -> bool: - return await self._call_without_cog(super().can_run, ctx) + self._command_impl.remove_check(func) def get_bot_mapping(self) -> Dict[Optional[Cog], List[Command[Any, ..., Any]]]: """Retrieves the bot mapping passed to :meth:`send_bot_help`.""" @@ -337,7 +441,7 @@ class HelpCommand(HelpCommandCommand, Generic[ContextT]): Optional[:class:`str`] The command name that triggered this invocation. """ - command_name = self.name + command_name = self._command_impl.name ctx = self.context if ctx is MISSING or ctx.command is None or ctx.command.qualified_name != command_name: return command_name @@ -361,7 +465,7 @@ class HelpCommand(HelpCommandCommand, Generic[ContextT]): The signature for the command. """ - parent: Optional[Group[Any, ..., Any]] = command.parent # type: ignore - the parent will be a Group + parent: Optional[Group[Any, ..., Any]] = command.parent # type: ignore # the parent will be a Group entries = [] while parent is not None: if not parent.signature or parent.invoke_without_command: @@ -402,61 +506,31 @@ class HelpCommand(HelpCommandCommand, Generic[ContextT]): return self.MENTION_PATTERN.sub(replace, string) - async def _on_error_cog_implementation(self, _, ctx: ContextT, error: CommandError) -> None: - await self.on_help_command_error(ctx, error) - - def _inject_into_cog(self, cog: Cog) -> None: - # Warning: hacky - - # Make the cog think that get_commands returns this command - # as well if we inject it without modifying __cog_commands__ - # since that's used for the injection and ejection of cogs. - def wrapped_get_commands( - *, _original: Callable[[], List[Command[Any, ..., Any]]] = cog.get_commands - ) -> List[Command[Any, ..., Any]]: - ret = _original() - ret.append(self) - return ret - - # Ditto here - def wrapped_walk_commands( - *, _original: Callable[[], Generator[Command[Any, ..., Any], None, None]] = cog.walk_commands - ): - yield from _original() - yield self - - functools.update_wrapper(wrapped_get_commands, cog.get_commands) - functools.update_wrapper(wrapped_walk_commands, cog.walk_commands) - cog.get_commands = wrapped_get_commands - cog.walk_commands = wrapped_walk_commands - if not hasattr(self.on_help_command_error, '__help_command_not_overridden__'): - self.on_error = self._on_error_cog_implementation - self._cog = cog + @property + def cog(self) -> Optional[Cog]: + """A property for retrieving or setting the cog for the help command. - def _eject_cog(self) -> None: - if self._cog is None: - return + When a cog is set for the help command, it is as-if the help command + belongs to that cog. All cog special methods will apply to the help + command and it will be automatically unset on unload. - # revert back into their original methods - if not hasattr(self.on_help_command_error, '__help_command_not_overridden__'): - self.on_error = self.on_help_command_error - cog = self._cog - cog.get_commands = cog.get_commands.__wrapped__ - cog.walk_commands = cog.walk_commands.__wrapped__ - self._cog = None + To unbind the cog from the help command, you can set it to ``None``. - @property - def cog(self) -> Optional[Cog]: - return self._cog + Returns + -------- + Optional[:class:`Cog`] + The cog that is currently set for the help command. + """ + return self._command_impl.cog @cog.setter def cog(self, cog: Optional[Cog]) -> None: # Remove whatever cog is currently valid, if any - self._eject_cog() + self._command_impl._eject_cog() # If a new cog is set then inject it. if cog is not None: - self._inject_into_cog(cog) + self._command_impl._inject_into_cog(cog) def command_not_found(self, string: str, /) -> str: """|maybecoro| @@ -561,7 +635,7 @@ class HelpCommand(HelpCommandCommand, Generic[ContextT]): if self.verify_checks is False: # if we do not need to verify the checks then we can just # run it straight through normally without using await. - return sorted(iterator, key=key) if sort else list(iterator) # type: ignore - the key shouldn't be None + return sorted(iterator, key=key) if sort else list(iterator) # type: ignore # the key shouldn't be None if self.verify_checks is None and not self.context.guild: # if verify_checks is None and we're in a DM, don't verify @@ -648,7 +722,7 @@ class HelpCommand(HelpCommandCommand, Generic[ContextT]): await destination.send(error) @_not_overridden - async def on_help_command_error(self, ctx: ContextT, error: CommandError, /) -> None: + async def on_help_command_error(self, ctx: Context[BotT], error: CommandError, /) -> None: """|coro| The help command's error handler, as specified by :ref:`ext_commands_error_handler`. @@ -811,7 +885,7 @@ class HelpCommand(HelpCommandCommand, Generic[ContextT]): """ return None - async def prepare_help_command(self, ctx: ContextT, command: Optional[str] = None, /) -> None: + async def prepare_help_command(self, ctx: Context[BotT], command: Optional[str] = None, /) -> None: """|coro| A low level method that can be used to prepare the help command @@ -839,7 +913,7 @@ class HelpCommand(HelpCommandCommand, Generic[ContextT]): """ pass - async def command_callback(self, ctx: ContextT, /, *, command: Optional[str] = None) -> Any: + async def command_callback(self, ctx: Context[BotT], /, *, command: Optional[str] = None) -> None: """|coro| The actual implementation of the help command. @@ -889,7 +963,7 @@ class HelpCommand(HelpCommandCommand, Generic[ContextT]): for key in keys[1:]: try: - found = cmd.all_commands.get(key) + found = cmd.all_commands.get(key) # type: ignore except AttributeError: string = await maybe_coro(self.subcommand_not_found, cmd, self.remove_mentions(key)) return await self.send_error_message(string) @@ -905,7 +979,7 @@ class HelpCommand(HelpCommandCommand, Generic[ContextT]): return await self.send_command_help(cmd) -class DefaultHelpCommand(HelpCommand[ContextT]): +class DefaultHelpCommand(HelpCommand): """The implementation of the default help command. This inherits from :class:`HelpCommand`. @@ -1059,7 +1133,7 @@ class DefaultHelpCommand(HelpCommand[ContextT]): else: return ctx.channel - async def prepare_help_command(self, ctx: ContextT, command: Optional[str] = None, /) -> None: + async def prepare_help_command(self, ctx: Context[BotT], command: Optional[str], /) -> None: self.paginator.clear() await super().prepare_help_command(ctx, command) @@ -1127,7 +1201,7 @@ class DefaultHelpCommand(HelpCommand[ContextT]): await self.send_pages() -class MinimalHelpCommand(HelpCommand[ContextT]): +class MinimalHelpCommand(HelpCommand): """An implementation of a help command with minimal output. This inherits from :class:`HelpCommand`. @@ -1319,7 +1393,7 @@ class MinimalHelpCommand(HelpCommand[ContextT]): else: return ctx.channel - async def prepare_help_command(self, ctx: ContextT, command: Optional[str] = None, /) -> None: + async def prepare_help_command(self, ctx: Context[BotT], command: Optional[str], /) -> None: self.paginator.clear() await super().prepare_help_command(ctx, command) diff --git a/discord/ext/commands/view.py b/discord/ext/commands/view.py index 96d086811..e287221eb 100644 --- a/discord/ext/commands/view.py +++ b/discord/ext/commands/view.py @@ -181,7 +181,7 @@ class StringView: next_char = self.get() valid_eof = not next_char or next_char.isspace() if not valid_eof: - raise InvalidEndOfQuotedStringError(next_char) # type: ignore - this will always be a string + raise InvalidEndOfQuotedStringError(next_char) # type: ignore # this will always be a string # we're quoted so it's okay return ''.join(result) diff --git a/discord/file.py b/discord/file.py index 425a1fa67..d2d207362 100644 --- a/discord/file.py +++ b/discord/file.py @@ -23,11 +23,13 @@ DEALINGS IN THE SOFTWARE. """ from __future__ import annotations -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Tuple, Union import os import io +from .utils import MISSING + # fmt: off __all__ = ( 'File', @@ -35,6 +37,14 @@ __all__ = ( # fmt: on +def _strip_spoiler(filename: str) -> Tuple[str, bool]: + stripped = filename + while stripped.startswith('SPOILER_'): + stripped = stripped[8:] # len('SPOILER_') + spoiler = stripped != filename + return stripped, spoiler + + class File: r"""A parameter object used for :meth:`abc.Messageable.send` for sending file objects. @@ -67,21 +77,22 @@ class File: .. versionadded:: 2.0 spoiler: :class:`bool` - Whether the attachment is a spoiler. + Whether the attachment is a spoiler. If left unspecified, the :attr:`~File.filename` is used + to determine if the file is a spoiler. description: Optional[:class:`str`] The file description to display, currently only supported for images. .. versionadded:: 2.0 """ - __slots__ = ('fp', 'filename', 'spoiler', 'description', '_original_pos', '_owner', '_closer') + __slots__ = ('fp', '_filename', 'spoiler', 'description', '_original_pos', '_owner', '_closer') def __init__( self, fp: Union[str, bytes, os.PathLike[Any], io.BufferedIOBase], filename: Optional[str] = None, *, - spoiler: bool = False, + spoiler: bool = MISSING, description: Optional[str] = None, ): if isinstance(fp, io.IOBase): @@ -103,18 +114,29 @@ class File: if filename is None: if isinstance(fp, str): - _, self.filename = os.path.split(fp) + _, filename = os.path.split(fp) else: - self.filename = getattr(fp, 'name', None) - else: - self.filename: Optional[str] = filename + filename = getattr(fp, 'name', 'untitled') - if spoiler and self.filename is not None and not self.filename.startswith('SPOILER_'): - self.filename = 'SPOILER_' + self.filename + self._filename, filename_spoiler = _strip_spoiler(filename) # type: ignore # the above getattr doesn't narrow the type + if spoiler is MISSING: + spoiler = filename_spoiler - self.spoiler: bool = spoiler or (self.filename is not None and self.filename.startswith('SPOILER_')) + self.spoiler: bool = spoiler self.description: Optional[str] = description + @property + def filename(self) -> str: + """:class:`str`: The filename to display when uploading to Discord. + If this is not given then it defaults to ``fp.name`` or if ``fp`` is + a string then the ``filename`` will default to the string given. + """ + return 'SPOILER_' + self._filename if self.spoiler else self._filename + + @filename.setter + def filename(self, value: str) -> None: + self._filename, self.spoiler = _strip_spoiler(value) + def reset(self, *, seek: Union[int, bool] = True) -> None: # The `seek` parameter is needed because # the retry-loop is iterated over multiple times diff --git a/discord/gateway.py b/discord/gateway.py index 1d39175a4..368433265 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -888,7 +888,7 @@ class DiscordVoiceWebSocket: self._close_code: Optional[int] = None self.secret_key: Optional[str] = None if hook: - self._hook = hook # type: ignore - type-checker doesn't like overriding methods + self._hook = hook # type: ignore # type-checker doesn't like overriding methods async def _hook(self, *args: Any) -> None: pass diff --git a/discord/guild.py b/discord/guild.py index cf89ea910..640b47b26 100644 --- a/discord/guild.py +++ b/discord/guild.py @@ -114,7 +114,7 @@ if TYPE_CHECKING: ) from .types.voice import GuildVoiceState from .permissions import Permissions - from .channel import VoiceChannel, StageChannel, TextChannel, CategoryChannel, StoreChannel + from .channel import VoiceChannel, StageChannel, TextChannel, CategoryChannel from .template import Template from .webhook import Webhook from .state import ConnectionState @@ -125,7 +125,6 @@ if TYPE_CHECKING: NewsChannel as NewsChannelPayload, VoiceChannel as VoiceChannelPayload, CategoryChannel as CategoryChannelPayload, - StoreChannel as StoreChannelPayload, StageChannel as StageChannelPayload, ) from .types.integration import IntegrationType @@ -133,7 +132,7 @@ if TYPE_CHECKING: from .types.widget import EditWidgetSettings VocalGuildChannel = Union[VoiceChannel, StageChannel] - GuildChannel = Union[VocalGuildChannel, TextChannel, CategoryChannel, StoreChannel] + GuildChannel = Union[VocalGuildChannel, TextChannel, CategoryChannel] ByCategoryItem = Tuple[Optional[CategoryChannel], List[GuildChannel]] @@ -410,7 +409,7 @@ class Guild(Hashable): ) -> Tuple[Optional[Member], VoiceState, VoiceState]: cache_flags = self._state.member_cache_flags user_id = int(data['user_id']) - channel: Optional[VocalGuildChannel] = self.get_channel(channel_id) # type: ignore - this will always be a voice channel + channel: Optional[VocalGuildChannel] = self.get_channel(channel_id) # type: ignore # this will always be a voice channel try: # Check if we should remove the voice state from cache if channel is None: @@ -454,7 +453,7 @@ class Guild(Hashable): def _from_data(self, guild: Union[GuildPayload, GuildPreviewPayload]) -> None: try: - self._member_count: int = guild['member_count'] # type: ignore - Handled below + self._member_count: int = guild['member_count'] # type: ignore # Handled below except KeyError: pass @@ -611,7 +610,7 @@ class Guild(Hashable): This is essentially used to get the member version of yourself. """ self_id = self._state.self_id - return self.get_member(self_id) # type: ignore - The self member is *always* cached + return self.get_member(self_id) # type: ignore # The self member is *always* cached @utils.cached_slot_property('_cs_joined') def joined(self) -> bool: @@ -1142,17 +1141,6 @@ class Guild(Hashable): ) -> Coroutine[Any, Any, NewsChannelPayload]: ... - @overload - def _create_channel( - self, - name: str, - channel_type: Literal[ChannelType.store], - overwrites: Mapping[Union[Role, Member], PermissionOverwrite] = ..., - category: Optional[Snowflake] = ..., - **options: Any, - ) -> Coroutine[Any, Any, StoreChannelPayload]: - ... - @overload def _create_channel( self, @@ -2007,35 +1995,117 @@ class Guild(Hashable): if ch_type in (ChannelType.group, ChannelType.private): raise InvalidData('Channel ID resolved to a private channel') - guild_id = int(data['guild_id']) # type: ignore - channel won't be a private channel + guild_id = int(data['guild_id']) # type: ignore # channel won't be a private channel if self.id != guild_id: raise InvalidData('Guild ID resolved to a different guild') - channel: GuildChannel = factory(guild=self, state=self._state, data=data) # type: ignore - channel won't be a private channel + channel: GuildChannel = factory(guild=self, state=self._state, data=data) # type: ignore # channel won't be a private channel return channel - async def bans(self) -> List[BanEntry]: - """|coro| - - Retrieves all the users that are banned from the guild as a :class:`list` of :class:`BanEntry`. + async def bans( + self, + *, + limit: Optional[int] = 1000, + before: Snowflake = MISSING, + after: Snowflake = MISSING, + ) -> AsyncIterator[BanEntry]: + """Retrieves an :term:`asynchronous iterator` of the users that are banned from the guild as a :class:`BanEntry`. You must have the :attr:`~Permissions.ban_members` permission to get this information. + .. versionchanged:: 2.0 + Due to a breaking change in Discord's API, this now returns a paginated iterator instead of a list. + + Examples + --------- + + Usage :: + + async for entry in guild.bans(limit=150): + print(entry.user, entry.reason) + + Flattening into a list :: + + bans = [entry async for entry in guild.bans(limit=2000)] + # bans is now a list of BanEntry... + + All parameters are optional. + + Parameters + ----------- + limit: Optional[:class:`int`] + The number of bans to retrieve. If ``None``, it retrieves every ban in + the guild. Note, however, that this would make it a slow operation. + Defaults to ``1000``. + before: :class:`.abc.Snowflake` + Retrieves bans before this user. + after: :class:`.abc.Snowflake` + Retrieve bans after this user. + Raises ------- Forbidden You do not have proper permissions to get the information. HTTPException An error occurred while fetching the information. + TypeError + Both ``after`` and ``before`` were provided, as Discord does not + support this type of pagination. - Returns + Yields -------- - List[:class:`BanEntry`] - A list of :class:`BanEntry` objects. + :class:`BanEntry` + The ban entry of the banned user. """ - data = await self._state.http.get_bans(self.id) - return [BanEntry(user=User(state=self._state, data=e['user']), reason=e['reason']) for e in data] + if before is not MISSING and after is not MISSING: + raise TypeError('bans pagination does not support both before and after') + + # This endpoint paginates in ascending order. + endpoint = self._state.http.get_bans + + async def _before_strategy(retrieve, before, limit): + before_id = before.id if before else None + data = await endpoint(self.id, limit=retrieve, before=before_id) + + if data: + if limit is not None: + limit -= len(data) + + before = Object(id=int(data[0]['user']['id'])) + + return data, before, limit + + async def _after_strategy(retrieve, after, limit): + after_id = after.id if after else None + data = await endpoint(self.id, limit=retrieve, after=after_id) + + if data: + if limit is not None: + limit -= len(data) + + after = Object(id=int(data[-1]['user']['id'])) + + return data, after, limit + + if before: + strategy, state = _before_strategy, before + else: + strategy, state = _after_strategy, after + + while True: + retrieve = min(1000 if limit is None else limit, 1000) + if retrieve < 1: + return + + data, state, limit = await strategy(retrieve, state, limit) + + # Terminate loop on next iteration; there's no data left after this + if len(data) < 1000: + limit = 0 + + for e in data: + yield BanEntry(user=User(state=self._state, data=e['user']), reason=e['reason']) async def prune_members( self, @@ -3165,7 +3235,7 @@ class Guild(Hashable): payload['max_uses'] = 0 payload['max_age'] = 0 payload['uses'] = payload.get('uses', 0) - return Invite(state=self._state, data=payload, guild=self, channel=channel) # type: ignore - we're faking a payload here + return Invite(state=self._state, data=payload, guild=self, channel=channel) # type: ignore # We're faking a payload here async def audit_logs( self, @@ -3638,7 +3708,7 @@ class Guild(Hashable): limit = min(100, limit or 5) members = await self._state.query_members( - self, query=query, limit=limit, user_ids=user_ids, presences=presences, cache=cache # type: ignore - The two types are compatible + self, query=query, limit=limit, user_ids=user_ids, presences=presences, cache=cache # type: ignore # The two types are compatible ) if subscribe: ids: List[_Snowflake] = [str(m.id) for m in members] diff --git a/discord/guild_folder.py b/discord/guild_folder.py index 5d65d684d..81f386b2b 100644 --- a/discord/guild_folder.py +++ b/discord/guild_folder.py @@ -65,7 +65,7 @@ class GuildFolder: self.id: Snowflake = data['id'] self.name: str = data['name'] self._colour: int = data['color'] - self.guilds: List[Guild] = list(filter(None, map(self._get_guild, data['guild_ids']))) # type: ignore - Lying for better developer UX + self.guilds: List[Guild] = list(filter(None, map(self._get_guild, data['guild_ids']))) # type: ignore # Lying for better developer UX def _get_guild(self, id): return self._state._get_guild(int(id)) or Object(id=int(id)) diff --git a/discord/http.py b/discord/http.py index c01e942ff..8f43921fb 100644 --- a/discord/http.py +++ b/discord/http.py @@ -72,7 +72,7 @@ _log = logging.getLogger(__name__) if TYPE_CHECKING: from typing_extensions import Self - from .channel import TextChannel, DMChannel, GroupChannel, PartialMessageable + from .channel import TextChannel, DMChannel, GroupChannel, PartialMessageable, VoiceChannel from .handlers import CaptchaHandler from .threads import Thread from .file import File @@ -110,7 +110,7 @@ if TYPE_CHECKING: T = TypeVar('T') BE = TypeVar('BE', bound=BaseException) Response = Coroutine[Any, Any, T] - MessageableChannel = Union[TextChannel, Thread, DMChannel, GroupChannel, PartialMessageable] + MessageableChannel = Union[TextChannel, Thread, DMChannel, GroupChannel, PartialMessageable, VoiceChannel] async def json_or_text(response: aiohttp.ClientResponse) -> Union[Dict[str, Any], str]: @@ -356,7 +356,7 @@ class HTTPClient: session = self.__session if session: try: - session.connector._close() # type: ignore - Handled below + session.connector._close() # type: ignore # Handled below except AttributeError: pass @@ -582,7 +582,7 @@ class HTTPClient: # Captcha handling except HTTPException as e: try: - captcha_key = data['captcha_key'] # type: ignore - Handled below + captcha_key = data['captcha_key'] # type: ignore # Handled below except (KeyError, TypeError): raise else: @@ -593,7 +593,7 @@ class HTTPClient: raise else: previous = payload or {} - previous['captcha_key'] = await captcha_handler.fetch_token(data, self.proxy, self.proxy_auth) # type: ignore - data is json here + previous['captcha_key'] = await captcha_handler.fetch_token(data, self.proxy, self.proxy_auth) # type: ignore # data is json here kwargs['headers']['Content-Type'] = 'application/json' kwargs['data'] = utils._to_json(previous) @@ -839,9 +839,9 @@ class HTTPClient: try: msg = data[0] except IndexError: - raise NotFound(_FakeResponse('Not Found', 404), 'message not found') # type: ignore - _FakeResponse is not a real response + raise NotFound(_FakeResponse('Not Found', 404), 'message not found') # type: ignore # _FakeResponse is not a real response if int(msg['id']) != message_id: - raise NotFound(_FakeResponse('Not Found', 404), 'message not found') # type: ignore - _FakeResponse is not a real Response + raise NotFound(_FakeResponse('Not Found', 404), 'message not found') # type: ignore # _FakeResponse is not a real Response return msg @@ -1347,8 +1347,22 @@ class HTTPClient: return self.request(Route('POST', '/guilds/templates/{code}', code=code), json=payload) - def get_bans(self, guild_id: Snowflake) -> Response[List[guild.Ban]]: - return self.request(Route('GET', '/guilds/{guild_id}/bans', guild_id=guild_id)) + def get_bans( + self, + guild_id: Snowflake, + limit: int, + before: Optional[Snowflake] = None, + after: Optional[Snowflake] = None, + ) -> Response[List[guild.Ban]]: + params: Dict[str, Any] = {} + if limit != 1000: + params['limit'] = limit + if before is not None: + params['before'] = before + if after is not None: + params['after'] = after + + return self.request(Route('GET', '/guilds/{guild_id}/bans', guild_id=guild_id), params=params) def get_ban(self, user_id: Snowflake, guild_id: Snowflake) -> Response[guild.Ban]: return self.request(Route('GET', '/guilds/{guild_id}/bans/{user_id}', guild_id=guild_id, user_id=user_id)) @@ -1467,7 +1481,7 @@ class HTTPClient: self, guild_id: Snowflake, sticker_id: Snowflake, - payload: sticker.EditGuildSticker, + payload: Dict[str, Any], reason: Optional[str], ) -> Response[sticker.GuildSticker]: return self.request( diff --git a/discord/interactions.py b/discord/interactions.py index b102f82b9..c456b4fa3 100644 --- a/discord/interactions.py +++ b/discord/interactions.py @@ -27,10 +27,11 @@ from __future__ import annotations from typing import Optional, TYPE_CHECKING, Union from .enums import InteractionType, try_enum +from .mixins import Hashable from .utils import cached_slot_property, find, MISSING if TYPE_CHECKING: - from .channel import DMChannel, GroupChannel, TextChannel + from .channel import DMChannel, GroupChannel, TextChannel, VoiceChannel from .guild import Guild from .message import Message from .modal import Modal @@ -40,7 +41,7 @@ if TYPE_CHECKING: from .types.user import User as UserPayload from .user import BaseUser, ClientUser - MessageableChannel = Union[TextChannel, Thread, DMChannel, GroupChannel] + MessageableChannel = Union[TextChannel, Thread, DMChannel, GroupChannel, VoiceChannel] # fmt: off __all__ = ( @@ -49,7 +50,7 @@ __all__ = ( # fmt: on -class Interaction: +class Interaction(Hashable): """Represents an interaction. .. versionadded:: 2.0 diff --git a/discord/invite.py b/discord/invite.py index 522709709..3ddb57c24 100644 --- a/discord/invite.py +++ b/discord/invite.py @@ -477,7 +477,7 @@ class Invite(Hashable): channel = state.get_channel(getattr(channel, 'id', None)) or channel if message is not None: - data['message'] = message # type: ignore - Not a real field + data['message'] = message # type: ignore # Not a real field return cls(state=state, data=data, guild=guild, channel=channel, welcome_screen=welcome_screen) # type: ignore @@ -564,7 +564,7 @@ class Invite(Hashable): """ self.scheduled_event_id = scheduled_event.id try: - self.scheduled_event = self.guild.get_scheduled_event(scheduled_event.id) # type: ignore - handled below + self.scheduled_event = self.guild.get_scheduled_event(scheduled_event.id) # type: ignore # handled below except AttributeError: self.scheduled_event = None diff --git a/discord/iterators.py b/discord/iterators.py index f10eb9db9..929272f18 100644 --- a/discord/iterators.py +++ b/discord/iterators.py @@ -246,7 +246,7 @@ class FakeCommandIterator: channel = await item._get_channel() # type: ignore item = None text = None - if not channel.recipient.bot: # type: ignore - Type checker cannot understand this + if not channel.recipient.bot: # type: ignore # Type checker cannot understand this raise TypeError('User is not a bot') return channel, text, item # type: ignore diff --git a/discord/member.py b/discord/member.py index 04df7c456..3c9e2c55f 100644 --- a/discord/member.py +++ b/discord/member.py @@ -28,7 +28,7 @@ import datetime import inspect import itertools from operator import attrgetter -from typing import Any, Callable, Collection, Coroutine, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Union, Type +from typing import Any, Awaitable, Callable, Collection, Dict, List, Optional, TYPE_CHECKING, Tuple, Union, Type import discord.abc @@ -331,7 +331,7 @@ class Member(discord.abc.Messageable, discord.abc.Connectable, _UserTag): default_avatar: Asset avatar: Optional[Asset] dm_channel: Optional[DMChannel] - create_dm: Callable[[], Coroutine[Any, Any, DMChannel]] + create_dm: Callable[[], Awaitable[DMChannel]] mutual_guilds: List[Guild] public_flags: PublicUserFlags banner: Optional[Asset] @@ -668,8 +668,11 @@ class Member(discord.abc.Messageable, discord.abc.Connectable, _UserTag): channel permission overwrites. For 100% accurate permission calculation, please use :meth:`abc.GuildChannel.permissions_for`. - This does take into consideration guild ownership and the - administrator implication. + This does take into consideration guild ownership, the + administrator implication, and whether the member is timed out. + + .. versionchanged:: 2.0 + Member timeouts are taken into consideration. """ if self.guild.owner_id == self.id: @@ -682,6 +685,9 @@ class Member(discord.abc.Messageable, discord.abc.Connectable, _UserTag): if base.administrator: return Permissions.all() + if self.is_timed_out(): + base.value &= Permissions._timeout_mask() + return base @property @@ -767,7 +773,7 @@ class Member(discord.abc.Messageable, discord.abc.Connectable, _UserTag): Can now pass ``None`` to ``voice_channel`` to kick a member from voice. .. versionchanged:: 2.0 - The newly member is now optionally returned, if applicable. + The newly updated member is now optionally returned, if applicable. Parameters ----------- @@ -936,7 +942,9 @@ class Member(discord.abc.Messageable, discord.abc.Connectable, _UserTag): """ await self.edit(voice_channel=channel, reason=reason) - async def timeout(self, when: Union[datetime.timedelta, datetime.datetime], /, *, reason: Optional[str] = None) -> None: + async def timeout( + self, until: Optional[Union[datetime.timedelta, datetime.datetime]], /, *, reason: Optional[str] = None + ) -> None: """|coro| Applies a time out to a member until the specified date time or for the @@ -949,26 +957,28 @@ class Member(discord.abc.Messageable, discord.abc.Connectable, _UserTag): Parameters ----------- - when: Union[:class:`datetime.timedelta`, :class:`datetime.datetime`] + until: Optional[Union[:class:`datetime.timedelta`, :class:`datetime.datetime`]] If this is a :class:`datetime.timedelta` then it represents the amount of time the member should be timed out for. If this is a :class:`datetime.datetime` - then it's when the member's timeout should expire. Note that the API only allows - for timeouts up to 28 days. + then it's when the member's timeout should expire. If ``None`` is passed then the + timeout is removed. Note that the API only allows for timeouts up to 28 days. reason: Optional[:class:`str`] The reason for doing this action. Shows up on the audit log. Raises ------- TypeError - The ``when`` parameter was the wrong type of the datetime was not timezone-aware. + The ``until`` parameter was the wrong type of the datetime was not timezone-aware. """ - if isinstance(when, datetime.timedelta): - timed_out_until = utils.utcnow() + when - elif isinstance(when, datetime.datetime): - timed_out_until = when + if until is None: + timed_out_until = None + elif isinstance(until, datetime.timedelta): + timed_out_until = utils.utcnow() + until + elif isinstance(until, datetime.datetime): + timed_out_until = until else: - raise TypeError(f'expected datetime.datetime or datetime.timedelta not {when.__class__!r}') + raise TypeError(f'expected None, datetime.datetime, or datetime.timedelta not {until.__class__!r}') await self.edit(timed_out_until=timed_out_until, reason=reason) diff --git a/discord/message.py b/discord/message.py index 909ffa279..e1341b9d3 100644 --- a/discord/message.py +++ b/discord/message.py @@ -40,7 +40,6 @@ from typing import ( Callable, Tuple, ClassVar, - Optional, Type, overload, ) @@ -78,6 +77,8 @@ if TYPE_CHECKING: MessageActivity as MessageActivityPayload, ) + from .types.interactions import MessageInteraction as MessageInteractionPayload + from .types.components import Component as ComponentPayload from .types.threads import ThreadArchiveDuration from .types.member import ( @@ -88,7 +89,7 @@ if TYPE_CHECKING: from .types.embed import Embed as EmbedPayload from .types.gateway import MessageReactionRemoveEvent, MessageUpdateEvent from .abc import Snowflake - from .abc import GuildChannel, PartialMessageableChannel, MessageableChannel + from .abc import GuildChannel, MessageableChannel from .components import Component from .state import ConnectionState from .channel import TextChannel @@ -343,7 +344,7 @@ class Attachment(Hashable): """ data = await self.read(use_cached=use_cached) - return File(io.BytesIO(data), filename=self.filename, spoiler=spoiler) + return File(io.BytesIO(data), filename=self.filename, description=self.description, spoiler=spoiler) def to_dict(self) -> AttachmentPayload: result: AttachmentPayload = { @@ -509,7 +510,7 @@ class MessageReference: result['guild_id'] = self.guild_id if self.fail_if_not_exists is not None: result['fail_if_not_exists'] = self.fail_if_not_exists - return result # type: ignore - Type checker doesn't understand these are the same + return result # type: ignore # Type checker doesn't understand these are the same to_message_reference_dict = to_dict @@ -573,13 +574,16 @@ class PartialMessage(Hashable): def __init__(self, *, channel: MessageableChannel, id: int) -> None: if not isinstance(channel, PartialMessageable) and channel.type not in ( ChannelType.text, + ChannelType.voice, ChannelType.news, ChannelType.private, ChannelType.news_thread, ChannelType.public_thread, ChannelType.private_thread, ): - raise TypeError(f'Expected PartialMessageable, TextChannel, DMChannel or Thread not {type(channel)!r}') + raise TypeError( + f'expected PartialMessageable, TextChannel, VoiceChannel, DMChannel or Thread not {type(channel)!r}' + ) self.channel: MessageableChannel = channel self._state: ConnectionState = channel._state @@ -1241,7 +1245,7 @@ class Message(PartialMessage, Hashable): .. versionadded:: 2.0 interaction: Optional[:class:`Interaction`] - The interaction the message is replying to, if applicable. + The interaction that this message is a response to. .. versionadded:: 2.0 """ @@ -1295,7 +1299,8 @@ class Message(PartialMessage, Hashable): channel: MessageableChannel, data: MessagePayload, ) -> None: - super().__init__(channel=channel, id=int(data['id'])) + self.channel: MessageableChannel = channel + self.id: int = int(data['id']) self._state: ConnectionState = state self.webhook_id: Optional[int] = utils._get_as_snowflake(data, 'webhook_id') self.application_id: Optional[int] = utils._get_as_snowflake(data, 'application_id') @@ -1390,17 +1395,17 @@ class Message(PartialMessage, Hashable): reaction = utils.find(lambda r: r.emoji == emoji, self.reactions) if reaction is None: - # already removed? + # Already removed? raise ValueError('Emoji already removed?') - # if reaction isn't in the list, we crash. This means discord + # If reaction isn't in the list, we crash; this means Discord # sent bad data, or we stored improperly reaction.count -= 1 if user_id == self._state.self_id: reaction.me = False if reaction.count == 0: - # this raises ValueError if something went wrong as well. + # This raises ValueError if something went wrong as well self.reactions.remove(reaction) return reaction @@ -1411,7 +1416,7 @@ class Message(PartialMessage, Hashable): if str(reaction.emoji) == to_check: break else: - # didn't find anything so just return + # Didn't find anything so just return return del self.reactions[index] @@ -1430,7 +1435,7 @@ class Message(PartialMessage, Hashable): else: handler(self, value) - # clear the cached properties + # Clear the cached properties for attr in self._CACHED_SLOTS: try: delattr(self, attr) @@ -1484,9 +1489,9 @@ class Message(PartialMessage, Hashable): # The gateway now gives us full Member objects sometimes with the following keys # deaf, mute, joined_at, roles # For the sake of performance I'm going to assume that the only - # field that needs *updating* would be the joined_at field. + # field that needs *updating* would be the joined_at field # If there is no Member object (for some strange reason), then we can upgrade - # ourselves to a more "partial" member object. + # ourselves to a more "partial" member object author = self.author try: # Update member reference @@ -1540,8 +1545,8 @@ class Message(PartialMessage, Hashable): def _handle_components(self, components: List[ComponentPayload]): self.components = [_component_factory(d, self) for d in components] - def _handle_interaction(self, interaction: Dict[str, Any]): - self.interaction = Interaction._from_message(self, **interaction) + def _handle_interaction(self, data: MessageInteractionPayload): + self.interaction = Interaction._from_message(self, **data) def _rebind_cached_references(self, new_guild: Guild, new_channel: Union[TextChannel, Thread]) -> None: self.guild = new_guild diff --git a/discord/opus.py b/discord/opus.py index 33641554e..9eda3f0b3 100644 --- a/discord/opus.py +++ b/discord/opus.py @@ -217,7 +217,8 @@ def _load_default() -> bool: _filename = os.path.join(_basedir, 'bin', f'libopus-0.{_target}.dll') _lib = libopus_loader(_filename) else: - _lib = libopus_loader(ctypes.util.find_library('opus')) + # This is handled in the exception case + _lib = libopus_loader(ctypes.util.find_library('opus')) # type: ignore except Exception: _lib = None diff --git a/discord/permissions.py b/discord/permissions.py index 312108820..6cef95270 100644 --- a/discord/permissions.py +++ b/discord/permissions.py @@ -151,6 +151,13 @@ class Permissions(BaseFlags): """ return cls(0b11111111111111111111111111111111111111111) + @classmethod + def _timeout_mask(cls) -> int: + p = cls.all() + p.view_channel = False + p.read_message_history = False + return ~p.value + @classmethod def all_channel(cls) -> Self: """A :class:`Permissions` with all channel-specific permissions set to @@ -691,7 +698,7 @@ class PermissionOverwrite: send_messages_in_threads: Optional[bool] external_stickers: Optional[bool] use_external_stickers: Optional[bool] - start_embedded_activities: Optional[bool] + use_embedded_activities: Optional[bool] moderate_members: Optional[bool] timeout_members: Optional[bool] diff --git a/discord/player.py b/discord/player.py index 6320844d8..e223f6e62 100644 --- a/discord/player.py +++ b/discord/player.py @@ -163,7 +163,7 @@ class FFmpegAudio(AudioSource): kwargs.update(subprocess_kwargs) self._process: subprocess.Popen = self._spawn_process(args, **kwargs) - self._stdout: IO[bytes] = self._process.stdout # type: ignore - process stdout is explicitly set + self._stdout: IO[bytes] = self._process.stdout # type: ignore # process stdout is explicitly set self._stdin: Optional[IO[bytes]] = None self._pipe_thread: Optional[threading.Thread] = None diff --git a/discord/profile.py b/discord/profile.py index 5c392c9dc..8f74f5073 100644 --- a/discord/profile.py +++ b/discord/profile.py @@ -107,7 +107,7 @@ class Profile: application = data.get('application', {}) install_params = application.get('install_params', {}) self.application_id = app_id = utils._get_as_snowflake(application, 'id') - self.install_url = application.get('custom_install_url') if not install_params else utils.oauth_url(app_id, permissions=Permissions(int(install_params.get('permissions', 0))), scopes=install_params.get('scopes', utils.MISSING)) # type: ignore - app_id is always present here + self.install_url = application.get('custom_install_url') if not install_params else utils.oauth_url(app_id, permissions=Permissions(int(install_params.get('permissions', 0))), scopes=install_params.get('scopes', utils.MISSING)) # type: ignore # app_id is always present here def _parse_mutual_guilds(self, mutual_guilds) -> Optional[List[Guild]]: if mutual_guilds is None: @@ -118,7 +118,7 @@ class Profile: def get_guild(guild): return state._get_guild(int(guild['id'])) or Object(id=int(guild['id'])) - return list(filter(None, map(get_guild, mutual_guilds))) # type: ignore - Lying for better developer UX + return list(filter(None, map(get_guild, mutual_guilds))) # type: ignore # Lying for better developer UX def _parse_mutual_friends(self, mutual_friends) -> Optional[List[User]]: if mutual_friends is None: diff --git a/discord/reaction.py b/discord/reaction.py index ec53ff644..3653dff73 100644 --- a/discord/reaction.py +++ b/discord/reaction.py @@ -23,7 +23,7 @@ DEALINGS IN THE SOFTWARE. """ from __future__ import annotations -from typing import Any, TYPE_CHECKING, AsyncIterator, Union, Optional +from typing import TYPE_CHECKING, AsyncIterator, Union, Optional from .user import User from .object import Object diff --git a/discord/role.py b/discord/role.py index 03d63375d..be996be13 100644 --- a/discord/role.py +++ b/discord/role.py @@ -38,8 +38,6 @@ __all__ = ( ) if TYPE_CHECKING: - from typing_extensions import Self - import datetime from .types.role import ( Role as RolePayload, diff --git a/discord/state.py b/discord/state.py index 819aa0a78..322572108 100644 --- a/discord/state.py +++ b/discord/state.py @@ -439,7 +439,7 @@ class ConnectionState: self._status: Optional[str] = status if cache_flags._empty: - self.store_user = self.create_user # type: ignore + self.store_user = self.create_user # type: ignore # Purposeful reassignment self.parsers: Dict[str, Callable[[Any], None]] self.parsers = parsers = {} @@ -566,7 +566,7 @@ class ConnectionState: def _update_references(self, ws: DiscordWebSocket) -> None: for vc in self.voice_clients: - vc.main_ws = ws # type: ignore - Silencing the unknown attribute (ok at runtime). + vc.main_ws = ws # type: ignore # Silencing the unknown attribute (ok at runtime). def _add_interaction(self, interaction: Interaction) -> None: self._interactions[interaction.id] = interaction @@ -832,13 +832,13 @@ class ConnectionState: data.get('merged_members', []), extra_data['merged_presences'].get('guilds', []), ): - guild_data['settings'] = utils.find( # type: ignore - This key does not actually exist in the payload + guild_data['settings'] = utils.find( # type: ignore # This key does not actually exist in the payload lambda i: i['guild_id'] == guild_data['id'], guild_settings, ) or {'guild_id': guild_data['id']} for presence in merged_presences: - presence['user'] = {'id': presence['user_id']} # type: ignore - :( + presence['user'] = {'id': presence['user_id']} # type: ignore # :( voice_states = guild_data.setdefault('voice_states', []) voice_states.extend(guild_extra.get('voice_states', [])) @@ -923,8 +923,7 @@ class ConnectionState: if message.call is not None: self._call_message_cache[message.id] = message - # We ensure that the channel is either a TextChannel or Thread - if channel and channel.__class__ in (TextChannel, Thread): + if channel: channel.last_message_id = message.id # type: ignore def parse_message_delete(self, data: gw.MessageDeleteEvent) -> None: @@ -968,7 +967,7 @@ class ConnectionState: def parse_message_reaction_add(self, data: gw.MessageReactionAddEvent) -> None: emoji = data['emoji'] emoji_id = utils._get_as_snowflake(emoji, 'id') - emoji = PartialEmoji.with_state(self, id=emoji_id, animated=emoji.get('animated', False), name=emoji['name']) + emoji = PartialEmoji.with_state(self, id=emoji_id, animated=emoji.get('animated', False), name=emoji['name']) # type: ignore raw = RawReactionActionEvent(data, emoji, 'REACTION_ADD') member_data = data.get('member') @@ -1182,7 +1181,7 @@ class ConnectionState: channel = guild.get_channel(channel_id) if channel is not None: old_channel = copy.copy(channel) - channel._update(guild, data) # type: ignore - the data payload varies based on the channel type. + channel._update(guild, data) # type: ignore # the data payload varies based on the channel type. self.dispatch('guild_channel_update', old_channel, channel) else: _log.debug('CHANNEL_UPDATE referencing an unknown channel ID: %s. Discarding.', channel_id) @@ -1442,7 +1441,12 @@ class ConnectionState: self.dispatch('member_update', old_member, member) else: if self.member_cache_flags.other or user_id == self.self_id or guild.chunked: - member = Member(data=data, guild=guild, state=self) # type: ignore - The data is close enough + member = Member(data=data, guild=guild, state=self) # type: ignore # The data is close enough + # Force an update on the inner user if necessary + user_update = member._update_inner_user(user) + if user_update: + self.dispatch('user_update', user_update[0], user_update[1]) + guild._add_member(member) _log.debug('GUILD_MEMBER_UPDATE referencing an unknown member ID: %s.', user_id) @@ -1721,7 +1725,7 @@ class ConnectionState: delay: Union[int, float] = MISSING, ) -> Union[Optional[List[Member]], asyncio.Future[Optional[List[Member]]]]: if not guild.me: - await guild.query_members(user_ids=[self.self_id], cache=True) # type: ignore - self_id is always present here + await guild.query_members(user_ids=[self.self_id], cache=True) # type: ignore # self_id is always present here if not force_scraping and any( { @@ -1747,7 +1751,7 @@ class ConnectionState: if wait: return await request.wait() - return request.get_future() # type: ignore - Honestly, I'm confused too + return request.get_future() # type: ignore # Honestly, I'm confused too @overload async def chunk_guild( @@ -1769,7 +1773,7 @@ class ConnectionState: channels: List[abcSnowflake] = MISSING, ) -> Union[asyncio.Future[Optional[List[Member]]], Optional[List[Member]]]: if not guild.me: - await guild.query_members(user_ids=[self.self_id], cache=True) # type: ignore - self_id is always present here + await guild.query_members(user_ids=[self.self_id], cache=True) # type: ignore # self_id is always present here request = self._scrape_requests.get(guild.id) if request is None: @@ -2165,7 +2169,7 @@ class ConnectionState: def parse_relationship_add(self, data) -> None: key = int(data['id']) - old = self.user.get_relationship(key) # type: ignore - self.user is always present here + old = self.user.get_relationship(key) # type: ignore # self.user is always present here new = Relationship(state=self, data=data) self._relationships[key] = new if old is not None: @@ -2184,7 +2188,7 @@ class ConnectionState: def parse_interaction_create(self, data) -> None: type, name, channel = self._interaction_cache.pop(data['nonce'], (0, None, None)) - i = Interaction._from_self(channel, type=type, user=self.user, name=name, **data) # type: ignore - self.user is always present here + i = Interaction._from_self(channel, type=type, user=self.user, name=name, **data) # type: ignore # self.user is always present here self._interactions[i.id] = i self.dispatch('interaction', i) @@ -2192,7 +2196,7 @@ class ConnectionState: id = int(data['id']) i = self._interactions.get(id, None) if i is None: - i = Interaction(id, nonce=data['nonce'], user=self.user) # type: ignore - self.user is always present here + i = Interaction(id, nonce=data['nonce'], user=self.user) # type: ignore # self.user is always present here i.successful = True self.dispatch('interaction_finish', i) @@ -2200,7 +2204,7 @@ class ConnectionState: id = int(data['id']) i = self._interactions.pop(id, None) if i is None: - i = Interaction(id, nonce=data['nonce'], user=self.user) # type: ignore - self.user is always present here + i = Interaction(id, nonce=data['nonce'], user=self.user) # type: ignore # self.user is always present here i.successful = False self.dispatch('interaction_finish', i) @@ -2258,7 +2262,5 @@ class ConnectionState: if channel is not None: return channel - def create_message( - self, *, channel: Union[TextChannel, Thread, DMChannel, GroupChannel, PartialMessageable], data: MessagePayload - ) -> Message: + def create_message(self, *, channel: MessageableChannel, data: MessagePayload) -> Message: return Message(state=self, channel=channel, data=data) diff --git a/discord/sticker.py b/discord/sticker.py index 3face3e90..516cdb8fe 100644 --- a/discord/sticker.py +++ b/discord/sticker.py @@ -51,7 +51,6 @@ if TYPE_CHECKING: Sticker as StickerPayload, StandardSticker as StandardStickerPayload, GuildSticker as GuildStickerPayload, - EditGuildSticker, ) @@ -122,7 +121,7 @@ class StickerPack(Hashable): @property def banner(self) -> Optional[Asset]: """:class:`Asset`: The banner asset of the sticker pack.""" - return self._banner and Asset._from_sticker_banner(self._state, self._banner) # type: ignore - type-checker thinks _banner could be Literal[0] + return self._banner and Asset._from_sticker_banner(self._state, self._banner) # type: ignore def __repr__(self) -> str: return f'' @@ -491,7 +490,7 @@ class GuildSticker(Sticker): :class:`GuildSticker` The newly modified sticker. """ - payload: EditGuildSticker = {} + payload = {} if name is not MISSING: payload['name'] = name diff --git a/discord/threads.py b/discord/threads.py index 89ca26e9d..d2f0e7614 100644 --- a/discord/threads.py +++ b/discord/threads.py @@ -26,12 +26,11 @@ from __future__ import annotations from typing import Callable, Dict, Iterable, List, Optional, Union, TYPE_CHECKING from datetime import datetime -import time import asyncio import copy from .mixins import Hashable -from .abc import Messageable +from .abc import Messageable, _purge_helper from .enums import ChannelType, try_enum from .errors import ClientException, InvalidData from .utils import MISSING, parse_time, snowflake_time, _get_as_snowflake @@ -384,7 +383,7 @@ class Thread(Messageable, Hashable): raise ClientException('Parent channel not found') return parent.permissions_for(obj) - async def delete_messages(self, messages: Iterable[Snowflake], /) -> None: + async def delete_messages(self, messages: Iterable[Snowflake], /, *, reason: Optional[str] = None) -> None: """|coro| Deletes a list of messages. This is similar to :meth:`Message.delete` @@ -402,6 +401,8 @@ class Thread(Messageable, Hashable): ----------- messages: Iterable[:class:`abc.Snowflake`] An iterable of messages denoting which ones to bulk delete. + reason: Optional[:class:`str`] + The reason for deleting the messages. Shows up on the audit log. Raises ------ @@ -416,7 +417,7 @@ class Thread(Messageable, Hashable): if len(messages) == 0: return # Do nothing - await self._state._delete_messages(self.id, messages) + await self._state._delete_messages(self.id, messages, reason=reason) async def purge( self, @@ -427,6 +428,7 @@ class Thread(Messageable, Hashable): after: Optional[SnowflakeTime] = None, around: Optional[SnowflakeTime] = None, oldest_first: Optional[bool] = False, + reason: Optional[str] = None, ) -> List[Message]: """|coro| @@ -464,6 +466,8 @@ class Thread(Messageable, Hashable): Same as ``around`` in :meth:`history`. oldest_first: Optional[:class:`bool`] Same as ``oldest_first`` in :meth:`history`. + reason: Optional[:class:`str`] + The reason for purging the messages. Shows up on the audit log. Raises ------- @@ -477,32 +481,16 @@ class Thread(Messageable, Hashable): List[:class:`.Message`] The list of messages that were deleted. """ - if check is MISSING: - check = lambda m: True - - state = self._state - channel_id = self.id - iterator = self.history(limit=limit, before=before, after=after, oldest_first=oldest_first, around=around) - ret: List[Message] = [] - count = 0 - - async for message in iterator: - if count == 50: - to_delete = ret[-50:] - await state._delete_messages(channel_id, to_delete) - count = 0 - - if not check(message): - continue - - count += 1 - ret.append(message) - - # Some messages remaining to poll - to_delete = ret[-count:] - await state._delete_messages(channel_id, to_delete) - - return ret + return await _purge_helper( + self, + limit=limit, + check=check, + before=before, + after=after, + around=around, + oldest_first=oldest_first, + reason=reason, + ) async def edit( self, diff --git a/discord/types/activity.py b/discord/types/activity.py index 24e8382ce..5902cce8a 100644 --- a/discord/types/activity.py +++ b/discord/types/activity.py @@ -25,6 +25,7 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations from typing import List, Literal, Optional, TypedDict +from typing_extensions import NotRequired from .user import User from .snowflake import Snowflake @@ -69,13 +70,10 @@ class ActivitySecrets(TypedDict, total=False): match: str -class _ActivityEmojiOptional(TypedDict, total=False): - id: Snowflake - animated: bool - - -class ActivityEmoji(_ActivityEmojiOptional): +class ActivityEmoji(TypedDict): name: str + id: NotRequired[Snowflake] + animated: NotRequired[bool] class ActivityButton(TypedDict): @@ -83,16 +81,13 @@ class ActivityButton(TypedDict): url: str -class _SendableActivityOptional(TypedDict, total=False): - url: Optional[str] - - ActivityType = Literal[0, 1, 2, 4, 5] -class SendableActivity(_SendableActivityOptional): +class SendableActivity(TypedDict): name: str type: ActivityType + url: NotRequired[Optional[str]] class _BaseActivity(SendableActivity): diff --git a/discord/types/appinfo.py b/discord/types/appinfo.py index e5e76cc46..01ea16dda 100644 --- a/discord/types/appinfo.py +++ b/discord/types/appinfo.py @@ -25,6 +25,7 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations from typing import TypedDict, List, Optional +from typing_extensions import NotRequired from .user import User from .team import Team @@ -38,39 +39,34 @@ class BaseAppInfo(TypedDict): icon: Optional[str] summary: str description: str - - -class _AppInfoOptional(TypedDict, total=False): - team: Team - guild_id: Snowflake - primary_sku_id: Snowflake - slug: str - terms_of_service_url: str - privacy_policy_url: str - hook: bool - max_participants: int - - -class _PartialAppInfoOptional(TypedDict, total=False): + cover_image: Optional[str] + flags: NotRequired[int] rpc_origins: List[str] - cover_image: str - hook: bool - terms_of_service_url: str - privacy_policy_url: str - max_participants: int - flags: int - -class PartialAppInfo(_PartialAppInfoOptional, BaseAppInfo): - pass - -class AppInfo(PartialAppInfo, _AppInfoOptional): +class AppInfo(BaseAppInfo): owner: User - integration_public: bool - integration_require_code_grant: bool - secret: str + bot_public: NotRequired[bool] + bot_require_code_grant: NotRequired[bool] + integration_public: NotRequired[bool] + integration_require_code_grant: NotRequired[bool] + team: NotRequired[Team] + guild_id: NotRequired[Snowflake] + primary_sku_id: NotRequired[Snowflake] + slug: NotRequired[str] + terms_of_service_url: NotRequired[str] + privacy_policy_url: NotRequired[str] + hook: NotRequired[bool] + max_participants: NotRequired[int] + interactions_endpoint_url: NotRequired[str] verification_state: int store_application_state: int rpc_application_state: int interactions_endpoint_url: str + + +class PartialAppInfo(BaseAppInfo, total=False): + hook: bool + terms_of_service_url: str + privacy_policy_url: str + max_participants: int diff --git a/discord/types/audit_log.py b/discord/types/audit_log.py index a041a3056..c2997596b 100644 --- a/discord/types/audit_log.py +++ b/discord/types/audit_log.py @@ -25,6 +25,7 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations from typing import List, Literal, Optional, TypedDict, Union +from typing_extensions import NotRequired from .webhook import Webhook from .guild import MFALevel, VerificationLevel, ExplicitContentFilterLevel, DefaultMessageNotificationLevel @@ -273,17 +274,14 @@ class AuditEntryInfo(TypedDict): role_name: str -class _AuditLogEntryOptional(TypedDict, total=False): - changes: List[AuditLogChange] - options: AuditEntryInfo - reason: str - - -class AuditLogEntry(_AuditLogEntryOptional): +class AuditLogEntry(TypedDict): target_id: Optional[str] user_id: Optional[Snowflake] id: Snowflake action_type: AuditLogEvent + changes: NotRequired[List[AuditLogChange]] + options: NotRequired[AuditEntryInfo] + reason: NotRequired[str] class AuditLog(TypedDict): diff --git a/discord/types/channel.py b/discord/types/channel.py index 101378949..3c02259f4 100644 --- a/discord/types/channel.py +++ b/discord/types/channel.py @@ -23,6 +23,8 @@ DEALINGS IN THE SOFTWARE. """ from typing import List, Literal, Optional, TypedDict, Union +from typing_extensions import NotRequired + from .user import PartialUser from .snowflake import Snowflake from .threads import ThreadMetadata, ThreadMember, ThreadArchiveDuration, ThreadType @@ -59,7 +61,7 @@ class PartialChannel(_BaseChannel): type: ChannelType -class _TextChannelOptional(TypedDict, total=False): +class _BaseTextChannel(_BaseGuildChannel, total=False): topic: str last_message_id: Optional[Snowflake] last_pin_timestamp: str @@ -67,56 +69,38 @@ class _TextChannelOptional(TypedDict, total=False): default_auto_archive_duration: ThreadArchiveDuration -class TextChannel(_BaseGuildChannel, _TextChannelOptional): +class TextChannel(_BaseTextChannel): type: Literal[0] -class NewsChannel(_BaseGuildChannel, _TextChannelOptional): +class NewsChannel(_BaseTextChannel): type: Literal[5] VideoQualityMode = Literal[1, 2] -class _VoiceChannelOptional(TypedDict, total=False): - rtc_region: Optional[str] - video_quality_mode: VideoQualityMode - - -class VoiceChannel(_BaseGuildChannel, _VoiceChannelOptional): +class VoiceChannel(_BaseTextChannel): type: Literal[2] bitrate: int user_limit: int + rtc_region: NotRequired[Optional[str]] + video_quality_mode: NotRequired[VideoQualityMode] class CategoryChannel(_BaseGuildChannel): type: Literal[4] -class StoreChannel(_BaseGuildChannel): - type: Literal[6] - - -class _StageChannelOptional(TypedDict, total=False): - rtc_region: Optional[str] - topic: str - - -class StageChannel(_BaseGuildChannel, _StageChannelOptional): +class StageChannel(_BaseGuildChannel): type: Literal[13] bitrate: int user_limit: int + rtc_region: NotRequired[Optional[str]] + topic: NotRequired[str] -class _ThreadChannelOptional(TypedDict, total=False): - member: ThreadMember - owner_id: Snowflake - rate_limit_per_user: int - last_message_id: Optional[Snowflake] - last_pin_timestamp: str - - -class ThreadChannel(_BaseChannel, _ThreadChannelOptional): +class ThreadChannel(_BaseChannel): type: Literal[10, 11, 12] guild_id: Snowflake parent_id: Snowflake @@ -127,9 +111,14 @@ class ThreadChannel(_BaseChannel, _ThreadChannelOptional): message_count: int member_count: int thread_metadata: ThreadMetadata + member: NotRequired[ThreadMember] + owner_id: NotRequired[Snowflake] + rate_limit_per_user: NotRequired[int] + last_message_id: NotRequired[Optional[Snowflake]] + last_pin_timestamp: NotRequired[str] -GuildChannel = Union[TextChannel, NewsChannel, VoiceChannel, CategoryChannel, StoreChannel, StageChannel, ThreadChannel] +GuildChannel = Union[TextChannel, NewsChannel, VoiceChannel, CategoryChannel, StageChannel, ThreadChannel] class DMChannel(_BaseChannel): diff --git a/discord/types/command.py b/discord/types/command.py index 412a8f126..77700d7f6 100644 --- a/discord/types/command.py +++ b/discord/types/command.py @@ -25,6 +25,7 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations from typing import List, Literal, TypedDict, Union +from typing_extensions import NotRequired, Required from .channel import ChannelType from .snowflake import Snowflake @@ -57,13 +58,10 @@ class _StringApplicationCommandOptionChoice(TypedDict): value: str -class _StringApplicationCommandOptionOptional(_BaseValueApplicationCommandOption, total=False): - choices: List[_StringApplicationCommandOptionChoice] - autocomplete: bool - - -class _StringApplicationCommandOption(_StringApplicationCommandOptionOptional): +class _StringApplicationCommandOption(_BaseApplicationCommandOption): type: Literal[3] + choices: NotRequired[List[_StringApplicationCommandOptionChoice]] + autocomplete: NotRequired[bool] class _IntegerApplicationCommandOptionChoice(TypedDict): @@ -71,27 +69,21 @@ class _IntegerApplicationCommandOptionChoice(TypedDict): value: int -class _IntegerApplicationCommandOptionOptional(_BaseValueApplicationCommandOption, total=False): +class _IntegerApplicationCommandOption(_BaseApplicationCommandOption, total=False): + type: Required[Literal[4]] min_value: int max_value: int choices: List[_IntegerApplicationCommandOptionChoice] autocomplete: bool -class _IntegerApplicationCommandOption(_IntegerApplicationCommandOptionOptional): - type: Literal[4] - - class _BooleanApplicationCommandOption(_BaseValueApplicationCommandOption): type: Literal[5] -class _ChannelApplicationCommandOptionChoiceOptional(_BaseApplicationCommandOption, total=False): - channel_types: List[ChannelType] - - -class _ChannelApplicationCommandOptionChoice(_ChannelApplicationCommandOptionChoiceOptional): +class _ChannelApplicationCommandOptionChoice(_BaseApplicationCommandOption): type: Literal[7] + channel_types: NotRequired[List[ChannelType]] class _NonChannelSnowflakeApplicationCommandOptionChoice(_BaseValueApplicationCommandOption): @@ -109,17 +101,14 @@ class _NumberApplicationCommandOptionChoice(TypedDict): value: float -class _NumberApplicationCommandOptionOptional(_BaseValueApplicationCommandOption, total=False): +class _NumberApplicationCommandOption(_BaseValueApplicationCommandOption, total=False): + type: Required[Literal[10]] min_value: float max_value: float choices: List[_NumberApplicationCommandOptionChoice] autocomplete: bool -class _NumberApplicationCommandOption(_NumberApplicationCommandOptionOptional): - type: Literal[10] - - _ValueApplicationCommandOption = Union[ _StringApplicationCommandOption, _IntegerApplicationCommandOption, @@ -148,7 +137,8 @@ class _BaseApplicationCommand(TypedDict): version: Snowflake -class _ChatInputApplicationCommandOptional(_BaseApplicationCommand, total=False): +class _ChatInputApplicationCommand(_BaseApplicationCommand, total=False): + description: Required[str] type: Literal[1] options: Union[ List[_ValueApplicationCommandOption], @@ -156,10 +146,6 @@ class _ChatInputApplicationCommandOptional(_BaseApplicationCommand, total=False) ] -class _ChatInputApplicationCommand(_ChatInputApplicationCommandOptional): - description: str - - class _BaseContextMenuApplicationCommand(_BaseApplicationCommand): description: Literal[""] diff --git a/discord/types/components.py b/discord/types/components.py index 92c41a2fb..9c197bebd 100644 --- a/discord/types/components.py +++ b/discord/types/components.py @@ -24,7 +24,9 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations -from typing import List, Literal, Optional, TypedDict, Union +from typing import List, Literal, TypedDict, Union +from typing_extensions import NotRequired + from .emoji import PartialEmoji ComponentType = Literal[1, 2, 3, 4] @@ -37,56 +39,44 @@ class ActionRow(TypedDict): components: List[Component] -class _ButtonComponentOptional(TypedDict, total=False): - custom_id: str - url: str - disabled: bool - emoji: PartialEmoji - label: str - - -class ButtonComponent(_ButtonComponentOptional): +class ButtonComponent(TypedDict): type: Literal[2] style: ButtonStyle + custom_id: NotRequired[str] + url: NotRequired[str] + disabled: NotRequired[bool] + emoji: NotRequired[PartialEmoji] + label: NotRequired[str] -class _SelectMenuOptional(TypedDict, total=False): - placeholder: str - min_values: int - max_values: int - disabled: bool - - -class _SelectOptionsOptional(TypedDict, total=False): - description: str - emoji: PartialEmoji - - -class SelectOption(_SelectOptionsOptional): +class SelectOption(TypedDict): label: str value: str default: bool + description: NotRequired[str] + emoji: NotRequired[PartialEmoji] -class SelectMenu(_SelectMenuOptional): +class SelectMenu(TypedDict): type: Literal[3] custom_id: str options: List[SelectOption] + placeholder: NotRequired[str] + min_values: NotRequired[int] + max_values: NotRequired[int] + disabled: NotRequired[bool] -class _TextInputOptional(TypedDict, total=False): - placeholder: str - value: Optional[str] - required: bool - min_length: int - max_length: int - - -class TextInput(_TextInputOptional): +class TextInput(TypedDict): type: Literal[4] custom_id: str style: TextStyle label: str + placeholder: NotRequired[str] + value: NotRequired[str] + required: NotRequired[bool] + min_length: NotRequired[int] + max_length: NotRequired[int] Component = Union[ActionRow, ButtonComponent, SelectMenu, TextInput] diff --git a/discord/types/embed.py b/discord/types/embed.py index de38bd276..f2f1c5a9f 100644 --- a/discord/types/embed.py +++ b/discord/types/embed.py @@ -23,36 +23,28 @@ DEALINGS IN THE SOFTWARE. """ from typing import List, Literal, TypedDict +from typing_extensions import NotRequired, Required -class _EmbedFooterOptional(TypedDict, total=False): - icon_url: str - proxy_icon_url: str - - -class EmbedFooter(_EmbedFooterOptional): +class EmbedFooter(TypedDict): text: str + icon_url: NotRequired[str] + proxy_icon_url: NotRequired[str] -class _EmbedFieldOptional(TypedDict, total=False): - inline: bool - - -class EmbedField(_EmbedFieldOptional): +class EmbedField(TypedDict): name: str value: str + inline: NotRequired[bool] -class _EmbedThumbnailOptional(TypedDict, total=False): +class EmbedThumbnail(TypedDict, total=False): + url: Required[str] proxy_url: str height: int width: int -class EmbedThumbnail(_EmbedThumbnailOptional): - url: str - - class EmbedVideo(TypedDict, total=False): url: str proxy_url: str @@ -60,31 +52,25 @@ class EmbedVideo(TypedDict, total=False): width: int -class _EmbedImageOptional(TypedDict, total=False): +class EmbedImage(TypedDict, total=False): + url: Required[str] proxy_url: str height: int width: int -class EmbedImage(_EmbedImageOptional): - url: str - - class EmbedProvider(TypedDict, total=False): name: str url: str -class _EmbedAuthorOptional(TypedDict, total=False): +class EmbedAuthor(TypedDict, total=False): + name: Required[str] url: str icon_url: str proxy_icon_url: str -class EmbedAuthor(_EmbedAuthorOptional): - name: str - - EmbedType = Literal['rich', 'image', 'video', 'gifv', 'article', 'link'] diff --git a/discord/types/gateway.py b/discord/types/gateway.py index 8353c9f63..69f0d0dd0 100644 --- a/discord/types/gateway.py +++ b/discord/types/gateway.py @@ -23,6 +23,7 @@ DEALINGS IN THE SOFTWARE. """ from typing import List, Literal, Optional, TypedDict, Union +from typing_extensions import NotRequired, Required from .activity import PartialPresenceUpdate from .voice import GuildVoiceState @@ -92,68 +93,50 @@ ResumedEvent = Literal[None] MessageCreateEvent = Message -class _MessageDeleteEventOptional(TypedDict, total=False): - guild_id: Snowflake - - -class MessageDeleteEvent(_MessageDeleteEventOptional): +class MessageDeleteEvent(TypedDict): id: Snowflake channel_id: Snowflake + guild_id: NotRequired[Snowflake] -class _MessageDeleteBulkEventOptional(TypedDict, total=False): - guild_id: Snowflake - - -class MessageDeleteBulkEvent(_MessageDeleteBulkEventOptional): +class MessageDeleteBulkEvent(TypedDict): ids: List[Snowflake] channel_id: Snowflake + guild_id: NotRequired[Snowflake] class MessageUpdateEvent(Message): channel_id: Snowflake -class _MessageReactionAddEventOptional(TypedDict, total=False): - member: MemberWithUser - guild_id: Snowflake - - -class MessageReactionAddEvent(_MessageReactionAddEventOptional): +class MessageReactionAddEvent(TypedDict): user_id: Snowflake channel_id: Snowflake message_id: Snowflake emoji: PartialEmoji + member: NotRequired[MemberWithUser] + guild_id: NotRequired[Snowflake] -class _MessageReactionRemoveEventOptional(TypedDict, total=False): - guild_id: Snowflake - - -class MessageReactionRemoveEvent(_MessageReactionRemoveEventOptional): +class MessageReactionRemoveEvent(TypedDict): user_id: Snowflake channel_id: Snowflake message_id: Snowflake emoji: PartialEmoji + guild_id: NotRequired[Snowflake] -class _MessageReactionRemoveAllEventOptional(TypedDict, total=False): - guild_id: Snowflake - - -class MessageReactionRemoveAllEvent(_MessageReactionRemoveAllEventOptional): +class MessageReactionRemoveAllEvent(TypedDict): message_id: Snowflake channel_id: Snowflake + guild_id: NotRequired[Snowflake] -class _MessageReactionRemoveEmojiEventOptional(TypedDict, total=False): - guild_id: Snowflake - - -class MessageReactionRemoveEmojiEvent(_MessageReactionRemoveEmojiEventOptional): +class MessageReactionRemoveEmojiEvent(TypedDict): emoji: PartialEmoji message_id: Snowflake channel_id: Snowflake + guild_id: NotRequired[Snowflake] InteractionCreateEvent = Interaction @@ -162,15 +145,7 @@ InteractionCreateEvent = Interaction UserUpdateEvent = User -class _InviteCreateEventOptional(TypedDict, total=False): - guild_id: Snowflake - inviter: User - target_type: InviteTargetType - target_user: User - target_application: PartialAppInfo - - -class InviteCreateEvent(_InviteCreateEventOptional): +class InviteCreateEvent(TypedDict): channel_id: Snowflake code: str created_at: str @@ -178,15 +153,17 @@ class InviteCreateEvent(_InviteCreateEventOptional): max_uses: int temporary: bool uses: Literal[0] + guild_id: NotRequired[Snowflake] + inviter: NotRequired[User] + target_type: NotRequired[InviteTargetType] + target_user: NotRequired[User] + target_application: NotRequired[PartialAppInfo] -class _InviteDeleteEventOptional(TypedDict, total=False): - guild_id: Snowflake - - -class InviteDeleteEvent(_InviteDeleteEventOptional): +class InviteDeleteEvent(TypedDict): channel_id: Snowflake code: str + guild_id: NotRequired[Snowflake] class _ChannelEvent(TypedDict): @@ -197,24 +174,17 @@ class _ChannelEvent(TypedDict): ChannelCreateEvent = ChannelUpdateEvent = ChannelDeleteEvent = _ChannelEvent -class _ChannelPinsUpdateEventOptional(TypedDict, total=False): - guild_id: Snowflake - last_pin_timestamp: Optional[str] - - -class ChannelPinsUpdateEvent(_ChannelPinsUpdateEventOptional): +class ChannelPinsUpdateEvent(TypedDict): channel_id: Snowflake + guild_id: NotRequired[Snowflake] + last_pin_timestamp: NotRequired[Optional[str]] -class _ThreadCreateEventOptional(TypedDict, total=False): +class ThreadCreateEvent(Thread, total=False): newly_created: bool members: List[ThreadMember] -class ThreadCreateEvent(Thread, _ThreadCreateEventOptional): - ... - - ThreadUpdateEvent = Thread @@ -225,29 +195,23 @@ class ThreadDeleteEvent(TypedDict): type: ChannelType -class _ThreadListSyncEventOptional(TypedDict, total=False): - channel_ids: List[Snowflake] - - -class ThreadListSyncEvent(_ThreadListSyncEventOptional): +class ThreadListSyncEvent(TypedDict): guild_id: Snowflake threads: List[Thread] members: List[ThreadMember] + channel_ids: NotRequired[List[Snowflake]] class ThreadMemberUpdate(ThreadMember): guild_id: Snowflake -class _ThreadMembersUpdateOptional(TypedDict, total=False): - added_members: List[ThreadMember] - removed_member_ids: List[Snowflake] - - -class ThreadMembersUpdate(_ThreadMembersUpdateOptional): +class ThreadMembersUpdate(TypedDict): id: Snowflake guild_id: Snowflake member_count: int + added_members: NotRequired[List[ThreadMember]] + removed_member_ids: NotRequired[List[Snowflake]] class GuildMemberAddEvent(MemberWithUser): @@ -259,21 +223,18 @@ class GuildMemberRemoveEvent(TypedDict): user: User -class _GuildMemberUpdateEventOptional(TypedDict, total=False): - nick: str - premium_since: Optional[str] - deaf: bool - mute: bool - pending: bool - communication_disabled_until: str - - -class GuildMemberUpdateEvent(_GuildMemberUpdateEventOptional): +class GuildMemberUpdateEvent(TypedDict): guild_id: Snowflake roles: List[Snowflake] user: User avatar: Optional[str] joined_at: Optional[str] + nick: NotRequired[str] + premium_since: NotRequired[Optional[str]] + deaf: NotRequired[bool] + mute: NotRequired[bool] + pending: NotRequired[bool] + communication_disabled_until: NotRequired[str] class GuildEmojisUpdateEvent(TypedDict): @@ -311,24 +272,22 @@ class GuildRoleDeleteEvent(TypedDict): GuildRoleCreateEvent = GuildRoleUpdateEvent = _GuildRoleEvent -class _GuildMembersChunkEventOptional(TypedDict, total=False): - not_found: List[Snowflake] - presences: List[PresenceUpdateEvent] - nonce: str - - -class GuildMembersChunkEvent(_GuildMembersChunkEventOptional): +class GuildMembersChunkEvent(TypedDict): guild_id: Snowflake members: List[MemberWithUser] chunk_index: int chunk_count: int + not_found: NotRequired[List[Snowflake]] + presences: NotRequired[List[PresenceUpdateEvent]] + nonce: NotRequired[str] class GuildIntegrationsUpdateEvent(TypedDict): guild_id: Snowflake -class _IntegrationEventOptional(BaseIntegration, total=False): +class _IntegrationEvent(BaseIntegration, total=False): + guild_id: Required[Snowflake] role_id: Optional[Snowflake] enable_emoticons: bool subscriber_count: int @@ -336,20 +295,13 @@ class _IntegrationEventOptional(BaseIntegration, total=False): application: IntegrationApplication -class _IntegrationEvent(_IntegrationEventOptional): - guild_id: Snowflake - - IntegrationCreateEvent = IntegrationUpdateEvent = _IntegrationEvent -class _IntegrationDeleteEventOptional(TypedDict, total=False): - application_id: Snowflake - - -class IntegrationDeleteEvent(_IntegrationDeleteEventOptional): +class IntegrationDeleteEvent(TypedDict): id: Snowflake guild_id: Snowflake + application_id: NotRequired[Snowflake] class WebhooksUpdateEvent(TypedDict): @@ -379,12 +331,9 @@ class VoiceServerUpdateEvent(TypedDict): endpoint: Optional[str] -class _TypingStartEventOptional(TypedDict, total=False): - guild_id: Snowflake - member: MemberWithUser - - -class TypingStartEvent(_TypingStartEventOptional): +class TypingStartEvent(TypedDict): channel_id: Snowflake user_id: Snowflake timestamp: int + guild_id: NotRequired[Snowflake] + member: NotRequired[MemberWithUser] diff --git a/discord/types/guild.py b/discord/types/guild.py index e6e26cd12..b6de3d161 100644 --- a/discord/types/guild.py +++ b/discord/types/guild.py @@ -23,6 +23,7 @@ DEALINGS IN THE SOFTWARE. """ from typing import List, Literal, Optional, TypedDict +from typing_extensions import NotRequired from .scheduled_event import GuildScheduledEvent from .sticker import GuildSticker @@ -32,10 +33,9 @@ from .voice import GuildVoiceState from .welcome_screen import WelcomeScreen from .activity import PartialPresenceUpdate from .role import Role -from .member import Member, MemberWithUser +from .member import MemberWithUser from .emoji import Emoji from .user import User -from .sticker import GuildSticker from .threads import Thread @@ -44,31 +44,9 @@ class Ban(TypedDict): user: User -class _UnavailableGuildOptional(TypedDict, total=False): - unavailable: bool - - -class UnavailableGuild(_UnavailableGuildOptional): +class UnavailableGuild(TypedDict): id: Snowflake - - -class _GuildOptional(TypedDict, total=False): - icon_hash: Optional[str] - owner: bool - widget_enabled: bool - widget_channel_id: Optional[Snowflake] - joined_at: Optional[str] - large: bool - member_count: int - voice_states: List[GuildVoiceState] - members: List[MemberWithUser] - channels: List[GuildChannel] - presences: List[PartialPresenceUpdate] - threads: List[Thread] - max_presences: Optional[int] - max_members: int - premium_subscription_count: int - max_video_channel_users: int + unavailable: NotRequired[bool] DefaultMessageNotificationLevel = Literal[0, 1] @@ -99,7 +77,7 @@ class GuildPreview(_BaseGuildPreview, _GuildPreviewUnique): ... -class Guild(_BaseGuildPreview, _GuildOptional): +class Guild(_BaseGuildPreview): owner_id: Snowflake region: str afk_channel_id: Optional[Snowflake] @@ -122,6 +100,23 @@ class Guild(_BaseGuildPreview, _GuildOptional): stickers: List[GuildSticker] stage_instances: List[StageInstance] guild_scheduled_events: List[GuildScheduledEvent] + icon_hash: NotRequired[Optional[str]] + owner: NotRequired[bool] + permissions: NotRequired[str] + widget_enabled: NotRequired[bool] + widget_channel_id: NotRequired[Optional[Snowflake]] + joined_at: NotRequired[Optional[str]] + large: NotRequired[bool] + member_count: NotRequired[int] + voice_states: NotRequired[List[GuildVoiceState]] + members: NotRequired[List[MemberWithUser]] + channels: NotRequired[List[GuildChannel]] + presences: NotRequired[List[PartialPresenceUpdate]] + threads: NotRequired[List[Thread]] + max_presences: NotRequired[Optional[int]] + max_members: NotRequired[int] + premium_subscription_count: NotRequired[int] + max_video_channel_users: NotRequired[int] class InviteGuild(Guild, total=False): diff --git a/discord/types/integration.py b/discord/types/integration.py index f3ca0f3e5..cf73b7a90 100644 --- a/discord/types/integration.py +++ b/discord/types/integration.py @@ -25,20 +25,19 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations from typing import Literal, Optional, TypedDict, Union +from typing_extensions import NotRequired + from .snowflake import Snowflake from .user import User -class _IntegrationApplicationOptional(TypedDict, total=False): - bot: User - - -class IntegrationApplication(_IntegrationApplicationOptional): +class IntegrationApplication(TypedDict): id: Snowflake name: str icon: Optional[str] description: str summary: str + bot: NotRequired[User] class IntegrationAccount(TypedDict): diff --git a/discord/types/interactions.py b/discord/types/interactions.py index 8c097c0ed..293b9ac27 100644 --- a/discord/types/interactions.py +++ b/discord/types/interactions.py @@ -25,7 +25,7 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations from typing import TYPE_CHECKING, Dict, List, Literal, TypedDict, Union - +from typing_extensions import NotRequired from .channel import ChannelTypeWithoutThread, ThreadMetadata from .threads import ThreadType @@ -120,14 +120,11 @@ ApplicationCommandInteractionDataOption = Union[ ] -class _BaseApplicationCommandInteractionDataOptional(TypedDict, total=False): - resolved: ResolvedData - guild_id: Snowflake - - -class _BaseApplicationCommandInteractionData(_BaseApplicationCommandInteractionDataOptional): +class _BaseApplicationCommandInteractionData(TypedDict): id: Snowflake name: str + resolved: NotRequired[ResolvedData] + guild_id: NotRequired[Snowflake] class ChatInputApplicationCommandInteractionData(_BaseApplicationCommandInteractionData, total=False): @@ -199,18 +196,15 @@ InteractionData = Union[ ] -class _BaseInteractionOptional(TypedDict, total=False): - guild_id: Snowflake - channel_id: Snowflake - locale: str - guild_locale: str - - -class _BaseInteraction(_BaseInteractionOptional): +class _BaseInteraction(TypedDict): id: Snowflake application_id: Snowflake token: str version: Literal[1] + guild_id: NotRequired[Snowflake] + channel_id: NotRequired[Snowflake] + locale: NotRequired[str] + guild_locale: NotRequired[str] class PingInteraction(_BaseInteraction): @@ -240,3 +234,4 @@ class MessageInteraction(TypedDict): type: InteractionType name: str user: User + member: NotRequired[Member] diff --git a/discord/types/invite.py b/discord/types/invite.py index 6d7818e81..b53ca374c 100644 --- a/discord/types/invite.py +++ b/discord/types/invite.py @@ -25,7 +25,7 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations from typing import Literal, Optional, TypedDict, Union - +from typing_extensions import NotRequired from .scheduled_event import GuildScheduledEvent from .snowflake import Snowflake @@ -37,15 +37,6 @@ from .appinfo import PartialAppInfo InviteTargetType = Literal[1, 2] -class _InviteOptional(TypedDict, total=False): - guild: InviteGuild - inviter: PartialUser - target_user: PartialUser - target_type: InviteTargetType - target_application: PartialAppInfo - guild_scheduled_event: GuildScheduledEvent - - class _InviteMetadata(TypedDict, total=False): uses: int max_uses: int @@ -55,12 +46,9 @@ class _InviteMetadata(TypedDict, total=False): expires_at: Optional[str] -class _VanityInviteOptional(_InviteMetadata, total=False): - revoked: bool - - -class VanityInvite(_VanityInviteOptional): +class VanityInvite(_InviteMetadata): code: Optional[str] + revoked: NotRequired[bool] class IncompleteInvite(_InviteMetadata): @@ -68,23 +56,20 @@ class IncompleteInvite(_InviteMetadata): channel: PartialChannel -class Invite(IncompleteInvite, _InviteOptional): - ... +class Invite(IncompleteInvite, total=False): + guild: InviteGuild + inviter: PartialUser + target_user: PartialUser + target_type: InviteTargetType + target_application: PartialAppInfo + guild_scheduled_event: GuildScheduledEvent class InviteWithCounts(Invite, _GuildPreviewUnique): ... -class _GatewayInviteCreateOptional(TypedDict, total=False): - guild_id: Snowflake - inviter: PartialUser - target_type: InviteTargetType - target_user: PartialUser - target_application: PartialAppInfo - - -class GatewayInviteCreate(_GatewayInviteCreateOptional): +class GatewayInviteCreate(TypedDict): channel_id: Snowflake code: str created_at: str @@ -92,15 +77,17 @@ class GatewayInviteCreate(_GatewayInviteCreateOptional): max_uses: int temporary: bool uses: bool - - -class _GatewayInviteDeleteOptional(TypedDict, total=False): guild_id: Snowflake + inviter: NotRequired[PartialUser] + target_type: NotRequired[InviteTargetType] + target_user: NotRequired[PartialUser] + target_application: NotRequired[PartialAppInfo] -class GatewayInviteDelete(_GatewayInviteDeleteOptional): +class GatewayInviteDelete(TypedDict): channel_id: Snowflake code: str + guild_id: NotRequired[Snowflake] GatewayInvite = Union[GatewayInviteCreate, GatewayInviteDelete] diff --git a/discord/types/message.py b/discord/types/message.py index 9d29ad1ae..178285ed7 100644 --- a/discord/types/message.py +++ b/discord/types/message.py @@ -25,6 +25,8 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations from typing import List, Literal, Optional, TypedDict, Union +from typing_extensions import NotRequired + from .snowflake import Snowflake, SnowflakeList from .member import Member, UserWithMember from .user import User @@ -36,12 +38,9 @@ from .interactions import MessageInteraction from .sticker import StickerItem -class _PartialMessageOptional(TypedDict, total=False): - guild_id: Snowflake - - -class PartialMessage(_PartialMessageOptional): +class PartialMessage(TypedDict): channel_id: Snowflake + guild_id: NotRequired[Snowflake] class ChannelMention(TypedDict): @@ -57,21 +56,18 @@ class Reaction(TypedDict): emoji: PartialEmoji -class _AttachmentOptional(TypedDict, total=False): - height: Optional[int] - width: Optional[int] - description: str - content_type: str - spoiler: bool - ephemeral: bool - - -class Attachment(_AttachmentOptional): +class Attachment(TypedDict): id: Snowflake filename: str size: int url: str proxy_url: str + height: NotRequired[Optional[int]] + width: NotRequired[Optional[int]] + description: NotRequired[str] + content_type: NotRequired[str] + spoiler: NotRequired[bool] + ephemeral: NotRequired[bool] MessageActivityType = Literal[1, 2, 3, 5] @@ -82,15 +78,12 @@ class MessageActivity(TypedDict): party_id: str -class _MessageApplicationOptional(TypedDict, total=False): - cover_image: str - - -class MessageApplication(_MessageApplicationOptional): +class MessageApplication(TypedDict): id: Snowflake description: str icon: Optional[str] name: str + cover_image: NotRequired[str] class MessageReference(TypedDict, total=False): @@ -100,30 +93,11 @@ class MessageReference(TypedDict, total=False): fail_if_not_exists: bool -class _MessageOptional(TypedDict, total=False): - guild_id: Snowflake - member: Member - mention_channels: List[ChannelMention] - reactions: List[Reaction] - nonce: Union[int, str] - webhook_id: Snowflake - activity: MessageActivity - application: MessageApplication - application_id: Snowflake - message_reference: MessageReference - flags: int - sticker_items: List[StickerItem] - referenced_message: Optional[Message] - interaction: MessageInteraction - components: List[Component] - - -MessageType = Literal[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 18, 19, 20, 21, 22, 23] - - -class Message(_MessageOptional): +MessageType = Literal[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 18, 19, 20, 21] + + +class Message(PartialMessage): id: Snowflake - channel_id: Snowflake author: User content: str timestamp: str @@ -136,6 +110,20 @@ class Message(_MessageOptional): embeds: List[Embed] pinned: bool type: MessageType + member: NotRequired[Member] + mention_channels: NotRequired[List[ChannelMention]] + reactions: NotRequired[List[Reaction]] + nonce: NotRequired[Union[int, str]] + webhook_id: NotRequired[Snowflake] + activity: NotRequired[MessageActivity] + application: NotRequired[MessageApplication] + application_id: NotRequired[Snowflake] + message_reference: NotRequired[MessageReference] + flags: NotRequired[int] + sticker_items: NotRequired[List[StickerItem]] + referenced_message: NotRequired[Optional[Message]] + interaction: NotRequired[MessageInteraction] + components: NotRequired[List[Component]] AllowedMentionType = Literal['roles', 'users', 'everyone'] diff --git a/discord/types/role.py b/discord/types/role.py index b2f2ad2c4..17f66cbff 100644 --- a/discord/types/role.py +++ b/discord/types/role.py @@ -24,19 +24,13 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations -from typing import Optional, TypedDict -from .snowflake import Snowflake - +from typing import TypedDict, Optional +from typing_extensions import NotRequired -class _RoleOptional(TypedDict, total=False): - icon: Optional[str] - unicode_emoji: Optional[str] - tags: RoleTags - icon: Optional[str] - unicode_emoji: Optional[str] +from .snowflake import Snowflake -class Role(_RoleOptional): +class Role(TypedDict): id: Snowflake name: str color: int @@ -45,6 +39,9 @@ class Role(_RoleOptional): permissions: str managed: bool mentionable: bool + icon: NotRequired[Optional[str]] + unicode_emoji: NotRequired[Optional[str]] + tags: NotRequired[RoleTags] class RoleTags(TypedDict, total=False): diff --git a/discord/types/scheduled_event.py b/discord/types/scheduled_event.py index 83dd95dbe..91c15cb08 100644 --- a/discord/types/scheduled_event.py +++ b/discord/types/scheduled_event.py @@ -23,6 +23,7 @@ DEALINGS IN THE SOFTWARE. """ from typing import List, Literal, Optional, TypedDict, Union +from typing_extensions import NotRequired from .snowflake import Snowflake from .user import User @@ -33,15 +34,7 @@ EventStatus = Literal[1, 2, 3, 4] EntityType = Literal[1, 2, 3] -class _BaseGuildScheduledEventOptional(TypedDict, total=False): - creator_id: Optional[Snowflake] - description: Optional[str] - creator: User - user_count: int - image: Optional[str] - - -class _BaseGuildScheduledEvent(_BaseGuildScheduledEventOptional): +class _BaseGuildScheduledEvent(TypedDict): id: Snowflake guild_id: Snowflake entity_id: Optional[Snowflake] @@ -49,15 +42,17 @@ class _BaseGuildScheduledEvent(_BaseGuildScheduledEventOptional): scheduled_start_time: str privacy_level: PrivacyLevel status: EventStatus + creator_id: NotRequired[Optional[Snowflake]] + description: NotRequired[Optional[str]] + creator: NotRequired[User] + user_count: NotRequired[int] + image: NotRequired[Optional[str]] -class _VoiceChannelScheduledEventOptional(_BaseGuildScheduledEvent, total=False): - scheduled_end_time: Optional[str] - - -class _VoiceChannelScheduledEvent(_VoiceChannelScheduledEventOptional): +class _VoiceChannelScheduledEvent(_BaseGuildScheduledEvent): channel_id: Snowflake entity_metadata: Literal[None] + scheduled_end_time: NotRequired[Optional[str]] class StageInstanceScheduledEvent(_VoiceChannelScheduledEvent): diff --git a/discord/types/sticker.py b/discord/types/sticker.py index 2a0278542..7dcd0ccba 100644 --- a/discord/types/sticker.py +++ b/discord/types/sticker.py @@ -25,6 +25,8 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations from typing import List, Literal, TypedDict, Union, Optional +from typing_extensions import NotRequired + from .snowflake import Snowflake from .user import User @@ -51,14 +53,11 @@ class StandardSticker(BaseSticker): pack_id: Snowflake -class _GuildStickerOptional(TypedDict, total=False): - user: User - - -class GuildSticker(BaseSticker, _GuildStickerOptional): +class GuildSticker(BaseSticker): type: Literal[2] available: bool guild_id: Snowflake + user: NotRequired[User] Sticker = Union[StandardSticker, GuildSticker] @@ -74,19 +73,10 @@ class StickerPack(TypedDict): banner_asset_id: Optional[Snowflake] -class _CreateGuildStickerOptional(TypedDict, total=False): - description: str - - -class CreateGuildSticker(_CreateGuildStickerOptional): +class CreateGuildSticker(TypedDict): name: str tags: str - - -class EditGuildSticker(TypedDict, total=False): - name: str - tags: str - description: str + description: NotRequired[str] class ListPremiumStickerPacks(TypedDict): diff --git a/discord/types/threads.py b/discord/types/threads.py index 418167a97..44d3f1711 100644 --- a/discord/types/threads.py +++ b/discord/types/threads.py @@ -23,7 +23,9 @@ DEALINGS IN THE SOFTWARE. """ from __future__ import annotations + from typing import List, Literal, Optional, TypedDict +from typing_extensions import NotRequired from .snowflake import Snowflake @@ -38,26 +40,17 @@ class ThreadMember(TypedDict): flags: int -class _ThreadMetadataOptional(TypedDict, total=False): - archiver_id: Snowflake - locked: bool - invitable: bool - create_timestamp: str - - -class ThreadMetadata(_ThreadMetadataOptional): +class ThreadMetadata(TypedDict): archived: bool auto_archive_duration: ThreadArchiveDuration archive_timestamp: str + archiver_id: NotRequired[Snowflake] + locked: NotRequired[bool] + invitable: NotRequired[bool] + create_timestamp: NotRequired[str] -class _ThreadOptional(TypedDict, total=False): - member: ThreadMember - last_message_id: Optional[Snowflake] - last_pin_timestamp: Optional[Snowflake] - - -class Thread(_ThreadOptional): +class Thread(TypedDict): id: Snowflake guild_id: Snowflake parent_id: Snowflake @@ -69,6 +62,9 @@ class Thread(_ThreadOptional): rate_limit_per_user: int thread_metadata: ThreadMetadata member_ids_preview: List[Snowflake] + member: NotRequired[ThreadMember] + last_message_id: NotRequired[Optional[Snowflake]] + last_pin_timestamp: NotRequired[Optional[Snowflake]] class ThreadPaginationPayload(TypedDict): diff --git a/discord/types/voice.py b/discord/types/voice.py index f9d1df3d0..8f4e2e03e 100644 --- a/discord/types/voice.py +++ b/discord/types/voice.py @@ -23,6 +23,8 @@ DEALINGS IN THE SOFTWARE. """ from typing import Optional, TypedDict, List, Literal +from typing_extensions import NotRequired + from .snowflake import Snowflake from .member import MemberWithUser @@ -30,12 +32,7 @@ from .member import MemberWithUser SupportedModes = Literal['xsalsa20_poly1305_lite', 'xsalsa20_poly1305_suffix', 'xsalsa20_poly1305'] -class _PartialVoiceStateOptional(TypedDict, total=False): - member: MemberWithUser - self_stream: bool - - -class _VoiceState(_PartialVoiceStateOptional): +class _VoiceState(TypedDict): user_id: Snowflake session_id: str deaf: bool @@ -44,6 +41,8 @@ class _VoiceState(_PartialVoiceStateOptional): self_mute: bool self_video: bool suppress: bool + member: NotRequired[MemberWithUser] + self_stream: NotRequired[bool] class GuildVoiceState(_VoiceState): diff --git a/discord/types/webhook.py b/discord/types/webhook.py index c526d750b..dd5eea156 100644 --- a/discord/types/webhook.py +++ b/discord/types/webhook.py @@ -23,7 +23,10 @@ DEALINGS IN THE SOFTWARE. """ from __future__ import annotations + from typing import Literal, Optional, TypedDict +from typing_extensions import NotRequired + from .snowflake import Snowflake from .user import User from .channel import PartialChannel @@ -35,28 +38,22 @@ class SourceGuild(TypedDict): icon: str -class _WebhookOptional(TypedDict, total=False): - guild_id: Snowflake - user: User - token: str - - WebhookType = Literal[1, 2, 3] -class _FollowerWebhookOptional(TypedDict, total=False): - source_channel: PartialChannel - source_guild: SourceGuild - - -class FollowerWebhook(_FollowerWebhookOptional): +class FollowerWebhook(TypedDict): channel_id: Snowflake webhook_id: Snowflake + source_channel: NotRequired[PartialChannel] + source_guild: NotRequired[SourceGuild] -class PartialWebhook(_WebhookOptional): +class PartialWebhook(TypedDict): id: Snowflake type: WebhookType + guild_id: NotRequired[Snowflake] + user: NotRequired[User] + token: NotRequired[str] class _FullWebhook(TypedDict, total=False): diff --git a/discord/user.py b/discord/user.py index 74f592bdc..01cf8ea86 100644 --- a/discord/user.py +++ b/discord/user.py @@ -854,10 +854,10 @@ class User(BaseUser, discord.abc.Connectable, discord.abc.Messageable): return f'<{self.__class__.__name__} id={self.id} name={self.name!r} discriminator={self.discriminator!r} bot={self.bot} system={self.system}>' def _get_voice_client_key(self) -> Tuple[int, str]: - return self._state.self_id, 'self_id' # type: ignore - self_id is always set at this point + return self._state.self_id, 'self_id' # type: ignore # self_id is always set at this point def _get_voice_state_pair(self) -> Tuple[int, int]: - return self._state.self_id, self.dm_channel.id # type: ignore - self_id is always set at this point + return self._state.self_id, self.dm_channel.id # type: ignore # self_id is always set at this point async def _get_channel(self) -> DMChannel: ch = await self.create_dm() @@ -879,7 +879,7 @@ class User(BaseUser, discord.abc.Connectable, discord.abc.Messageable): @property def relationship(self) -> Optional[Relationship]: """Optional[:class:`Relationship`]: Returns the :class:`Relationship` with this user if applicable, ``None`` otherwise.""" - return self._state.user.get_relationship(self.id) # type: ignore - user is always present when logged in + return self._state.user.get_relationship(self.id) # type: ignore # user is always present when logged in @copy_doc(discord.abc.Connectable.connect) async def connect( diff --git a/discord/utils.py b/discord/utils.py index 282d74a1f..b75888240 100644 --- a/discord/utils.py +++ b/discord/utils.py @@ -151,10 +151,7 @@ if TYPE_CHECKING: P = ParamSpec('P') - MaybeCoroFunc = Union[ - Callable[P, Coroutine[Any, Any, 'T']], - Callable[P, 'T'], - ] + MaybeAwaitableFunc = Callable[P, 'MaybeAwaitable[T]'] _SnowflakeListBase = array.array[int] @@ -167,6 +164,7 @@ T = TypeVar('T') T_co = TypeVar('T_co', covariant=True) _Iter = Union[Iterable[T], AsyncIterable[T]] Coro = Coroutine[Any, Any, T] +MaybeAwaitable = Union[T, Awaitable[T]] class CachedSlotProperty(Generic[T, T_co]): @@ -626,7 +624,7 @@ def _parse_ratelimit_header(request: Any, *, use_clock: bool = False) -> float: return float(reset_after) -async def maybe_coroutine(f: MaybeCoroFunc[P, T], *args: P.args, **kwargs: P.kwargs) -> T: +async def maybe_coroutine(f: MaybeAwaitableFunc[P, T], *args: P.args, **kwargs: P.kwargs) -> T: value = f(*args, **kwargs) if _isawaitable(value): return await value @@ -1043,9 +1041,9 @@ def evaluate_annotation( if implicit_str and isinstance(tp, str): if tp in cache: return cache[tp] - evaluated = eval(tp, globals, locals) + evaluated = evaluate_annotation(eval(tp, globals, locals), globals, locals, cache) cache[tp] = evaluated - return evaluate_annotation(evaluated, globals, locals, cache) + return evaluated if hasattr(tp, '__args__'): implicit_str = True diff --git a/discord/voice_client.py b/discord/voice_client.py index bf46fb046..ee102a55e 100644 --- a/discord/voice_client.py +++ b/discord/voice_client.py @@ -440,7 +440,7 @@ class VoiceClient(VoiceProtocol): @property def user(self) -> ClientUser: """:class:`ClientUser`: The user connected to voice (i.e. ourselves).""" - return self._state.user # type: ignore - user can't be None after login + return self._state.user # type: ignore # Connection related @@ -465,9 +465,9 @@ class VoiceClient(VoiceProtocol): else: guild = self.guild if guild is not None: - self.channel = channel_id and guild.get_channel(int(channel_id)) # type: ignore - This won't be None + self.channel = channel_id and guild.get_channel(int(channel_id)) # type: ignore # This won't be None else: - self.channel = channel_id and self._state._get_private_channel(int(channel_id)) # type: ignore - This won't be None + self.channel = channel_id and self._state._get_private_channel(int(channel_id)) # type: ignore # This won't be None else: self._voice_state_complete.set() diff --git a/discord/webhook/async_.py b/discord/webhook/async_.py index 625b78930..57c03f56a 100644 --- a/discord/webhook/async_.py +++ b/discord/webhook/async_.py @@ -1072,7 +1072,7 @@ class Webhook(BaseWebhook): .. versionchanged:: 2.0 This function will now raise :exc:`ValueError` instead of - ``~InvalidArgument``. + ``InvalidArgument``. Parameters ------------ diff --git a/discord/widget.py b/discord/widget.py index 5ea0dcbd5..402cdab82 100644 --- a/discord/widget.py +++ b/discord/widget.py @@ -24,7 +24,7 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations -from typing import Any, List, Optional, TYPE_CHECKING, Union +from typing import List, Optional, TYPE_CHECKING, Union from .utils import snowflake_time, _get_as_snowflake, resolve_invite from .user import BaseUser diff --git a/docs/api.rst b/docs/api.rst index ed50f8cdf..a45aa15a1 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -1776,10 +1776,6 @@ of :class:`enum.Enum`. A guild news channel. - .. attribute:: store - - A guild store channel. - .. attribute:: stage_voice A guild stage voice channel. @@ -1908,9 +1904,9 @@ of :class:`enum.Enum`. The system message denoting that the author is replying to a message. .. versionadded:: 2.0 - .. attribute:: application_command + .. attribute:: chat_input_command - The system message denoting that an application (or "slash") command was executed. + The system message denoting that a slash command was executed. .. versionadded:: 2.0 .. attribute:: guild_invite_reminder @@ -1923,6 +1919,11 @@ of :class:`enum.Enum`. The system message denoting the message in the thread that is the one that started the thread's conversation topic. + .. versionadded:: 2.0 + .. attribute:: context_menu_command + + The system message denoting that a context menu command was executed. + .. versionadded:: 2.0 .. class:: UserFlags @@ -2191,7 +2192,7 @@ of :class:`enum.Enum`. - :attr:`~AuditLogDiff.afk_channel` - :attr:`~AuditLogDiff.system_channel` - :attr:`~AuditLogDiff.afk_timeout` - - :attr:`~AuditLogDiff.default_message_notifications` + - :attr:`~AuditLogDiff.default_notifications` - :attr:`~AuditLogDiff.explicit_content_filter` - :attr:`~AuditLogDiff.mfa_level` - :attr:`~AuditLogDiff.name` @@ -2787,6 +2788,7 @@ of :class:`enum.Enum`. - :attr:`~AuditLogDiff.privacy_level` - :attr:`~AuditLogDiff.status` - :attr:`~AuditLogDiff.entity_type` + - :attr:`~AuditLogDiff.cover_image` .. versionadded:: 2.0 @@ -2805,6 +2807,7 @@ of :class:`enum.Enum`. - :attr:`~AuditLogDiff.privacy_level` - :attr:`~AuditLogDiff.status` - :attr:`~AuditLogDiff.entity_type` + - :attr:`~AuditLogDiff.cover_image` .. versionadded:: 2.0 @@ -2823,6 +2826,7 @@ of :class:`enum.Enum`. - :attr:`~AuditLogDiff.privacy_level` - :attr:`~AuditLogDiff.status` - :attr:`~AuditLogDiff.entity_type` + - :attr:`~AuditLogDiff.cover_image` .. versionadded:: 2.0 @@ -3102,7 +3106,7 @@ of :class:`enum.Enum`. .. class:: Locale - Supported locales by Discord. Mainly used for application command localisation. + Supported locales by Discord. .. versionadded:: 2.0 @@ -3530,12 +3534,6 @@ AuditLogDiff :type: :class:`ContentFilter` - .. attribute:: default_message_notifications - - The guild's default message notification setting. - - :type: :class:`int` - .. attribute:: vanity_url_code The guild's vanity URL. @@ -3871,7 +3869,7 @@ AuditLogDiff See also :attr:`Guild.preferred_locale` - :type: :class:`str` + :type: :class:`Locale` .. attribute:: prune_delete_days @@ -3891,6 +3889,14 @@ AuditLogDiff :type: :class:`EntityType` + .. attribute:: cover_image + + The scheduled event's cover image. + + See also :attr:`ScheduledEvent.cover_image`. + + :type: :class:`Asset` + .. this is currently missing the following keys: reason and application_id I'm not sure how to about porting these @@ -4320,15 +4326,6 @@ ThreadMember .. autoclass:: ThreadMember() :members: -StoreChannel -~~~~~~~~~~~~~ - -.. attributetable:: StoreChannel - -.. autoclass:: StoreChannel() - :members: - :inherited-members: - VoiceChannel ~~~~~~~~~~~~~ diff --git a/docs/ext/commands/api.rst b/docs/ext/commands/api.rst index 483a258bf..da98cd734 100644 --- a/docs/ext/commands/api.rst +++ b/docs/ext/commands/api.rst @@ -322,14 +322,6 @@ Checks .. _ext_commands_api_context: -Cooldown ---------- - -.. attributetable:: discord.ext.commands.Cooldown - -.. autoclass:: discord.ext.commands.Cooldown - :members: - Context -------- @@ -375,9 +367,6 @@ Converters .. autoclass:: discord.ext.commands.VoiceChannelConverter :members: -.. autoclass:: discord.ext.commands.StoreChannelConverter - :members: - .. autoclass:: discord.ext.commands.StageChannelConverter :members: diff --git a/docs/ext/commands/commands.rst b/docs/ext/commands/commands.rst index 4549d66e7..4fa97ae5c 100644 --- a/docs/ext/commands/commands.rst +++ b/docs/ext/commands/commands.rst @@ -390,7 +390,6 @@ A lot of discord models work out of the gate as a parameter: - :class:`TextChannel` - :class:`VoiceChannel` - :class:`StageChannel` (since v1.7) -- :class:`StoreChannel` (since v1.7) - :class:`CategoryChannel` - :class:`Invite` - :class:`Guild` (since v1.7) @@ -430,8 +429,6 @@ converter is given below: +--------------------------+-------------------------------------------------+ | :class:`StageChannel` | :class:`~ext.commands.StageChannelConverter` | +--------------------------+-------------------------------------------------+ -| :class:`StoreChannel` | :class:`~ext.commands.StoreChannelConverter` | -+--------------------------+-------------------------------------------------+ | :class:`CategoryChannel` | :class:`~ext.commands.CategoryChannelConverter` | +--------------------------+-------------------------------------------------+ | :class:`Invite` | :class:`~ext.commands.InviteConverter` | diff --git a/docs/ext/tasks/index.rst b/docs/ext/tasks/index.rst index 8f90a87dc..484d79620 100644 --- a/docs/ext/tasks/index.rst +++ b/docs/ext/tasks/index.rst @@ -108,7 +108,7 @@ Doing something during cancellation: class MyCog(commands.Cog): def __init__(self, bot): - self.bot= bot + self.bot = bot self._batch = [] self.lock = asyncio.Lock() self.bulker.start() diff --git a/docs/migrating.rst b/docs/migrating.rst index 5219a9d99..e6f3abbb4 100644 --- a/docs/migrating.rst +++ b/docs/migrating.rst @@ -1192,7 +1192,7 @@ The main differences between text channels and threads are: - :attr:`Permissions.create_private_threads` - :attr:`Permissions.send_messages_in_threads` -- Threads do not have their own nsfw status, they inherit it from their parent channel. +- Threads do not have their own NSFW status, they inherit it from their parent channel. - This means that :class:`Thread` does not have an ``nsfw`` attribute. @@ -1311,10 +1311,6 @@ The following have been changed: - Note that this method will return ``None`` instead of :class:`StageChannel` if the edit was only positional. -- :meth:`StoreChannel.edit` - - - Note that this method will return ``None`` instead of :class:`StoreChannel` if the edit was only positional. - - :meth:`TextChannel.edit` - Note that this method will return ``None`` instead of :class:`TextChannel` if the edit was only positional. @@ -1589,7 +1585,6 @@ The following methods have been changed: - :meth:`Role.edit` - :meth:`StageChannel.edit` - :meth:`StageInstance.edit` -- :meth:`StoreChannel.edit` - :meth:`StreamIntegration.edit` - :meth:`TextChannel.edit` - :meth:`VoiceChannel.edit` @@ -1608,6 +1603,33 @@ The following methods have been changed: - :meth:`Webhook.send` - :meth:`abc.GuildChannel.set_permissions` +Removal of ``StoreChannel`` +----------------------------- + +Discord's API has removed store channels as of `March 10th, 2022 `_. Therefore, the library has removed support for it as well. + +This removes the following: + +- ``StoreChannel`` +- ``commands.StoreChannelConverter`` +- ``ChannelType.store`` + +Change in ``Guild.bans`` endpoint +----------------------------------- + +Due to a breaking API change by Discord, :meth:`Guild.bans` no longer returns a list of every ban in the guild but instead is paginated using an asynchronous iterator. + +.. code-block:: python3 + + # before + + bans = await guild.bans() + + # after + async for ban in guild.bans(limit=1000): + ... + + Function Signature Changes ---------------------------- @@ -1632,7 +1654,7 @@ Parameters in the following methods are now all positional-only: - :meth:`Client.fetch_webhook` - :meth:`Client.fetch_widget` - :meth:`Message.add_reaction` -- :meth:`Client.error` +- :meth:`Client.on_error` - :meth:`abc.Messageable.fetch_message` - :meth:`abc.GuildChannel.permissions_for` - :meth:`DMChannel.get_partial_message` @@ -1858,6 +1880,9 @@ The following changes have been made: - :meth:`Permissions.stage_moderator` now includes the :attr:`Permissions.manage_channels` permission and the :attr:`Permissions.request_to_speak` permission is no longer included. +- :attr:`File.filename` will no longer be ``None``, in situations where previously this was the case the filename is set to `'untitled'`. + + .. _migrating_2_0_commands: Command Extension Changes @@ -1872,7 +1897,7 @@ As an extension to the :ref:`asyncio changes ` To accommodate this, the following changes have been made: -- the ``setup`` and ``teardown`` functions in extensions must now be coroutines. +- The ``setup`` and ``teardown`` functions in extensions must now be coroutines. - :meth:`ext.commands.Bot.load_extension` must now be awaited. - :meth:`ext.commands.Bot.unload_extension` must now be awaited. - :meth:`ext.commands.Bot.reload_extension` must now be awaited. @@ -2072,7 +2097,7 @@ The following attributes have been removed: - Use :attr:`ext.commands.Context.clean_prefix` instead. -Miscellanous Changes +Miscellaneous Changes ~~~~~~~~~~~~~~~~~~~~~~ - :meth:`ext.commands.Bot.add_cog` is now raising :exc:`ClientException` when a cog with the same name is already loaded. diff --git a/examples/background_task.py b/examples/background_task.py index a8a7f99c3..fd61efbcc 100644 --- a/examples/background_task.py +++ b/examples/background_task.py @@ -2,6 +2,7 @@ from discord.ext import tasks import discord + class MyClient(discord.Client): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -17,15 +18,16 @@ class MyClient(discord.Client): print(f'Logged in as {self.user} (ID: {self.user.id})') print('------') - @tasks.loop(seconds=60) # task runs every 60 seconds + @tasks.loop(seconds=60) # task runs every 60 seconds async def my_background_task(self): - channel = self.get_channel(1234567) # channel ID goes here + channel = self.get_channel(1234567) # channel ID goes here self.counter += 1 await channel.send(self.counter) @my_background_task.before_loop async def before_my_task(self): - await self.wait_until_ready() # wait until the bot logs in + await self.wait_until_ready() # wait until the bot logs in + client = MyClient() client.run('token') diff --git a/examples/background_task_asyncio.py b/examples/background_task_asyncio.py index 860916bb0..9f895fbca 100644 --- a/examples/background_task_asyncio.py +++ b/examples/background_task_asyncio.py @@ -1,6 +1,7 @@ import discord import asyncio + class MyClient(discord.Client): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -16,11 +17,11 @@ class MyClient(discord.Client): async def my_background_task(self): await self.wait_until_ready() counter = 0 - channel = self.get_channel(1234567) # channel ID goes here + channel = self.get_channel(1234567) # channel ID goes here while not self.is_closed(): counter += 1 await channel.send(counter) - await asyncio.sleep(60) # task runs every 60 seconds + await asyncio.sleep(60) # task runs every 60 seconds client = MyClient() diff --git a/examples/basic_bot.py b/examples/basic_bot.py index af5a803f2..b1943c53b 100644 --- a/examples/basic_bot.py +++ b/examples/basic_bot.py @@ -9,16 +9,19 @@ There are a number of utility commands being showcased here.''' bot = commands.Bot(command_prefix='?', description=description, self_bot=True) + @bot.event async def on_ready(): print(f'Logged in as {bot.user} (ID: {bot.user.id})') print('------') + @bot.command() async def add(ctx, left: int, right: int): """Adds two numbers together.""" await ctx.send(left + right) + @bot.command() async def roll(ctx, dice: str): """Rolls a dice in NdN format.""" @@ -31,22 +34,26 @@ async def roll(ctx, dice: str): result = ', '.join(str(random.randint(1, limit)) for r in range(rolls)) await ctx.send(result) + @bot.command(description='For when you wanna settle the score some other way') async def choose(ctx, *choices: str): """Chooses between multiple choices.""" await ctx.send(random.choice(choices)) + @bot.command() async def repeat(ctx, times: int, content='repeating...'): """Repeats a message multiple times.""" for i in range(times): await ctx.send(content) + @bot.command() async def joined(ctx, member: discord.Member): """Says when a member joined.""" await ctx.send(f'{member.name} joined in {member.joined_at}') + @bot.group() async def cool(ctx): """Says if a user is cool. @@ -56,9 +63,11 @@ async def cool(ctx): if ctx.invoked_subcommand is None: await ctx.send(f'No, {ctx.subcommand_passed} is not cool') + @cool.command(name='bot') async def _bot(ctx): """Is the bot cool?""" await ctx.send('Yes, the bot is cool.') + bot.run('token') diff --git a/examples/basic_voice.py b/examples/basic_voice.py index 788073a19..ca1e48357 100644 --- a/examples/basic_voice.py +++ b/examples/basic_voice.py @@ -20,11 +20,11 @@ ytdl_format_options = { 'quiet': True, 'no_warnings': True, 'default_search': 'auto', - 'source_address': '0.0.0.0' # bind to ipv4 since ipv6 addresses cause issues sometimes + 'source_address': '0.0.0.0', # bind to ipv4 since ipv6 addresses cause issues sometimes } ffmpeg_options = { - 'options': '-vn' + 'options': '-vn', } ytdl = youtube_dl.YoutubeDL(ytdl_format_options) @@ -131,9 +131,11 @@ async def on_ready(): print(f'Logged in as {bot.user} (ID: {bot.user.id})') print('------') + async def main(): async with bot: await bot.add_cog(Music(bot)) await bot.start('token') + asyncio.run(main()) diff --git a/examples/converters.py b/examples/converters.py index 7b0f17aed..bc82a6a4d 100644 --- a/examples/converters.py +++ b/examples/converters.py @@ -26,6 +26,7 @@ async def userinfo(ctx: commands.Context, user: discord.User): avatar = user.display_avatar.url await ctx.send(f'User found: {user_id} -- {username}\n{avatar}') + @userinfo.error async def userinfo_error(ctx: commands.Context, error: commands.CommandError): # if the conversion above fails for any reason, it will raise `commands.BadArgument` @@ -33,6 +34,7 @@ async def userinfo_error(ctx: commands.Context, error: commands.CommandError): if isinstance(error, commands.BadArgument): return await ctx.send('Couldn\'t find that user.') + # Custom Converter here class ChannelOrMemberConverter(commands.Converter): async def convert(self, ctx: commands.Context, argument: str): @@ -68,16 +70,16 @@ class ChannelOrMemberConverter(commands.Converter): raise commands.BadArgument(f'No Member or TextChannel could be converted from "{argument}"') - @bot.command() async def notify(ctx: commands.Context, target: ChannelOrMemberConverter): # This command signature utilises the custom converter written above # What will happen during command invocation is that the `target` above will be passed to - # the `argument` parameter of the `ChannelOrMemberConverter.convert` method and + # the `argument` parameter of the `ChannelOrMemberConverter.convert` method and # the conversion will go through the process defined there. await target.send(f'Hello, {target.name}!') + @bot.command() async def ignore(ctx: commands.Context, target: typing.Union[discord.Member, discord.TextChannel]): # This command signature utilises the `typing.Union` typehint. @@ -91,9 +93,10 @@ async def ignore(ctx: commands.Context, target: typing.Union[discord.Member, dis # To check the resulting type, `isinstance` is used if isinstance(target, discord.Member): await ctx.send(f'Member found: {target.mention}, adding them to the ignore list.') - elif isinstance(target, discord.TextChannel): # this could be an `else` but for completeness' sake. + elif isinstance(target, discord.TextChannel): # this could be an `else` but for completeness' sake. await ctx.send(f'Channel found: {target.mention}, adding it to the ignore list.') + # Built-in type converters. @bot.command() async def multiply(ctx: commands.Context, number: int, maybe: bool): @@ -105,4 +108,5 @@ async def multiply(ctx: commands.Context, number: int, maybe: bool): return await ctx.send(number * 2) await ctx.send(number * 5) + bot.run('token') diff --git a/examples/custom_context.py b/examples/custom_context.py index 45450bd33..4a83f054f 100644 --- a/examples/custom_context.py +++ b/examples/custom_context.py @@ -30,9 +30,10 @@ class MyBot(commands.Bot): bot = MyBot(command_prefix='!', self_bot=True) + @bot.command() async def guess(ctx, number: int): - """ Guess a random number from 1 to 6. """ + """Guess a random number from 1 to 6.""" # explained in a previous example, this gives you # a random number from 1-6 value = random.randint(1, 6) @@ -41,8 +42,9 @@ async def guess(ctx, number: int): # or a red cross mark if it wasn't await ctx.tick(number == value) + # IMPORTANT: You shouldn't hard code your token -# these are very important, and leaking them can +# these are very important, and leaking them can # let people do very malicious things with your # bot. Try to use a file or something to keep # them private, and don't commit it to GitHub diff --git a/examples/deleted.py b/examples/deleted.py index 6e0c203df..f2dd02bbe 100644 --- a/examples/deleted.py +++ b/examples/deleted.py @@ -1,5 +1,6 @@ import discord + class MyClient(discord.Client): async def on_ready(self): print(f'Logged in as {self.user} (ID: {self.user.id})') diff --git a/examples/edits.py b/examples/edits.py index 367860461..c16253bf7 100644 --- a/examples/edits.py +++ b/examples/edits.py @@ -1,6 +1,7 @@ import discord import asyncio + class MyClient(discord.Client): async def on_ready(self): print(f'Logged in as {self.user} (ID: {self.user.id})') diff --git a/examples/guessing_game.py b/examples/guessing_game.py index a6a1113cd..3aedccc2b 100644 --- a/examples/guessing_game.py +++ b/examples/guessing_game.py @@ -2,6 +2,7 @@ import discord import random import asyncio + class MyClient(discord.Client): async def on_ready(self): print(f'Logged in as {self.user} (ID: {self.user.id})') diff --git a/examples/modal.py b/examples/modal.py deleted file mode 100644 index 5dbc8095a..000000000 --- a/examples/modal.py +++ /dev/null @@ -1,69 +0,0 @@ -import discord -from discord import app_commands - -import traceback - -# Just default intents and a `discord.Client` instance -# We don't need a `commands.Bot` instance because we are not -# creating text-based commands. -intents = discord.Intents.default() -client = discord.Client(intents=intents) - -# We need an `discord.app_commands.CommandTree` instance -# to register application commands (slash commands in this case) -tree = app_commands.CommandTree(client) - -# The guild in which this slash command will be registered. -# As global commands can take up to an hour to propagate, it is ideal -# to test it in a guild. -TEST_GUILD = discord.Object(ID) - -@client.event -async def on_ready(): - print(f'Logged in as {client.user} (ID: {client.user.id})') - print('------') - - # Sync the application command with Discord. - await tree.sync(guild=TEST_GUILD) - -class Feedback(discord.ui.Modal, title='Feedback'): - # Our modal classes MUST subclass `discord.ui.Modal`, - # but the title can be whatever you want. - - # This will be a short input, where the user can enter their name - # It will also have a placeholder, as denoted by the `placeholder` kwarg. - # By default, it is required and is a short-style input which is exactly - # what we want. - name = discord.ui.TextInput( - label='Name', - placeholder='Your name here...', - ) - - # This is a longer, paragraph style input, where user can submit feedback - # Unlike the name, it is not required. If filled out, however, it will - # only accept a maximum of 300 characters, as denoted by the - # `max_length=300` kwarg. - feedback = discord.ui.TextInput( - label='What do you think of this new feature?', - style=discord.TextStyle.long, - placeholder='Type your feedback here...', - required=False, - max_length=300, - ) - - async def on_submit(self, interaction: discord.Interaction): - await interaction.response.send_message(f'Thanks for your feedback, {self.name.value}!', ephemeral=True) - - async def on_error(self, error: Exception, interaction: discord.Interaction) -> None: - await interaction.response.send_message('Oops! Something went wrong.', ephemeral=True) - - # Make sure we know what the error actually is - traceback.print_tb(error.__traceback__) - - -@tree.command(guild=TEST_GUILD, description="Submit feedback") -async def feedback(interaction: discord.Interaction): - # Send the modal with an instance of our `Feedback` class - await interaction.response.send_modal(Feedback()) - -client.run('token') diff --git a/examples/new_member.py b/examples/new_member.py index 954750cfa..62d8f5cd6 100644 --- a/examples/new_member.py +++ b/examples/new_member.py @@ -1,5 +1,6 @@ import discord + class MyClient(discord.Client): async def on_ready(self): print(f'Logged in as {self.user} (ID: {self.user.id})') diff --git a/examples/reaction_roles.py b/examples/reaction_roles.py index 3f2109e87..c22f8a33b 100644 --- a/examples/reaction_roles.py +++ b/examples/reaction_roles.py @@ -1,14 +1,15 @@ import discord + class MyClient(discord.Client): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.role_message_id = 0 # ID of the message that can be reacted to to add/remove a role. + self.role_message_id = 0 # ID of the message that can be reacted to to add/remove a role. self.emoji_to_role = { - discord.PartialEmoji(name='🔴'): 0, # ID of the role associated with unicode emoji '🔴'. - discord.PartialEmoji(name='🟡'): 0, # ID of the role associated with unicode emoji '🟡'. - discord.PartialEmoji(name='green', id=0): 0, # ID of the role associated with a partial emoji's ID. + discord.PartialEmoji(name='🔴'): 0, # ID of the role associated with unicode emoji '🔴'. + discord.PartialEmoji(name='🟡'): 0, # ID of the role associated with unicode emoji '🟡'. + discord.PartialEmoji(name='green', id=0): 0, # ID of the role associated with a partial emoji's ID. } async def on_raw_reaction_add(self, payload: discord.RawReactionActionEvent): diff --git a/examples/reply.py b/examples/reply.py index b35ac669a..4e9f5c911 100644 --- a/examples/reply.py +++ b/examples/reply.py @@ -1,5 +1,6 @@ import discord + class MyClient(discord.Client): async def on_ready(self): print(f'Logged in as {self.user} (ID: {self.user.id})') diff --git a/examples/secret.py b/examples/secret.py index faf48f535..3d963ec19 100644 --- a/examples/secret.py +++ b/examples/secret.py @@ -5,15 +5,16 @@ from discord.ext import commands bot = commands.Bot(command_prefix=commands.when_mentioned, description="Nothing to see here!", self_bot=True) -# the `hidden` keyword argument hides it from the help command. +# the `hidden` keyword argument hides it from the help command. @bot.group(hidden=True) async def secret(ctx: commands.Context): """What is this "secret" you speak of?""" if ctx.invoked_subcommand is None: await ctx.send('Shh!', delete_after=5) + def create_overwrites(ctx, *objects): - """This is just a helper function that creates the overwrites for the + """This is just a helper function that creates the overwrites for the voice/text channels. A `discord.PermissionOverwrite` allows you to determine the permissions @@ -26,10 +27,7 @@ def create_overwrites(ctx, *objects): # a dict comprehension is being utilised here to set the same permission overwrites # for each `discord.Role` or `discord.Member`. - overwrites = { - obj: discord.PermissionOverwrite(view_channel=True) - for obj in objects - } + overwrites = {obj: discord.PermissionOverwrite(view_channel=True) for obj in objects} # prevents the default role (@everyone) from viewing the channel # if it isn't already allowed to view the channel. @@ -40,15 +38,16 @@ def create_overwrites(ctx, *objects): return overwrites + # since these commands rely on guild related features, # it is best to lock it to be guild-only. @secret.command() @commands.guild_only() async def text(ctx: commands.Context, name: str, *objects: typing.Union[discord.Role, discord.Member]): - """This makes a text channel with a specified name + """This makes a text channel with a specified name that is only visible to roles or members that are specified. """ - + overwrites = create_overwrites(ctx, *objects) await ctx.guild.create_text_channel( @@ -58,6 +57,7 @@ async def text(ctx: commands.Context, name: str, *objects: typing.Union[discord. reason='Very secret business.', ) + @secret.command() @commands.guild_only() async def voice(ctx: commands.Context, name: str, *objects: typing.Union[discord.Role, discord.Member]): @@ -70,9 +70,10 @@ async def voice(ctx: commands.Context, name: str, *objects: typing.Union[discord await ctx.guild.create_voice_channel( name, overwrites=overwrites, - reason='Very secret business.' + reason='Very secret business.', ) + @secret.command() @commands.guild_only() async def emoji(ctx: commands.Context, emoji: discord.PartialEmoji, *roles: discord.Role): @@ -89,7 +90,7 @@ async def emoji(ctx: commands.Context, emoji: discord.PartialEmoji, *roles: disc name=emoji.name, image=emoji_bytes, roles=roles, - reason='Very secret business.' + reason='Very secret business.', ) diff --git a/pyproject.toml b/pyproject.toml index dc48efce4..82bb19f77 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ exclude = [ "docs", ] reportUnnecessaryTypeIgnoreComment = "warning" +reportUnusedImport = "error" pythonVersion = "3.8" typeCheckingMode = "basic" diff --git a/tests/test_files.py b/tests/test_files.py new file mode 100644 index 000000000..6096c3a38 --- /dev/null +++ b/tests/test_files.py @@ -0,0 +1,129 @@ +""" +The MIT License (MIT) + +Copyright (c) 2015-present Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +from io import BytesIO + +import discord + + +FILE = BytesIO() + + +def test_file_with_no_name(): + f = discord.File('.gitignore') + assert f.filename == '.gitignore' + + +def test_io_with_no_name(): + f = discord.File(FILE) + assert f.filename == 'untitled' + + +def test_file_with_name(): + f = discord.File('.gitignore', 'test') + assert f.filename == 'test' + + +def test_io_with_name(): + f = discord.File(FILE, 'test') + assert f.filename == 'test' + + +def test_file_with_no_name_and_spoiler(): + f = discord.File('.gitignore', spoiler=True) + assert f.filename == 'SPOILER_.gitignore' + assert f.spoiler == True + + +def test_file_with_spoiler_name_and_implicit_spoiler(): + f = discord.File('.gitignore', 'SPOILER_.gitignore') + assert f.filename == 'SPOILER_.gitignore' + assert f.spoiler == True + + +def test_file_with_spoiler_name_and_spoiler(): + f = discord.File('.gitignore', 'SPOILER_.gitignore', spoiler=True) + assert f.filename == 'SPOILER_.gitignore' + assert f.spoiler == True + + +def test_file_with_spoiler_name_and_not_spoiler(): + f = discord.File('.gitignore', 'SPOILER_.gitignore', spoiler=False) + assert f.filename == '.gitignore' + assert f.spoiler == False + + +def test_file_with_name_and_double_spoiler_and_implicit_spoiler(): + f = discord.File('.gitignore', 'SPOILER_SPOILER_.gitignore') + assert f.filename == 'SPOILER_.gitignore' + assert f.spoiler == True + + +def test_file_with_name_and_double_spoiler_and_spoiler(): + f = discord.File('.gitignore', 'SPOILER_SPOILER_.gitignore', spoiler=True) + assert f.filename == 'SPOILER_.gitignore' + assert f.spoiler == True + + +def test_file_with_name_and_double_spoiler_and_not_spoiler(): + f = discord.File('.gitignore', 'SPOILER_SPOILER_.gitignore', spoiler=False) + assert f.filename == '.gitignore' + assert f.spoiler == False + + +def test_file_with_spoiler_with_overriding_name_not_spoiler(): + f = discord.File('.gitignore', spoiler=True) + f.filename = '.gitignore' + assert f.filename == '.gitignore' + assert f.spoiler == False + + +def test_file_with_spoiler_with_overriding_name_spoiler(): + f = discord.File('.gitignore', spoiler=True) + f.filename = 'SPOILER_.gitignore' + assert f.filename == 'SPOILER_.gitignore' + assert f.spoiler == True + + +def test_file_not_spoiler_with_overriding_name_not_spoiler(): + f = discord.File('.gitignore') + f.filename = '.gitignore' + assert f.filename == '.gitignore' + assert f.spoiler == False + + +def test_file_not_spoiler_with_overriding_name_spoiler(): + f = discord.File('.gitignore') + f.filename = 'SPOILER_.gitignore' + assert f.filename == 'SPOILER_.gitignore' + assert f.spoiler == True + + +def test_file_not_spoiler_with_overriding_name_double_spoiler(): + f = discord.File('.gitignore') + f.filename = 'SPOILER_SPOILER_.gitignore' + assert f.filename == 'SPOILER_.gitignore' + assert f.spoiler == True diff --git a/tests/test_utils.py b/tests/test_utils.py index b1977da98..5f8060d52 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -164,7 +164,9 @@ def test_resolve_template(url, code): assert utils.resolve_template(url) == code -@pytest.mark.parametrize('mention', ['@everyone', '@here', '<@80088516616269824>', '<@!80088516616269824>', '<@&381978264698224660>']) +@pytest.mark.parametrize( + 'mention', ['@everyone', '@here', '<@80088516616269824>', '<@!80088516616269824>', '<@&381978264698224660>'] +) def test_escape_mentions(mention): assert mention not in utils.escape_mentions(mention) assert mention not in utils.escape_mentions(f"one {mention} two") @@ -198,6 +200,37 @@ def test_resolve_annotation(annotation, resolved): assert resolved == utils.resolve_annotation(annotation, globals(), locals(), None) +@pytest.mark.parametrize( + ('annotation', 'resolved', 'check_cache'), + [ + (datetime.datetime, datetime.datetime, False), + ('datetime.datetime', datetime.datetime, True), + ( + 'typing.Union[typing.Literal["a"], typing.Literal["b"]]', + typing.Union[typing.Literal["a"], typing.Literal["b"]], + True, + ), + ('typing.Union[typing.Union[int, str], typing.Union[bool, dict]]', typing.Union[int, str, bool, dict], True), + ], +) +def test_resolve_annotation_with_cache(annotation, resolved, check_cache): + cache = {} + + assert resolved == utils.resolve_annotation(annotation, globals(), locals(), cache) + + if check_cache: + assert len(cache) == 1 + + cached_item = cache[annotation] + + latest = utils.resolve_annotation(annotation, globals(), locals(), cache) + + assert latest is cached_item + assert typing.get_origin(latest) is typing.get_origin(resolved) + else: + assert len(cache) == 0 + + def test_resolve_annotation_optional_normalisation(): value = utils.resolve_annotation('typing.Union[None, int]', globals(), locals(), None) assert value.__args__ == (int, type(None)) @@ -216,6 +249,30 @@ def test_resolve_annotation_310(annotation, resolved): assert resolved == utils.resolve_annotation(annotation, globals(), locals(), None) +@pytest.mark.skipif(sys.version_info < (3, 10), reason="3.10 union syntax") +@pytest.mark.parametrize( + ('annotation', 'resolved'), + [ + ('int | None', typing.Optional[int]), + ('str | int', typing.Union[str, int]), + ('str | int | None', typing.Optional[typing.Union[str, int]]), + ], +) +def test_resolve_annotation_with_cache_310(annotation, resolved): + cache = {} + + assert resolved == utils.resolve_annotation(annotation, globals(), locals(), cache) + assert typing.get_origin(resolved) is typing.Union + + assert len(cache) == 1 + + cached_item = cache[annotation] + + latest = utils.resolve_annotation(annotation, globals(), locals(), cache) + assert latest is cached_item + assert typing.get_origin(latest) is typing.get_origin(resolved) + + # is_inside_class tests