Browse Source

Type and format abc.py

There's still some stuff missing but this is a decent first pass
pull/7138/head
Rapptz 4 years ago
parent
commit
55c7de82d3
  1. 247
      discord/abc.py

247
discord/abc.py

@ -26,7 +26,21 @@ from __future__ import annotations
import copy import copy
import asyncio import asyncio
from typing import Any, Dict, List, Mapping, Optional, TYPE_CHECKING, Protocol, Type, TypeVar, Union, overload, runtime_checkable from typing import (
Any,
Dict,
List,
Mapping,
Optional,
TYPE_CHECKING,
Protocol,
Tuple,
Type,
TypeVar,
Union,
overload,
runtime_checkable,
)
from .iterators import HistoryIterator from .iterators import HistoryIterator
from .context_managers import Typing from .context_managers import Typing
@ -62,16 +76,24 @@ if TYPE_CHECKING:
from .channel import CategoryChannel from .channel import CategoryChannel
from .embeds import Embed from .embeds import Embed
from .message import Message, MessageReference from .message import Message, MessageReference
from .channel import TextChannel, DMChannel, GroupChannel
from .threads import Thread
from .enums import InviteTarget from .enums import InviteTarget
from .ui.view import View from .ui.view import View
from .types.channel import (
PermissionOverwrite as PermissionOverwritePayload,
GuildChannel as GuildChannelPayload,
OverwriteType,
)
MessageableChannel = Union[TextChannel, Thread, DMChannel, GroupChannel]
SnowflakeTime = Union["Snowflake", datetime] SnowflakeTime = Union["Snowflake", datetime]
MISSING = utils.MISSING MISSING = utils.MISSING
class _Undefined: class _Undefined:
def __repr__(self): def __repr__(self) -> str:
return 'see-below' return 'see-below'
@ -102,6 +124,7 @@ class Snowflake(Protocol):
""":class:`datetime.datetime`: Returns the model's creation time as an aware datetime in UTC.""" """:class:`datetime.datetime`: Returns the model's creation time as an aware datetime in UTC."""
raise NotImplementedError raise NotImplementedError
@runtime_checkable @runtime_checkable
class User(Snowflake, Protocol): class User(Snowflake, Protocol):
"""An ABC that details the common operations on a Discord user. """An ABC that details the common operations on a Discord user.
@ -172,13 +195,13 @@ class _Overwrites:
ROLE = 0 ROLE = 0
MEMBER = 1 MEMBER = 1
def __init__(self, **kwargs): def __init__(self, data: PermissionOverwritePayload):
self.id = kwargs.pop('id') self.id: int = int(data.pop('id'))
self.allow = int(kwargs.pop('allow', 0)) self.allow: int = int(data.pop('allow', 0))
self.deny = int(kwargs.pop('deny', 0)) self.deny: int = int(data.pop('deny', 0))
self.type = kwargs.pop('type') self.type: OverwriteType = data.pop('type')
def _asdict(self): def _asdict(self) -> PermissionOverwritePayload:
return { return {
'id': self.id, 'id': self.id,
'allow': str(self.allow), 'allow': str(self.allow),
@ -208,11 +231,6 @@ class GuildChannel:
This ABC must also implement :class:`~discord.abc.Snowflake`. This ABC must also implement :class:`~discord.abc.Snowflake`.
Note
----
This ABC is not decorated with :func:`typing.runtime_checkable`, so will fail :func:`isinstance`/:func:`issubclass`
checks.
Attributes Attributes
----------- -----------
name: :class:`str` name: :class:`str`
@ -230,7 +248,10 @@ class GuildChannel:
name: str name: str
guild: Guild guild: Guild
type: ChannelType type: ChannelType
position: int
category_id: Optional[int]
_state: ConnectionState _state: ConnectionState
_overwrites: List[_Overwrites]
if TYPE_CHECKING: if TYPE_CHECKING:
@ -254,13 +275,13 @@ class GuildChannel:
lock_permissions: bool = False, lock_permissions: bool = False,
*, *,
reason: Optional[str], reason: Optional[str],
): ) -> None:
if position < 0: if position < 0:
raise InvalidArgument('Channel position cannot be less than 0.') raise InvalidArgument('Channel position cannot be less than 0.')
http = self._state.http http = self._state.http
bucket = self._sorting_bucket bucket = self._sorting_bucket
channels = [c for c in self.guild.channels if c._sorting_bucket == bucket] channels: List[GuildChannel] = [c for c in self.guild.channels if c._sorting_bucket == bucket]
channels.sort(key=lambda c: c.position) channels.sort(key=lambda c: c.position)
@ -277,7 +298,7 @@ class GuildChannel:
payload = [] payload = []
for index, c in enumerate(channels): for index, c in enumerate(channels):
d = {'id': c.id, 'position': index} d: Dict[str, Any] = {'id': c.id, 'position': index}
if parent_id is not _undefined and c.id == self.id: if parent_id is not _undefined and c.id == self.id:
d.update(parent_id=parent_id, lock_permissions=lock_permissions) d.update(parent_id=parent_id, lock_permissions=lock_permissions)
payload.append(d) payload.append(d)
@ -287,7 +308,7 @@ class GuildChannel:
if parent_id is not _undefined: if parent_id is not _undefined:
self.category_id = int(parent_id) if parent_id else None self.category_id = int(parent_id) if parent_id else None
async def _edit(self, options, reason): async def _edit(self, options: Dict[str, Any], reason: Optional[str]):
try: try:
parent = options.pop('category') parent = options.pop('category')
except KeyError: except KeyError:
@ -322,13 +343,15 @@ class GuildChannel:
if parent_id is not _undefined: if parent_id is not _undefined:
if lock_permissions: if lock_permissions:
category = self.guild.get_channel(parent_id) category = self.guild.get_channel(parent_id)
options['permission_overwrites'] = [c._asdict() for c in category._overwrites] if category:
options['permission_overwrites'] = [c._asdict() for c in category._overwrites]
options['parent_id'] = parent_id options['parent_id'] = parent_id
elif lock_permissions and self.category_id is not None: elif lock_permissions and self.category_id is not None:
# if we're syncing permissions on a pre-existing channel category without changing it # if we're syncing permissions on a pre-existing channel category without changing it
# we need to update the permissions to point to the pre-existing category # we need to update the permissions to point to the pre-existing category
category = self.guild.get_channel(self.category_id) category = self.guild.get_channel(self.category_id)
options['permission_overwrites'] = [c._asdict() for c in category._overwrites] if category:
options['permission_overwrites'] = [c._asdict() for c in category._overwrites]
else: else:
await self._move(position, parent_id=parent_id, lock_permissions=lock_permissions, reason=reason) await self._move(position, parent_id=parent_id, lock_permissions=lock_permissions, reason=reason)
@ -367,19 +390,19 @@ class GuildChannel:
data = await self._state.http.edit_channel(self.id, reason=reason, **options) data = await self._state.http.edit_channel(self.id, reason=reason, **options)
self._update(self.guild, data) self._update(self.guild, data)
def _fill_overwrites(self, data): def _fill_overwrites(self, data: GuildChannelPayload) -> None:
self._overwrites = [] self._overwrites = []
everyone_index = 0 everyone_index = 0
everyone_id = self.guild.id everyone_id = self.guild.id
for index, overridden in enumerate(data.get('permission_overwrites', [])): for index, overridden in enumerate(data.get('permission_overwrites', [])):
overridden_id = int(overridden.pop('id')) overwrite = _Overwrites(overridden)
self._overwrites.append(_Overwrites(id=overridden_id, **overridden)) self._overwrites.append(overwrite)
if overridden['type'] == _Overwrites.MEMBER: if overridden['type'] == _Overwrites.MEMBER:
continue continue
if overridden_id == everyone_id: if overwrite.id == everyone_id:
# the @everyone role is not guaranteed to be the first one # the @everyone role is not guaranteed to be the first one
# in the list of permission overwrites, however the permission # in the list of permission overwrites, however the permission
# resolution code kind of requires that it is the first one in # resolution code kind of requires that it is the first one in
@ -488,7 +511,7 @@ class GuildChannel:
If there is no category then this is ``None``. If there is no category then this is ``None``.
""" """
return self.guild.get_channel(self.category_id) return self.guild.get_channel(self.category_id) # type: ignore
@property @property
def permissions_synced(self) -> bool: def permissions_synced(self) -> bool:
@ -499,6 +522,9 @@ class GuildChannel:
.. versionadded:: 1.3 .. versionadded:: 1.3
""" """
if self.category_id is None:
return False
category = self.guild.get_channel(self.category_id) category = self.guild.get_channel(self.category_id)
return bool(category and category.overwrites == self.overwrites) return bool(category and category.overwrites == self.overwrites)
@ -679,14 +705,7 @@ class GuildChannel:
) -> None: ) -> None:
... ...
async def set_permissions( async def set_permissions(self, target, *, overwrite=_undefined, reason=None, **permissions):
self,
target,
*,
overwrite=_undefined,
reason=None,
**permissions
):
r"""|coro| r"""|coro|
Sets the channel specific permission overwrites for a target in the Sets the channel specific permission overwrites for a target in the
@ -801,7 +820,7 @@ class GuildChannel:
obj = cls(state=self._state, guild=self.guild, data=data) obj = cls(state=self._state, guild=self.guild, data=data)
# temporarily add it to the cache # temporarily add it to the cache
self.guild._channels[obj.id] = obj self.guild._channels[obj.id] = obj # type: ignore
return obj return obj
async def clone(self: GCH, *, name: Optional[str] = None, reason: Optional[str] = None) -> GCH: async def clone(self: GCH, *, name: Optional[str] = None, reason: Optional[str] = None) -> GCH:
@ -956,6 +975,7 @@ class GuildChannel:
bucket = self._sorting_bucket bucket = self._sorting_bucket
parent_id = kwargs.get('category', MISSING) parent_id = kwargs.get('category', MISSING)
# fmt: off # fmt: off
channels: List[GuildChannel]
if parent_id not in (MISSING, None): if parent_id not in (MISSING, None):
parent_id = parent_id.id parent_id = parent_id.id
channels = [ channels = [
@ -1017,7 +1037,7 @@ class GuildChannel:
unique: bool = True, unique: bool = True,
target_type: Optional[InviteTarget] = None, target_type: Optional[InviteTarget] = None,
target_user: Optional[User] = None, target_user: Optional[User] = None,
target_application_id: Optional[int] = None target_application_id: Optional[int] = None,
) -> Invite: ) -> Invite:
"""|coro| """|coro|
@ -1045,9 +1065,9 @@ class GuildChannel:
The reason for creating this invite. Shows up on the audit log. The reason for creating this invite. Shows up on the audit log.
target_type: Optional[:class:`.InviteTarget`] target_type: Optional[:class:`.InviteTarget`]
The type of target for the voice channel invite, if any. The type of target for the voice channel invite, if any.
.. versionadded:: 2.0 .. versionadded:: 2.0
target_user: Optional[:class:`User`] target_user: Optional[:class:`User`]
The user whose stream to display for this invite, required if `target_type` is `TargetType.stream`. The user must be streaming in the channel. The user whose stream to display for this invite, required if `target_type` is `TargetType.stream`. The user must be streaming in the channel.
@ -1081,7 +1101,7 @@ class GuildChannel:
unique=unique, unique=unique,
target_type=target_type.value if target_type else None, target_type=target_type.value if target_type else None,
target_user_id=target_user.id if target_user else None, target_user_id=target_user.id if target_user else None,
target_application_id=target_application_id target_application_id=target_application_id,
) )
return Invite.from_incomplete(data=data, state=self._state) return Invite.from_incomplete(data=data, state=self._state)
@ -1111,7 +1131,7 @@ class GuildChannel:
return [Invite(state=state, data=invite, channel=self, guild=guild) for invite in data] return [Invite(state=state, data=invite, channel=self, guild=guild) for invite in data]
class Messageable(Protocol): class Messageable:
"""An ABC that details the common operations on a model that can send messages. """An ABC that details the common operations on a model that can send messages.
The following implement this ABC: The following implement this ABC:
@ -1122,28 +1142,57 @@ class Messageable(Protocol):
- :class:`~discord.User` - :class:`~discord.User`
- :class:`~discord.Member` - :class:`~discord.Member`
- :class:`~discord.ext.commands.Context` - :class:`~discord.ext.commands.Context`
Note
----
This ABC is not decorated with :func:`typing.runtime_checkable`, so will fail :func:`isinstance`/:func:`issubclass`
checks.
""" """
__slots__ = () __slots__ = ()
_state: ConnectionState
async def _get_channel(self): async def _get_channel(self) -> MessageableChannel:
raise NotImplementedError raise NotImplementedError
@overload @overload
async def send( async def send(
self, self,
content: Optional[str] =..., content: Optional[str] = ...,
*, *,
tts: bool = ..., tts: bool = ...,
embed: Embed = ..., embed: Embed = ...,
file: File = ..., file: File = ...,
delete_after: int = ..., delete_after: float = ...,
nonce: Union[str, int] = ...,
allowed_mentions: AllowedMentions = ...,
reference: Union[Message, MessageReference] = ...,
mention_author: bool = ...,
view: View = ...,
) -> Message:
...
@overload
async def send(
self,
content: Optional[str] = ...,
*,
tts: bool = ...,
embed: Embed = ...,
files: List[File] = ...,
delete_after: float = ...,
nonce: Union[str, int] = ...,
allowed_mentions: AllowedMentions = ...,
reference: Union[Message, MessageReference] = ...,
mention_author: bool = ...,
view: View = ...,
) -> Message:
...
@overload
async def send(
self,
content: Optional[str] = ...,
*,
tts: bool = ...,
embeds: List[Embed] = ...,
file: File = ...,
delete_after: float = ...,
nonce: Union[str, int] = ..., nonce: Union[str, int] = ...,
allowed_mentions: AllowedMentions = ..., allowed_mentions: AllowedMentions = ...,
reference: Union[Message, MessageReference] = ..., reference: Union[Message, MessageReference] = ...,
@ -1160,7 +1209,7 @@ class Messageable(Protocol):
tts: bool = ..., tts: bool = ...,
embeds: List[Embed] = ..., embeds: List[Embed] = ...,
files: List[File] = ..., files: List[File] = ...,
delete_after: int = ..., delete_after: float = ...,
nonce: Union[str, int] = ..., nonce: Union[str, int] = ...,
allowed_mentions: AllowedMentions = ..., allowed_mentions: AllowedMentions = ...,
reference: Union[Message, MessageReference] = ..., reference: Union[Message, MessageReference] = ...,
@ -1169,10 +1218,22 @@ class Messageable(Protocol):
) -> Message: ) -> Message:
... ...
async def send(self, content=None, *, tts=False, embed=None, embeds=None, async def send(
file=None, files=None, delete_after=None, self,
nonce=None, allowed_mentions=None, reference=None, content=None,
mention_author=None, view=None): *,
tts=None,
embed=None,
embeds=None,
file=None,
files=None,
delete_after=None,
nonce=None,
allowed_mentions=None,
reference=None,
mention_author=None,
view=None,
):
"""|coro| """|coro|
Sends a message to the destination with the content given. Sends a message to the destination with the content given.
@ -1185,7 +1246,7 @@ class Messageable(Protocol):
single :class:`~discord.File` object. To upload multiple files, the ``files`` single :class:`~discord.File` object. To upload multiple files, the ``files``
parameter should be used with a :class:`list` of :class:`~discord.File` objects. parameter should be used with a :class:`list` of :class:`~discord.File` objects.
**Specifying both parameters will lead to an exception**. **Specifying both parameters will lead to an exception**.
To upload a single embed, the ``embed`` parameter should be used with a To upload a single embed, the ``embed`` parameter should be used with a
single :class:`~discord.Embed` object. To upload multiple embeds, the ``embeds`` single :class:`~discord.Embed` object. To upload multiple embeds, the ``embeds``
parameter should be used with a :class:`list` of :class:`~discord.Embed` objects. parameter should be used with a :class:`list` of :class:`~discord.Embed` objects.
@ -1193,7 +1254,7 @@ class Messageable(Protocol):
Parameters Parameters
------------ ------------
content: :class:`str` content: Optional[:class:`str`]
The content of the message to send. The content of the message to send.
tts: :class:`bool` tts: :class:`bool`
Indicates if the message should be sent using text-to-speech. Indicates if the message should be sent using text-to-speech.
@ -1261,13 +1322,13 @@ class Messageable(Protocol):
channel = await self._get_channel() channel = await self._get_channel()
state = self._state state = self._state
content = str(content) if content is not None else None content = str(content) if content is not None else None
if embed is not None and embeds is not None: if embed is not None and embeds is not None:
raise InvalidArgument('cannot pass both embed and embeds parameter to send()') raise InvalidArgument('cannot pass both embed and embeds parameter to send()')
if embed is not None: if embed is not None:
embed = embed.to_dict() embed = embed.to_dict()
elif embeds is not None: elif embeds is not None:
if len(embeds) > 10: if len(embeds) > 10:
raise InvalidArgument('embeds parameter must be a list of up to 10 elements') raise InvalidArgument('embeds parameter must be a list of up to 10 elements')
@ -1307,9 +1368,18 @@ class Messageable(Protocol):
raise InvalidArgument('file parameter must be File') raise InvalidArgument('file parameter must be File')
try: try:
data = await state.http.send_files(channel.id, files=[file], allowed_mentions=allowed_mentions, data = await state.http.send_files(
content=content, tts=tts, embed=embed, embeds=embeds, channel.id,
nonce=nonce, message_reference=reference, components=components) files=[file],
allowed_mentions=allowed_mentions,
content=content,
tts=tts,
embed=embed,
embeds=embeds,
nonce=nonce,
message_reference=reference,
components=components,
)
finally: finally:
file.close() file.close()
@ -1320,17 +1390,33 @@ class Messageable(Protocol):
raise InvalidArgument('files parameter must be a list of File') raise InvalidArgument('files parameter must be a list of File')
try: try:
data = await state.http.send_files(channel.id, files=files, content=content, tts=tts, data = await state.http.send_files(
embed=embed, embeds=embeds, nonce=nonce, channel.id,
allowed_mentions=allowed_mentions, message_reference=reference, files=files,
components=components) content=content,
tts=tts,
embed=embed,
embeds=embeds,
nonce=nonce,
allowed_mentions=allowed_mentions,
message_reference=reference,
components=components,
)
finally: finally:
for f in files: for f in files:
f.close() f.close()
else: else:
data = await state.http.send_message(channel.id, content, tts=tts, embed=embed, data = await state.http.send_message(
embeds=embeds, nonce=nonce, allowed_mentions=allowed_mentions, channel.id,
message_reference=reference, components=components) content,
tts=tts,
embed=embed,
embeds=embeds,
nonce=nonce,
allowed_mentions=allowed_mentions,
message_reference=reference,
components=components,
)
ret = state.create_message(channel=channel, data=data) ret = state.create_message(channel=channel, data=data)
if view: if view:
@ -1340,7 +1426,7 @@ class Messageable(Protocol):
await ret.delete(delay=delete_after) await ret.delete(delay=delete_after)
return ret return ret
async def trigger_typing(self): async def trigger_typing(self) -> None:
"""|coro| """|coro|
Triggers a *typing* indicator to the destination. Triggers a *typing* indicator to the destination.
@ -1351,7 +1437,7 @@ class Messageable(Protocol):
channel = await self._get_channel() channel = await self._get_channel()
await self._state.http.send_typing(channel.id) await self._state.http.send_typing(channel.id)
def typing(self): def typing(self) -> Typing:
"""Returns a context manager that allows you to type for an indefinite period of time. """Returns a context manager that allows you to type for an indefinite period of time.
This is useful for denoting long computations in your bot. This is useful for denoting long computations in your bot.
@ -1362,8 +1448,8 @@ class Messageable(Protocol):
This means that both ``with`` and ``async with`` work with this. This means that both ``with`` and ``async with`` work with this.
Example Usage: :: Example Usage: ::
async with channel.typing(): async with channel.typing():
# simulate something heavy # simulate something heavy
await asyncio.sleep(10) await asyncio.sleep(10)
await channel.send('done!') await channel.send('done!')
@ -1371,7 +1457,7 @@ class Messageable(Protocol):
""" """
return Typing(self) return Typing(self)
async def fetch_message(self, id): async def fetch_message(self, id: int, /) -> Message:
"""|coro| """|coro|
Retrieves a single :class:`~discord.Message` from the destination. Retrieves a single :class:`~discord.Message` from the destination.
@ -1400,7 +1486,7 @@ class Messageable(Protocol):
data = await self._state.http.get_message(channel.id, id) data = await self._state.http.get_message(channel.id, id)
return self._state.create_message(channel=channel, data=data) return self._state.create_message(channel=channel, data=data)
async def pins(self): async def pins(self) -> List[Message]:
"""|coro| """|coro|
Retrieves all messages that are currently pinned in the channel. Retrieves all messages that are currently pinned in the channel.
@ -1427,7 +1513,15 @@ class Messageable(Protocol):
data = await state.http.pins_from(channel.id) data = await state.http.pins_from(channel.id)
return [state.create_message(channel=channel, data=m) for m in data] return [state.create_message(channel=channel, data=m) for m in data]
def history(self, *, limit=100, before=None, after=None, around=None, oldest_first=None): def history(
self,
*,
limit: Optional[int] = 100,
before: Optional[SnowflakeTime] = None,
after: Optional[SnowflakeTime] = None,
around: Optional[SnowflakeTime] = None,
oldest_first: Optional[bool] = None,
) -> HistoryIterator:
"""Returns an :class:`~discord.AsyncIterator` that enables receiving the destination's message history. """Returns an :class:`~discord.AsyncIterator` that enables receiving the destination's message history.
You must have :attr:`~discord.Permissions.read_message_history` permissions to use this. You must have :attr:`~discord.Permissions.read_message_history` permissions to use this.
@ -1504,11 +1598,12 @@ class Connectable(Protocol):
""" """
__slots__ = () __slots__ = ()
_state: ConnectionState
def _get_voice_client_key(self): def _get_voice_client_key(self) -> Tuple[int, str]:
raise NotImplementedError raise NotImplementedError
def _get_voice_state_pair(self): def _get_voice_state_pair(self) -> Tuple[int, int]:
raise NotImplementedError raise NotImplementedError
async def connect(self, *, timeout: float = 60.0, reconnect: bool = True, cls: Type[T] = VoiceClient) -> T: async def connect(self, *, timeout: float = 60.0, reconnect: bool = True, cls: Type[T] = VoiceClient) -> T:

Loading…
Cancel
Save