diff --git a/discord/user.py b/discord/user.py index fba241b48..df5ff864c 100644 --- a/discord/user.py +++ b/discord/user.py @@ -22,19 +22,36 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from typing import Any, Dict, Optional, TYPE_CHECKING +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Type, TypeVar, TYPE_CHECKING + import discord.abc +from .asset import Asset +from .colour import Colour +from .enums import DefaultAvatar from .flags import PublicUserFlags from .utils import snowflake_time, _bytes_to_base64_data, MISSING -from .enums import DefaultAvatar -from .colour import Colour -from .asset import Asset + +if TYPE_CHECKING: + from datetime import datetime + + from .channel import DMChannel + from .guild import Guild + from .message import Message + from .state import ConnectionState + from .types.channel import DMChannel as DMChannelPayload + from .types.user import User as UserPayload + __all__ = ( 'User', 'ClientUser', ) +U = TypeVar('U', bound='User') +BU = TypeVar('BU', bound='BaseUser') + class _UserTag: __slots__ = () @@ -50,30 +67,35 @@ class BaseUser(_UserTag): discriminator: str bot: bool system: bool + _state: ConnectionState + _avatar: str + _banner: Optional[str] + _accent_colour: Optional[str] + _public_flags: int - def __init__(self, *, state, data): + def __init__(self, *, state: ConnectionState, data: UserPayload) -> None: self._state = state self._update(data) - def __repr__(self): + def __repr__(self) -> str: return ( f"" ) - def __str__(self): + def __str__(self) -> str: return f'{self.name}#{self.discriminator}' - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return isinstance(other, _UserTag) and other.id == self.id - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self.__eq__(other) - def __hash__(self): + def __hash__(self) -> int: return self.id >> 22 - def _update(self, data): + def _update(self, data: UserPayload) -> None: self.name = data['username'] self.id = int(data['id']) self.discriminator = data['discriminator'] @@ -85,7 +107,7 @@ class BaseUser(_UserTag): self.system = data.get('system', False) @classmethod - def _copy(cls, user): + def _copy(cls: Type[BU], user: BU) -> BU: self = cls.__new__(cls) # bypass __init__ self.name = user.name @@ -100,7 +122,7 @@ class BaseUser(_UserTag): return self - def _to_minimal_user_json(self): + def _to_minimal_user_json(self) -> Dict[str, Any]: return { 'username': self.name, 'id': self.id, @@ -110,12 +132,12 @@ class BaseUser(_UserTag): } @property - def public_flags(self): + def public_flags(self) -> PublicUserFlags: """:class:`PublicUserFlags`: The publicly available flags the user has.""" return PublicUserFlags._from_value(self._public_flags) @property - def avatar(self): + def avatar(self) -> Asset: """:class:`Asset`: Returns an :class:`Asset` for the avatar the user has. If the user does not have a traditional avatar, an asset for @@ -127,7 +149,7 @@ class BaseUser(_UserTag): return Asset._from_avatar(self._state, self.id, self._avatar) @property - def default_avatar(self): + def default_avatar(self) -> Asset: """:class:`Asset`: Returns the default avatar for a given user. This is calculated by the user's discriminator.""" return Asset._from_default_avatar(self._state, int(self.discriminator) % len(DefaultAvatar)) @@ -176,7 +198,7 @@ class BaseUser(_UserTag): return self.accent_colour @property - def colour(self): + def colour(self) -> Colour: """:class:`Colour`: A property that returns a colour denoting the rendered colour for the user. This always returns :meth:`Colour.default`. @@ -185,7 +207,7 @@ class BaseUser(_UserTag): return Colour.default() @property - def color(self): + def color(self) -> Colour: """:class:`Colour`: A property that returns a color denoting the rendered color for the user. This always returns :meth:`Colour.default`. @@ -194,12 +216,12 @@ class BaseUser(_UserTag): return self.colour @property - def mention(self): + def mention(self) -> str: """:class:`str`: Returns a string that allows you to mention the given user.""" return f'<@{self.id}>' @property - def created_at(self): + def created_at(self) -> datetime: """:class:`datetime.datetime`: Returns the user's creation time in UTC. This is when the user's Discord account was created. @@ -207,7 +229,7 @@ class BaseUser(_UserTag): return snowflake_time(self.id) @property - def display_name(self): + def display_name(self) -> str: """:class:`str`: Returns the user's display name. For regular users this is just their username, but @@ -216,7 +238,7 @@ class BaseUser(_UserTag): """ return self.name - def mentioned_in(self, message): + def mentioned_in(self, message: Message) -> bool: """Checks if the user is mentioned in the specified message. Parameters @@ -282,16 +304,22 @@ class ClientUser(BaseUser): __slots__ = ('locale', '_flags', 'verified', 'mfa_enabled', '__weakref__') - def __init__(self, *, state, data): + if TYPE_CHECKING: + verified: bool + local: Optional[str] + mfa_enabled: bool + _flags: int + + def __init__(self, *, state: ConnectionState, data: UserPayload) -> None: super().__init__(state=state, data=data) - def __repr__(self): + def __repr__(self) -> str: return ( f'' ) - def _update(self, data): + def _update(self, data: UserPayload) -> None: super()._update(data) # There's actually an Optional[str] phone field as well but I won't use it self.verified = data.get('verified', False) @@ -335,7 +363,7 @@ class ClientUser(BaseUser): if avatar is not MISSING: payload['avatar'] = _bytes_to_base64_data(avatar) - data = await self._state.http.edit_profile(payload) + data: UserPayload = await self._state.http.edit_profile(payload) self._update(data) @@ -376,11 +404,14 @@ class User(BaseUser, discord.abc.Messageable): __slots__ = ('_stored',) - def __init__(self, *, state, data): + if TYPE_CHECKING: + _stored: bool + + def __init__(self, *, state: ConnectionState, data: UserPayload) -> None: super().__init__(state=state, data=data) self._stored = False - def __repr__(self): + def __repr__(self) -> str: return f'' def __del__(self) -> None: @@ -391,17 +422,17 @@ class User(BaseUser, discord.abc.Messageable): pass @classmethod - def _copy(cls, user): + def _copy(cls: Type[U], user: U) -> U: self = super()._copy(user) self._stored = False return self - async def _get_channel(self): + async def _get_channel(self) -> DMChannel: ch = await self.create_dm() return ch @property - def dm_channel(self): + def dm_channel(self) -> Optional[DMChannel]: """Optional[:class:`DMChannel`]: Returns the channel associated with this user if it exists. If this returns ``None``, you can create a DM channel by calling the @@ -410,7 +441,7 @@ class User(BaseUser, discord.abc.Messageable): return self._state._get_private_channel_by_user(self.id) @property - def mutual_guilds(self): + def mutual_guilds(self) -> List[Guild]: """List[:class:`Guild`]: The guilds that the user shares with the client. .. note:: @@ -421,7 +452,7 @@ class User(BaseUser, discord.abc.Messageable): """ return [guild for guild in self._state._guilds.values() if guild.get_member(self.id)] - async def create_dm(self): + async def create_dm(self) -> DMChannel: """|coro| Creates a :class:`DMChannel` with this user. @@ -439,5 +470,5 @@ class User(BaseUser, discord.abc.Messageable): return found state = self._state - data = await state.http.start_private_message(self.id) + data: DMChannelPayload = await state.http.start_private_message(self.id) return state.add_dm_channel(data)