diff --git a/discord/abc.py b/discord/abc.py index b94f9a717..df43dae52 100644 --- a/discord/abc.py +++ b/discord/abc.py @@ -26,7 +26,21 @@ from __future__ import annotations import copy import asyncio -from typing import Any, Dict, List, Mapping, Optional, TYPE_CHECKING, Protocol, Type, TypeVar, Union, overload, runtime_checkable +from typing import ( + Any, + Dict, + List, + Mapping, + Optional, + TYPE_CHECKING, + Protocol, + Tuple, + Type, + TypeVar, + Union, + overload, + runtime_checkable, +) from .iterators import HistoryIterator from .context_managers import Typing @@ -62,16 +76,24 @@ if TYPE_CHECKING: from .channel import CategoryChannel from .embeds import Embed from .message import Message, MessageReference + from .channel import TextChannel, DMChannel, GroupChannel + from .threads import Thread from .enums import InviteTarget from .ui.view import View + from .types.channel import ( + PermissionOverwrite as PermissionOverwritePayload, + GuildChannel as GuildChannelPayload, + OverwriteType, + ) + MessageableChannel = Union[TextChannel, Thread, DMChannel, GroupChannel] SnowflakeTime = Union["Snowflake", datetime] MISSING = utils.MISSING class _Undefined: - def __repr__(self): + def __repr__(self) -> str: return 'see-below' @@ -102,6 +124,7 @@ class Snowflake(Protocol): """:class:`datetime.datetime`: Returns the model's creation time as an aware datetime in UTC.""" raise NotImplementedError + @runtime_checkable class User(Snowflake, Protocol): """An ABC that details the common operations on a Discord user. @@ -172,13 +195,13 @@ class _Overwrites: ROLE = 0 MEMBER = 1 - def __init__(self, **kwargs): - self.id = kwargs.pop('id') - self.allow = int(kwargs.pop('allow', 0)) - self.deny = int(kwargs.pop('deny', 0)) - self.type = kwargs.pop('type') + def __init__(self, data: PermissionOverwritePayload): + self.id: int = int(data.pop('id')) + self.allow: int = int(data.pop('allow', 0)) + self.deny: int = int(data.pop('deny', 0)) + self.type: OverwriteType = data.pop('type') - def _asdict(self): + def _asdict(self) -> PermissionOverwritePayload: return { 'id': self.id, 'allow': str(self.allow), @@ -208,11 +231,6 @@ class GuildChannel: This ABC must also implement :class:`~discord.abc.Snowflake`. - Note - ---- - This ABC is not decorated with :func:`typing.runtime_checkable`, so will fail :func:`isinstance`/:func:`issubclass` - checks. - Attributes ----------- name: :class:`str` @@ -230,7 +248,10 @@ class GuildChannel: name: str guild: Guild type: ChannelType + position: int + category_id: Optional[int] _state: ConnectionState + _overwrites: List[_Overwrites] if TYPE_CHECKING: @@ -254,13 +275,13 @@ class GuildChannel: lock_permissions: bool = False, *, reason: Optional[str], - ): + ) -> None: if position < 0: raise InvalidArgument('Channel position cannot be less than 0.') http = self._state.http bucket = self._sorting_bucket - channels = [c for c in self.guild.channels if c._sorting_bucket == bucket] + channels: List[GuildChannel] = [c for c in self.guild.channels if c._sorting_bucket == bucket] channels.sort(key=lambda c: c.position) @@ -277,7 +298,7 @@ class GuildChannel: payload = [] for index, c in enumerate(channels): - d = {'id': c.id, 'position': index} + d: Dict[str, Any] = {'id': c.id, 'position': index} if parent_id is not _undefined and c.id == self.id: d.update(parent_id=parent_id, lock_permissions=lock_permissions) payload.append(d) @@ -287,7 +308,7 @@ class GuildChannel: if parent_id is not _undefined: self.category_id = int(parent_id) if parent_id else None - async def _edit(self, options, reason): + async def _edit(self, options: Dict[str, Any], reason: Optional[str]): try: parent = options.pop('category') except KeyError: @@ -322,13 +343,15 @@ class GuildChannel: if parent_id is not _undefined: if lock_permissions: category = self.guild.get_channel(parent_id) - options['permission_overwrites'] = [c._asdict() for c in category._overwrites] + if category: + options['permission_overwrites'] = [c._asdict() for c in category._overwrites] options['parent_id'] = parent_id elif lock_permissions and self.category_id is not None: # if we're syncing permissions on a pre-existing channel category without changing it # we need to update the permissions to point to the pre-existing category category = self.guild.get_channel(self.category_id) - options['permission_overwrites'] = [c._asdict() for c in category._overwrites] + if category: + options['permission_overwrites'] = [c._asdict() for c in category._overwrites] else: await self._move(position, parent_id=parent_id, lock_permissions=lock_permissions, reason=reason) @@ -367,19 +390,19 @@ class GuildChannel: data = await self._state.http.edit_channel(self.id, reason=reason, **options) self._update(self.guild, data) - def _fill_overwrites(self, data): + def _fill_overwrites(self, data: GuildChannelPayload) -> None: self._overwrites = [] everyone_index = 0 everyone_id = self.guild.id for index, overridden in enumerate(data.get('permission_overwrites', [])): - overridden_id = int(overridden.pop('id')) - self._overwrites.append(_Overwrites(id=overridden_id, **overridden)) + overwrite = _Overwrites(overridden) + self._overwrites.append(overwrite) if overridden['type'] == _Overwrites.MEMBER: continue - if overridden_id == everyone_id: + if overwrite.id == everyone_id: # the @everyone role is not guaranteed to be the first one # in the list of permission overwrites, however the permission # resolution code kind of requires that it is the first one in @@ -488,7 +511,7 @@ class GuildChannel: If there is no category then this is ``None``. """ - return self.guild.get_channel(self.category_id) + return self.guild.get_channel(self.category_id) # type: ignore @property def permissions_synced(self) -> bool: @@ -499,6 +522,9 @@ class GuildChannel: .. versionadded:: 1.3 """ + if self.category_id is None: + return False + category = self.guild.get_channel(self.category_id) return bool(category and category.overwrites == self.overwrites) @@ -679,14 +705,7 @@ class GuildChannel: ) -> None: ... - async def set_permissions( - self, - target, - *, - overwrite=_undefined, - reason=None, - **permissions - ): + async def set_permissions(self, target, *, overwrite=_undefined, reason=None, **permissions): r"""|coro| Sets the channel specific permission overwrites for a target in the @@ -801,7 +820,7 @@ class GuildChannel: obj = cls(state=self._state, guild=self.guild, data=data) # temporarily add it to the cache - self.guild._channels[obj.id] = obj + self.guild._channels[obj.id] = obj # type: ignore return obj async def clone(self: GCH, *, name: Optional[str] = None, reason: Optional[str] = None) -> GCH: @@ -956,6 +975,7 @@ class GuildChannel: bucket = self._sorting_bucket parent_id = kwargs.get('category', MISSING) # fmt: off + channels: List[GuildChannel] if parent_id not in (MISSING, None): parent_id = parent_id.id channels = [ @@ -1017,7 +1037,7 @@ class GuildChannel: unique: bool = True, target_type: Optional[InviteTarget] = None, target_user: Optional[User] = None, - target_application_id: Optional[int] = None + target_application_id: Optional[int] = None, ) -> Invite: """|coro| @@ -1045,9 +1065,9 @@ class GuildChannel: The reason for creating this invite. Shows up on the audit log. target_type: Optional[:class:`.InviteTarget`] The type of target for the voice channel invite, if any. - + .. versionadded:: 2.0 - + target_user: Optional[:class:`User`] The user whose stream to display for this invite, required if `target_type` is `TargetType.stream`. The user must be streaming in the channel. @@ -1081,7 +1101,7 @@ class GuildChannel: unique=unique, target_type=target_type.value if target_type else None, target_user_id=target_user.id if target_user else None, - target_application_id=target_application_id + target_application_id=target_application_id, ) return Invite.from_incomplete(data=data, state=self._state) @@ -1111,7 +1131,7 @@ class GuildChannel: return [Invite(state=state, data=invite, channel=self, guild=guild) for invite in data] -class Messageable(Protocol): +class Messageable: """An ABC that details the common operations on a model that can send messages. The following implement this ABC: @@ -1122,28 +1142,57 @@ class Messageable(Protocol): - :class:`~discord.User` - :class:`~discord.Member` - :class:`~discord.ext.commands.Context` - - - Note - ---- - This ABC is not decorated with :func:`typing.runtime_checkable`, so will fail :func:`isinstance`/:func:`issubclass` - checks. """ __slots__ = () + _state: ConnectionState - async def _get_channel(self): + async def _get_channel(self) -> MessageableChannel: raise NotImplementedError @overload async def send( self, - content: Optional[str] =..., + content: Optional[str] = ..., *, tts: bool = ..., embed: Embed = ..., file: File = ..., - delete_after: int = ..., + delete_after: float = ..., + nonce: Union[str, int] = ..., + allowed_mentions: AllowedMentions = ..., + reference: Union[Message, MessageReference] = ..., + mention_author: bool = ..., + view: View = ..., + ) -> Message: + ... + + @overload + async def send( + self, + content: Optional[str] = ..., + *, + tts: bool = ..., + embed: Embed = ..., + files: List[File] = ..., + delete_after: float = ..., + nonce: Union[str, int] = ..., + allowed_mentions: AllowedMentions = ..., + reference: Union[Message, MessageReference] = ..., + mention_author: bool = ..., + view: View = ..., + ) -> Message: + ... + + @overload + async def send( + self, + content: Optional[str] = ..., + *, + tts: bool = ..., + embeds: List[Embed] = ..., + file: File = ..., + delete_after: float = ..., nonce: Union[str, int] = ..., allowed_mentions: AllowedMentions = ..., reference: Union[Message, MessageReference] = ..., @@ -1160,7 +1209,7 @@ class Messageable(Protocol): tts: bool = ..., embeds: List[Embed] = ..., files: List[File] = ..., - delete_after: int = ..., + delete_after: float = ..., nonce: Union[str, int] = ..., allowed_mentions: AllowedMentions = ..., reference: Union[Message, MessageReference] = ..., @@ -1169,10 +1218,22 @@ class Messageable(Protocol): ) -> Message: ... - async def send(self, content=None, *, tts=False, embed=None, embeds=None, - file=None, files=None, delete_after=None, - nonce=None, allowed_mentions=None, reference=None, - mention_author=None, view=None): + async def send( + self, + content=None, + *, + tts=None, + embed=None, + embeds=None, + file=None, + files=None, + delete_after=None, + nonce=None, + allowed_mentions=None, + reference=None, + mention_author=None, + view=None, + ): """|coro| Sends a message to the destination with the content given. @@ -1185,7 +1246,7 @@ class Messageable(Protocol): single :class:`~discord.File` object. To upload multiple files, the ``files`` parameter should be used with a :class:`list` of :class:`~discord.File` objects. **Specifying both parameters will lead to an exception**. - + To upload a single embed, the ``embed`` parameter should be used with a single :class:`~discord.Embed` object. To upload multiple embeds, the ``embeds`` parameter should be used with a :class:`list` of :class:`~discord.Embed` objects. @@ -1193,7 +1254,7 @@ class Messageable(Protocol): Parameters ------------ - content: :class:`str` + content: Optional[:class:`str`] The content of the message to send. tts: :class:`bool` Indicates if the message should be sent using text-to-speech. @@ -1261,13 +1322,13 @@ class Messageable(Protocol): channel = await self._get_channel() state = self._state content = str(content) if content is not None else None - + if embed is not None and embeds is not None: raise InvalidArgument('cannot pass both embed and embeds parameter to send()') - + if embed is not None: embed = embed.to_dict() - + elif embeds is not None: if len(embeds) > 10: raise InvalidArgument('embeds parameter must be a list of up to 10 elements') @@ -1307,9 +1368,18 @@ class Messageable(Protocol): raise InvalidArgument('file parameter must be File') try: - data = await state.http.send_files(channel.id, files=[file], allowed_mentions=allowed_mentions, - content=content, tts=tts, embed=embed, embeds=embeds, - nonce=nonce, message_reference=reference, components=components) + data = await state.http.send_files( + channel.id, + files=[file], + allowed_mentions=allowed_mentions, + content=content, + tts=tts, + embed=embed, + embeds=embeds, + nonce=nonce, + message_reference=reference, + components=components, + ) finally: file.close() @@ -1320,17 +1390,33 @@ class Messageable(Protocol): raise InvalidArgument('files parameter must be a list of File') try: - data = await state.http.send_files(channel.id, files=files, content=content, tts=tts, - embed=embed, embeds=embeds, nonce=nonce, - allowed_mentions=allowed_mentions, message_reference=reference, - components=components) + data = await state.http.send_files( + channel.id, + files=files, + content=content, + tts=tts, + embed=embed, + embeds=embeds, + nonce=nonce, + allowed_mentions=allowed_mentions, + message_reference=reference, + components=components, + ) finally: for f in files: f.close() else: - data = await state.http.send_message(channel.id, content, tts=tts, embed=embed, - embeds=embeds, nonce=nonce, allowed_mentions=allowed_mentions, - message_reference=reference, components=components) + data = await state.http.send_message( + channel.id, + content, + tts=tts, + embed=embed, + embeds=embeds, + nonce=nonce, + allowed_mentions=allowed_mentions, + message_reference=reference, + components=components, + ) ret = state.create_message(channel=channel, data=data) if view: @@ -1340,7 +1426,7 @@ class Messageable(Protocol): await ret.delete(delay=delete_after) return ret - async def trigger_typing(self): + async def trigger_typing(self) -> None: """|coro| Triggers a *typing* indicator to the destination. @@ -1351,7 +1437,7 @@ class Messageable(Protocol): channel = await self._get_channel() await self._state.http.send_typing(channel.id) - def typing(self): + def typing(self) -> Typing: """Returns a context manager that allows you to type for an indefinite period of time. This is useful for denoting long computations in your bot. @@ -1362,8 +1448,8 @@ class Messageable(Protocol): This means that both ``with`` and ``async with`` work with this. Example Usage: :: - async with channel.typing(): - # simulate something heavy + async with channel.typing(): + # simulate something heavy await asyncio.sleep(10) await channel.send('done!') @@ -1371,7 +1457,7 @@ class Messageable(Protocol): """ return Typing(self) - async def fetch_message(self, id): + async def fetch_message(self, id: int, /) -> Message: """|coro| Retrieves a single :class:`~discord.Message` from the destination. @@ -1400,7 +1486,7 @@ class Messageable(Protocol): data = await self._state.http.get_message(channel.id, id) return self._state.create_message(channel=channel, data=data) - async def pins(self): + async def pins(self) -> List[Message]: """|coro| Retrieves all messages that are currently pinned in the channel. @@ -1427,7 +1513,15 @@ class Messageable(Protocol): data = await state.http.pins_from(channel.id) return [state.create_message(channel=channel, data=m) for m in data] - def history(self, *, limit=100, before=None, after=None, around=None, oldest_first=None): + def history( + self, + *, + limit: Optional[int] = 100, + before: Optional[SnowflakeTime] = None, + after: Optional[SnowflakeTime] = None, + around: Optional[SnowflakeTime] = None, + oldest_first: Optional[bool] = None, + ) -> HistoryIterator: """Returns an :class:`~discord.AsyncIterator` that enables receiving the destination's message history. You must have :attr:`~discord.Permissions.read_message_history` permissions to use this. @@ -1504,11 +1598,12 @@ class Connectable(Protocol): """ __slots__ = () + _state: ConnectionState - def _get_voice_client_key(self): + def _get_voice_client_key(self) -> Tuple[int, str]: raise NotImplementedError - def _get_voice_state_pair(self): + def _get_voice_state_pair(self) -> Tuple[int, int]: raise NotImplementedError async def connect(self, *, timeout: float = 60.0, reconnect: bool = True, cls: Type[T] = VoiceClient) -> T: