Browse Source

Type hint GuildChannel and don't make it a Protocol

This reverts GuildChannel back into a base class mixin.
pull/6874/head
Rapptz 4 years ago
parent
commit
c31946f29f
  1. 163
      discord/abc.py
  2. 22
      discord/http.py

163
discord/abc.py

@ -26,7 +26,7 @@ from __future__ import annotations
import copy
import asyncio
from typing import TYPE_CHECKING, Protocol, runtime_checkable
from typing import Any, Dict, List, Mapping, Optional, TYPE_CHECKING, Protocol, TypeVar, Union, overload, runtime_checkable
from .iterators import HistoryIterator
from .context_managers import Typing
@ -54,16 +54,21 @@ if TYPE_CHECKING:
from .user import ClientUser
from .asset import Asset
from .state import ConnectionState
from .guild import Guild
from .member import Member
from .channel import CategoryChannel
MISSING = utils.MISSING
class _Undefined:
def __repr__(self):
return 'see-below'
_undefined = _Undefined()
_undefined: Any = _Undefined()
@runtime_checkable
@ -81,6 +86,7 @@ class Snowflake(Protocol):
id: :class:`int`
The model's unique ID.
"""
__slots__ = ()
id: int
@ -113,6 +119,7 @@ class User(Snowflake, Protocol):
bot: :class:`bool`
If the user is a bot account.
"""
__slots__ = ()
name: str
@ -147,6 +154,7 @@ class PrivateChannel(Snowflake, Protocol):
me: :class:`~discord.ClientUser`
The user presenting yourself.
"""
__slots__ = ()
me: ClientUser
@ -179,7 +187,10 @@ class _Overwrites:
return self.type == 1
class GuildChannel(Snowflake, Protocol):
GCH = TypeVar('GCH', bound='GuildChannel')
class GuildChannel:
"""An ABC that details the common operations on a Discord guild channel.
The following implement this ABC:
@ -206,16 +217,38 @@ class GuildChannel(Snowflake, Protocol):
The position in the channel list. This is a number that starts at 0.
e.g. the top channel is position 0.
"""
__slots__ = ()
def __str__(self):
id: int
name: str
guild: Guild
type: ChannelType
_state: ConnectionState
if TYPE_CHECKING:
def __init__(self, *, state: ConnectionState, guild: Guild, data: Dict[str, Any]):
...
def __str__(self) -> str:
return self.name
@property
def _sorting_bucket(self):
def _sorting_bucket(self) -> int:
raise NotImplementedError
def _update(self, guild: Guild, data: Dict[str, Any]) -> None:
raise NotImplementedError
async def _move(self, position, parent_id=None, lock_permissions=False, *, reason):
async def _move(
self,
position: int,
parent_id: Optional[Any] = None,
lock_permissions: bool = False,
*,
reason: Optional[str],
):
if position < 0:
raise InvalidArgument('Channel position cannot be less than 0.')
@ -304,7 +337,7 @@ class GuildChannel(Snowflake, Protocol):
payload = {
'allow': allow.value,
'deny': deny.value,
'id': target.id
'id': target.id,
}
if isinstance(target, Role):
@ -354,7 +387,7 @@ class GuildChannel(Snowflake, Protocol):
tmp[everyone_index], tmp[0] = tmp[0], tmp[everyone_index]
@property
def changed_roles(self):
def changed_roles(self) -> List[Role]:
"""List[:class:`~discord.Role`]: Returns a list of roles that have been overridden from
their default values in the :attr:`~discord.Guild.roles` attribute."""
ret = []
@ -370,16 +403,16 @@ class GuildChannel(Snowflake, Protocol):
return ret
@property
def mention(self):
def mention(self) -> str:
""":class:`str`: The string that allows you to mention the channel."""
return f'<#{self.id}>'
@property
def created_at(self):
def created_at(self) -> datetime:
""":class:`datetime.datetime`: Returns the channel's creation time in UTC."""
return utils.snowflake_time(self.id)
def overwrites_for(self, obj):
def overwrites_for(self, obj: Union[Role, User]) -> PermissionOverwrite:
"""Returns the channel-specific overwrites for a member or a role.
Parameters
@ -410,7 +443,7 @@ class GuildChannel(Snowflake, Protocol):
return PermissionOverwrite()
@property
def overwrites(self):
def overwrites(self) -> Mapping[Union[Role, Member], PermissionOverwrite]:
"""Returns all of the channel's overwrites.
This is returned as a dictionary where the key contains the target which
@ -427,6 +460,7 @@ class GuildChannel(Snowflake, Protocol):
allow = Permissions(ow.allow)
deny = Permissions(ow.deny)
overwrite = PermissionOverwrite.from_pair(allow, deny)
target = None
if ow.is_role():
target = self.guild.get_role(ow.id)
@ -443,7 +477,7 @@ class GuildChannel(Snowflake, Protocol):
return ret
@property
def category(self):
def category(self) -> Optional[CategoryChannel]:
"""Optional[:class:`~discord.CategoryChannel`]: The category this channel belongs to.
If there is no category then this is ``None``.
@ -451,7 +485,7 @@ class GuildChannel(Snowflake, Protocol):
return self.guild.get_channel(self.category_id)
@property
def permissions_synced(self):
def permissions_synced(self) -> bool:
""":class:`bool`: Whether or not the permissions for this channel are synced with the
category it belongs to.
@ -462,7 +496,7 @@ class GuildChannel(Snowflake, Protocol):
category = self.guild.get_channel(self.category_id)
return bool(category and category.overwrites == self.overwrites)
def permissions_for(self, obj, /):
def permissions_for(self, obj: Union[Member, Role], /) -> Permissions:
"""Handles permission resolution for the :class:`~discord.Member`
or :class:`~discord.Role`.
@ -595,7 +629,7 @@ class GuildChannel(Snowflake, Protocol):
return base
async def delete(self, *, reason=None):
async def delete(self, *, reason: Optional[str] = None) -> None:
"""|coro|
Deletes the channel.
@ -619,7 +653,14 @@ class GuildChannel(Snowflake, Protocol):
"""
await self._state.http.delete_channel(self.id, reason=reason)
async def set_permissions(self, target, *, overwrite=_undefined, reason=None, **permissions):
async def set_permissions(
self,
target: Union[Member, Role],
*,
overwrite: Optional[PermissionOverwrite] = _undefined,
reason: Optional[str] = None,
**permissions: bool,
) -> None:
r"""|coro|
Sets the channel specific permission overwrites for a target in the
@ -714,10 +755,14 @@ class GuildChannel(Snowflake, Protocol):
else:
raise InvalidArgument('Invalid overwrite type provided.')
async def _clone_impl(self, base_attrs, *, name=None, reason=None):
base_attrs['permission_overwrites'] = [
x._asdict() for x in self._overwrites
]
async def _clone_impl(
self: GCH,
base_attrs: Dict[str, Any],
*,
name: Optional[str] = None,
reason: Optional[str] = None,
) -> GCH:
base_attrs['permission_overwrites'] = [x._asdict() for x in self._overwrites]
base_attrs['parent_id'] = self.category_id
base_attrs['name'] = name or self.name
guild_id = self.guild.id
@ -729,7 +774,7 @@ class GuildChannel(Snowflake, Protocol):
self.guild._channels[obj.id] = obj
return obj
async def clone(self, *, name=None, reason=None):
async def clone(self: GCH, *, name: Optional[str] = None, reason: Optional[str] = None) -> GCH:
"""|coro|
Clones this channel. This creates a channel with the same properties
@ -762,7 +807,55 @@ class GuildChannel(Snowflake, Protocol):
"""
raise NotImplementedError
async def move(self, **kwargs):
@overload
async def move(
self,
*,
beginning: bool,
offset: int = MISSING,
category: Optional[Snowflake] = MISSING,
sync_permissions: bool = MISSING,
reason: str = MISSING,
) -> None:
...
@overload
async def move(
self,
*,
end: bool,
offset: int = MISSING,
category: Optional[Snowflake] = MISSING,
sync_permissions: bool = MISSING,
reason: str = MISSING,
) -> None:
...
@overload
async def move(
self,
*,
before: Snowflake,
offset: int = MISSING,
category: Optional[Snowflake] = MISSING,
sync_permissions: bool = MISSING,
reason: str = MISSING,
) -> None:
...
@overload
async def move(
self,
*,
after: Snowflake,
offset: int = MISSING,
category: Optional[Snowflake] = MISSING,
sync_permissions: bool = MISSING,
reason: str = MISSING,
) -> None:
...
async def move(self, **kwargs) -> None:
"""|coro|
A rich interface to help move a channel relative to other channels.
@ -832,6 +925,7 @@ class GuildChannel(Snowflake, Protocol):
bucket = self._sorting_bucket
parent_id = kwargs.get('category', MISSING)
# fmt: off
if parent_id not in (MISSING, None):
parent_id = parent_id.id
channels = [
@ -847,6 +941,7 @@ class GuildChannel(Snowflake, Protocol):
if ch._sorting_bucket == bucket
and ch.category_id == self.category_id
]
# fmt: on
channels.sort(key=lambda c: (c.position, c.id))
@ -882,7 +977,15 @@ class GuildChannel(Snowflake, Protocol):
await self._state.http.bulk_channel_update(self.guild.id, payload, reason=reason)
async def create_invite(self, *, reason=None, **fields):
async def create_invite(
self,
*,
reason: Optional[str] = None,
max_age: int = 0,
max_uses: int = 0,
temporary: bool = False,
unique: bool = True,
) -> Invite:
"""|coro|
Creates an instant invite from a text or voice channel.
@ -922,10 +1025,17 @@ class GuildChannel(Snowflake, Protocol):
The invite that was created.
"""
data = await self._state.http.create_invite(self.id, reason=reason, **fields)
data = await self._state.http.create_invite(
self.id,
reason=reason,
max_age=max_age,
max_uses=max_uses,
temporary=temporary,
unique=unique,
)
return Invite.from_incomplete(data=data, state=self._state)
async def invites(self):
async def invites(self) -> List[Invite]:
"""|coro|
Returns a list of all active instant invites from this channel.
@ -1283,6 +1393,7 @@ class Connectable(Protocol):
This ABC is not decorated with :func:`typing.runtime_checkable`, so will fail :func:`isinstance`/:func:`issubclass`
checks.
"""
__slots__ = ()
def _get_voice_client_key(self):

22
discord/http.py

@ -28,7 +28,7 @@ import asyncio
import json
import logging
import sys
from typing import Any, Coroutine, List, TYPE_CHECKING, TypeVar
from typing import Any, Coroutine, List, Optional, TYPE_CHECKING, TypeVar
from urllib.parse import quote as _uriquote
import weakref
@ -43,6 +43,7 @@ log = logging.getLogger(__name__)
if TYPE_CHECKING:
from .types import (
interactions,
invite,
)
T = TypeVar('T')
@ -966,13 +967,22 @@ class HTTPClient:
# Invite management
def create_invite(self, channel_id, *, reason=None, **options):
def create_invite(
self,
channel_id: int,
*,
reason: Optional[str] = None,
max_age: int = 0,
max_uses: int = 0,
temporary: bool = False,
unique: bool = True,
) -> Response[invite.Invite]:
r = Route('POST', '/channels/{channel_id}/invites', channel_id=channel_id)
payload = {
'max_age': options.get('max_age', 0),
'max_uses': options.get('max_uses', 0),
'temporary': options.get('temporary', False),
'unique': options.get('unique', True),
'max_age': max_age,
'max_uses': max_uses,
'temporary': temporary,
'unique': unique,
}
return self.request(r, reason=reason, json=payload)

Loading…
Cancel
Save