From c31946f29f6695c703740bcc67cbf3f766bf7fc7 Mon Sep 17 00:00:00 2001 From: Rapptz Date: Wed, 5 May 2021 11:14:58 -0400 Subject: [PATCH] Type hint GuildChannel and don't make it a Protocol This reverts GuildChannel back into a base class mixin. --- discord/abc.py | 163 ++++++++++++++++++++++++++++++++++++++++-------- discord/http.py | 22 +++++-- 2 files changed, 153 insertions(+), 32 deletions(-) diff --git a/discord/abc.py b/discord/abc.py index d7d965af2..3132ca736 100644 --- a/discord/abc.py +++ b/discord/abc.py @@ -26,7 +26,7 @@ from __future__ import annotations import copy import asyncio -from typing import TYPE_CHECKING, Protocol, runtime_checkable +from typing import Any, Dict, List, Mapping, Optional, TYPE_CHECKING, Protocol, TypeVar, Union, overload, runtime_checkable from .iterators import HistoryIterator from .context_managers import Typing @@ -54,16 +54,21 @@ if TYPE_CHECKING: from .user import ClientUser from .asset import Asset + from .state import ConnectionState + from .guild import Guild + from .member import Member + from .channel import CategoryChannel MISSING = utils.MISSING + class _Undefined: def __repr__(self): return 'see-below' -_undefined = _Undefined() +_undefined: Any = _Undefined() @runtime_checkable @@ -81,6 +86,7 @@ class Snowflake(Protocol): id: :class:`int` The model's unique ID. """ + __slots__ = () id: int @@ -113,6 +119,7 @@ class User(Snowflake, Protocol): bot: :class:`bool` If the user is a bot account. """ + __slots__ = () name: str @@ -147,6 +154,7 @@ class PrivateChannel(Snowflake, Protocol): me: :class:`~discord.ClientUser` The user presenting yourself. """ + __slots__ = () me: ClientUser @@ -179,7 +187,10 @@ class _Overwrites: return self.type == 1 -class GuildChannel(Snowflake, Protocol): +GCH = TypeVar('GCH', bound='GuildChannel') + + +class GuildChannel: """An ABC that details the common operations on a Discord guild channel. The following implement this ABC: @@ -206,16 +217,38 @@ class GuildChannel(Snowflake, Protocol): The position in the channel list. This is a number that starts at 0. e.g. the top channel is position 0. """ + __slots__ = () - def __str__(self): + id: int + name: str + guild: Guild + type: ChannelType + _state: ConnectionState + + if TYPE_CHECKING: + + def __init__(self, *, state: ConnectionState, guild: Guild, data: Dict[str, Any]): + ... + + def __str__(self) -> str: return self.name @property - def _sorting_bucket(self): + def _sorting_bucket(self) -> int: + raise NotImplementedError + + def _update(self, guild: Guild, data: Dict[str, Any]) -> None: raise NotImplementedError - async def _move(self, position, parent_id=None, lock_permissions=False, *, reason): + async def _move( + self, + position: int, + parent_id: Optional[Any] = None, + lock_permissions: bool = False, + *, + reason: Optional[str], + ): if position < 0: raise InvalidArgument('Channel position cannot be less than 0.') @@ -304,7 +337,7 @@ class GuildChannel(Snowflake, Protocol): payload = { 'allow': allow.value, 'deny': deny.value, - 'id': target.id + 'id': target.id, } if isinstance(target, Role): @@ -354,7 +387,7 @@ class GuildChannel(Snowflake, Protocol): tmp[everyone_index], tmp[0] = tmp[0], tmp[everyone_index] @property - def changed_roles(self): + def changed_roles(self) -> List[Role]: """List[:class:`~discord.Role`]: Returns a list of roles that have been overridden from their default values in the :attr:`~discord.Guild.roles` attribute.""" ret = [] @@ -370,16 +403,16 @@ class GuildChannel(Snowflake, Protocol): return ret @property - def mention(self): + def mention(self) -> str: """:class:`str`: The string that allows you to mention the channel.""" return f'<#{self.id}>' @property - def created_at(self): + def created_at(self) -> datetime: """:class:`datetime.datetime`: Returns the channel's creation time in UTC.""" return utils.snowflake_time(self.id) - def overwrites_for(self, obj): + def overwrites_for(self, obj: Union[Role, User]) -> PermissionOverwrite: """Returns the channel-specific overwrites for a member or a role. Parameters @@ -410,7 +443,7 @@ class GuildChannel(Snowflake, Protocol): return PermissionOverwrite() @property - def overwrites(self): + def overwrites(self) -> Mapping[Union[Role, Member], PermissionOverwrite]: """Returns all of the channel's overwrites. This is returned as a dictionary where the key contains the target which @@ -427,6 +460,7 @@ class GuildChannel(Snowflake, Protocol): allow = Permissions(ow.allow) deny = Permissions(ow.deny) overwrite = PermissionOverwrite.from_pair(allow, deny) + target = None if ow.is_role(): target = self.guild.get_role(ow.id) @@ -443,7 +477,7 @@ class GuildChannel(Snowflake, Protocol): return ret @property - def category(self): + def category(self) -> Optional[CategoryChannel]: """Optional[:class:`~discord.CategoryChannel`]: The category this channel belongs to. If there is no category then this is ``None``. @@ -451,7 +485,7 @@ class GuildChannel(Snowflake, Protocol): return self.guild.get_channel(self.category_id) @property - def permissions_synced(self): + def permissions_synced(self) -> bool: """:class:`bool`: Whether or not the permissions for this channel are synced with the category it belongs to. @@ -462,7 +496,7 @@ class GuildChannel(Snowflake, Protocol): category = self.guild.get_channel(self.category_id) return bool(category and category.overwrites == self.overwrites) - def permissions_for(self, obj, /): + def permissions_for(self, obj: Union[Member, Role], /) -> Permissions: """Handles permission resolution for the :class:`~discord.Member` or :class:`~discord.Role`. @@ -595,7 +629,7 @@ class GuildChannel(Snowflake, Protocol): return base - async def delete(self, *, reason=None): + async def delete(self, *, reason: Optional[str] = None) -> None: """|coro| Deletes the channel. @@ -619,7 +653,14 @@ class GuildChannel(Snowflake, Protocol): """ await self._state.http.delete_channel(self.id, reason=reason) - async def set_permissions(self, target, *, overwrite=_undefined, reason=None, **permissions): + async def set_permissions( + self, + target: Union[Member, Role], + *, + overwrite: Optional[PermissionOverwrite] = _undefined, + reason: Optional[str] = None, + **permissions: bool, + ) -> None: r"""|coro| Sets the channel specific permission overwrites for a target in the @@ -714,10 +755,14 @@ class GuildChannel(Snowflake, Protocol): else: raise InvalidArgument('Invalid overwrite type provided.') - async def _clone_impl(self, base_attrs, *, name=None, reason=None): - base_attrs['permission_overwrites'] = [ - x._asdict() for x in self._overwrites - ] + async def _clone_impl( + self: GCH, + base_attrs: Dict[str, Any], + *, + name: Optional[str] = None, + reason: Optional[str] = None, + ) -> GCH: + base_attrs['permission_overwrites'] = [x._asdict() for x in self._overwrites] base_attrs['parent_id'] = self.category_id base_attrs['name'] = name or self.name guild_id = self.guild.id @@ -729,7 +774,7 @@ class GuildChannel(Snowflake, Protocol): self.guild._channels[obj.id] = obj return obj - async def clone(self, *, name=None, reason=None): + async def clone(self: GCH, *, name: Optional[str] = None, reason: Optional[str] = None) -> GCH: """|coro| Clones this channel. This creates a channel with the same properties @@ -762,7 +807,55 @@ class GuildChannel(Snowflake, Protocol): """ raise NotImplementedError - async def move(self, **kwargs): + @overload + async def move( + self, + *, + beginning: bool, + offset: int = MISSING, + category: Optional[Snowflake] = MISSING, + sync_permissions: bool = MISSING, + reason: str = MISSING, + ) -> None: + ... + + @overload + async def move( + self, + *, + end: bool, + offset: int = MISSING, + category: Optional[Snowflake] = MISSING, + sync_permissions: bool = MISSING, + reason: str = MISSING, + ) -> None: + ... + + @overload + async def move( + self, + *, + before: Snowflake, + offset: int = MISSING, + category: Optional[Snowflake] = MISSING, + sync_permissions: bool = MISSING, + reason: str = MISSING, + ) -> None: + ... + + @overload + async def move( + self, + *, + after: Snowflake, + offset: int = MISSING, + category: Optional[Snowflake] = MISSING, + sync_permissions: bool = MISSING, + reason: str = MISSING, + ) -> None: + ... + + async def move(self, **kwargs) -> None: """|coro| A rich interface to help move a channel relative to other channels. @@ -832,6 +925,7 @@ class GuildChannel(Snowflake, Protocol): bucket = self._sorting_bucket parent_id = kwargs.get('category', MISSING) + # fmt: off if parent_id not in (MISSING, None): parent_id = parent_id.id channels = [ @@ -847,6 +941,7 @@ class GuildChannel(Snowflake, Protocol): if ch._sorting_bucket == bucket and ch.category_id == self.category_id ] + # fmt: on channels.sort(key=lambda c: (c.position, c.id)) @@ -882,7 +977,15 @@ class GuildChannel(Snowflake, Protocol): await self._state.http.bulk_channel_update(self.guild.id, payload, reason=reason) - async def create_invite(self, *, reason=None, **fields): + async def create_invite( + self, + *, + reason: Optional[str] = None, + max_age: int = 0, + max_uses: int = 0, + temporary: bool = False, + unique: bool = True, + ) -> Invite: """|coro| Creates an instant invite from a text or voice channel. @@ -922,10 +1025,17 @@ class GuildChannel(Snowflake, Protocol): The invite that was created. """ - data = await self._state.http.create_invite(self.id, reason=reason, **fields) + data = await self._state.http.create_invite( + self.id, + reason=reason, + max_age=max_age, + max_uses=max_uses, + temporary=temporary, + unique=unique, + ) return Invite.from_incomplete(data=data, state=self._state) - async def invites(self): + async def invites(self) -> List[Invite]: """|coro| Returns a list of all active instant invites from this channel. @@ -1283,6 +1393,7 @@ class Connectable(Protocol): This ABC is not decorated with :func:`typing.runtime_checkable`, so will fail :func:`isinstance`/:func:`issubclass` checks. """ + __slots__ = () def _get_voice_client_key(self): diff --git a/discord/http.py b/discord/http.py index dcd6131e2..471d829e6 100644 --- a/discord/http.py +++ b/discord/http.py @@ -28,7 +28,7 @@ import asyncio import json import logging import sys -from typing import Any, Coroutine, List, TYPE_CHECKING, TypeVar +from typing import Any, Coroutine, List, Optional, TYPE_CHECKING, TypeVar from urllib.parse import quote as _uriquote import weakref @@ -43,6 +43,7 @@ log = logging.getLogger(__name__) if TYPE_CHECKING: from .types import ( interactions, + invite, ) T = TypeVar('T') @@ -966,13 +967,22 @@ class HTTPClient: # Invite management - def create_invite(self, channel_id, *, reason=None, **options): + def create_invite( + self, + channel_id: int, + *, + reason: Optional[str] = None, + max_age: int = 0, + max_uses: int = 0, + temporary: bool = False, + unique: bool = True, + ) -> Response[invite.Invite]: r = Route('POST', '/channels/{channel_id}/invites', channel_id=channel_id) payload = { - 'max_age': options.get('max_age', 0), - 'max_uses': options.get('max_uses', 0), - 'temporary': options.get('temporary', False), - 'unique': options.get('unique', True), + 'max_age': max_age, + 'max_uses': max_uses, + 'temporary': temporary, + 'unique': unique, } return self.request(r, reason=reason, json=payload)