From 346f447da6bb6bfe724098d3c0f39b8f17e11c95 Mon Sep 17 00:00:00 2001 From: dolfies Date: Sat, 2 Apr 2022 12:12:57 -0400 Subject: [PATCH] Fix typing issues, make ClientUser relationship properties return relationships --- discord/client.py | 10 +++++----- discord/team.py | 6 +++--- discord/types/user.py | 1 + discord/user.py | 35 +++++++++++++++++++++-------------- 4 files changed, 30 insertions(+), 22 deletions(-) diff --git a/discord/client.py b/discord/client.py index cdef462c8..1ffd7571e 100644 --- a/discord/client.py +++ b/discord/client.py @@ -1370,10 +1370,10 @@ class Client: status_str = str(status) activities_tuple = tuple(a.to_dict() for a in activities) self._client_status._this = str(status) - self._client_activities['this'] = activities_tuple + self._client_activities['this'] = activities_tuple # type: ignore if self._session_count <= 1: self._client_status._status = status_str - self._client_activities[None] = self._client_activities['this'] = activities_tuple + self._client_activities[None] = self._client_activities['this'] = activities_tuple # type: ignore async def change_voice_state( self, @@ -2258,15 +2258,15 @@ class Client: return GroupChannel(me=self.user, data=data, state=state) # type: ignore - user is always present when logged in @overload - async def send_friend_request(self, user: BaseUser) -> Relationship: + async def send_friend_request(self, user: BaseUser, /) -> Relationship: ... @overload - async def send_friend_request(self, user: str) -> Relationship: + async def send_friend_request(self, user: str, /) -> Relationship: ... @overload - async def send_friend_request(self, username: str, discriminator: str) -> Relationship: + async def send_friend_request(self, username: str, discriminator: str, /) -> Relationship: ... async def send_friend_request(self, *args: Union[BaseUser, str]) -> Relationship: diff --git a/discord/team.py b/discord/team.py index 92b0cc271..cf71d3585 100644 --- a/discord/team.py +++ b/discord/team.py @@ -170,15 +170,15 @@ class Team: return members @overload - async def invite_member(self, user: BaseUser) -> TeamMember: + async def invite_member(self, user: BaseUser, /) -> TeamMember: ... @overload - async def invite_member(self, user: str) -> TeamMember: + async def invite_member(self, user: str, /) -> TeamMember: ... @overload - async def invite_member(self, username: str, discriminator: str) -> TeamMember: + async def invite_member(self, username: str, discriminator: str, /) -> TeamMember: ... async def invite_member(self, *args: Union[BaseUser, str]) -> TeamMember: diff --git a/discord/types/user.py b/discord/types/user.py index 50776d347..699322345 100644 --- a/discord/types/user.py +++ b/discord/types/user.py @@ -51,3 +51,4 @@ class User(PartialUser, total=False): bio: str analytics_token: str phone: Optional[str] + token: str \ No newline at end of file diff --git a/discord/user.py b/discord/user.py index 1081e9255..0d3f20304 100644 --- a/discord/user.py +++ b/discord/user.py @@ -24,7 +24,7 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations -from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union import discord.abc from .asset import Asset @@ -52,9 +52,10 @@ if TYPE_CHECKING: from datetime import datetime - from .abc import Snowflake as _Snowflake + from .abc import Snowflake as _Snowflake, T as ConnectReturn from .calls import PrivateCall from .channel import DMChannel + from .client import Client from .member import VoiceState from .message import Message from .profile import UserProfile @@ -544,7 +545,7 @@ class ClientUser(BaseUser): self.bio = data.get('bio') or None self.nsfw_allowed = data.get('nsfw_allowed', False) - def get_relationship(self, user_id: int) -> Relationship: + def get_relationship(self, user_id: int) -> Optional[Relationship]: """Retrieves the :class:`Relationship` if applicable. Parameters @@ -572,12 +573,12 @@ class ClientUser(BaseUser): @property def friends(self) -> List[Relationship]: r"""List[:class:`User`]: Returns all the users that the user is friends with.""" - return [r.user for r in self._state._relationships.values() if r.type is RelationshipType.friend] + return [r for r in self._state._relationships.values() if r.type is RelationshipType.friend] @property def blocked(self) -> List[Relationship]: r"""List[:class:`User`]: Returns all the users that the user has blocked.""" - return [r.user for r in self._state._relationships.values() if r.type is RelationshipType.blocked] + return [r for r in self._state._relationships.values() if r.type is RelationshipType.blocked] @property def settings(self) -> Optional[UserSettings]: @@ -842,10 +843,10 @@ class User(BaseUser, discord.abc.Connectable, discord.abc.Messageable): return f'<{self.__class__.__name__} id={self.id} name={self.name!r} discriminator={self.discriminator!r} bot={self.bot} system={self.system}>' def _get_voice_client_key(self) -> Tuple[int, str]: - return self._state.self_id, 'self_id' + return self._state.self_id, 'self_id' # type: ignore - self_id is always set at this point def _get_voice_state_pair(self) -> Tuple[int, int]: - return self._state.self_id, self.dm_channel.id + return self._state.self_id, self.dm_channel.id # type: ignore - self_id is always set at this point async def _get_channel(self) -> DMChannel: ch = await self.create_dm() @@ -867,16 +868,22 @@ class User(BaseUser, discord.abc.Connectable, discord.abc.Messageable): @property def relationship(self) -> Optional[Relationship]: """Optional[:class:`Relationship`]: Returns the :class:`Relationship` with this user if applicable, ``None`` otherwise.""" - return self._state.user.get_relationship(self.id) + return self._state.user.get_relationship(self.id) # type: ignore - user is always present when logged in - async def connect(self, *, ring=True, **kwargs): + @copy_doc(discord.abc.Connectable.connect) + async def connect( + self, + *, + timeout: float = 60.0, + reconnect: bool = True, + cls: Callable[[Client, discord.abc.Connectable], ConnectReturn] = MISSING, + ring: bool = True, + ) -> ConnectReturn: channel = await self._get_channel() - call = self.call - if call is not None: - ring = False - await super().connect(_channel=channel, **kwargs) - if ring: + call = channel.call + if call is None and ring: await channel._initial_ring() + return await super().connect(timeout=timeout, reconnect=reconnect, cls=cls, _channel=channel) async def create_dm(self) -> DMChannel: """|coro|