Browse Source

Type-hint user.py

pull/7373/head
thetimtoy 4 years ago
committed by GitHub
parent
commit
529fad6fec
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 99
      discord/user.py

99
discord/user.py

@ -22,19 +22,36 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from typing import Any, Dict, Optional, TYPE_CHECKING
from __future__ import annotations
from typing import Any, Dict, List, Optional, Type, TypeVar, TYPE_CHECKING
import discord.abc
from .asset import Asset
from .colour import Colour
from .enums import DefaultAvatar
from .flags import PublicUserFlags
from .utils import snowflake_time, _bytes_to_base64_data, MISSING
from .enums import DefaultAvatar
from .colour import Colour
from .asset import Asset
if TYPE_CHECKING:
from datetime import datetime
from .channel import DMChannel
from .guild import Guild
from .message import Message
from .state import ConnectionState
from .types.channel import DMChannel as DMChannelPayload
from .types.user import User as UserPayload
__all__ = (
'User',
'ClientUser',
)
U = TypeVar('U', bound='User')
BU = TypeVar('BU', bound='BaseUser')
class _UserTag:
__slots__ = ()
@ -50,30 +67,35 @@ class BaseUser(_UserTag):
discriminator: str
bot: bool
system: bool
_state: ConnectionState
_avatar: str
_banner: Optional[str]
_accent_colour: Optional[str]
_public_flags: int
def __init__(self, *, state, data):
def __init__(self, *, state: ConnectionState, data: UserPayload) -> None:
self._state = state
self._update(data)
def __repr__(self):
def __repr__(self) -> str:
return (
f"<BaseUser id={self.id} name={self.name!r} discriminator={self.discriminator!r}"
f" bot={self.bot} system={self.system}>"
)
def __str__(self):
def __str__(self) -> str:
return f'{self.name}#{self.discriminator}'
def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
return isinstance(other, _UserTag) and other.id == self.id
def __ne__(self, other):
def __ne__(self, other: Any) -> bool:
return not self.__eq__(other)
def __hash__(self):
def __hash__(self) -> int:
return self.id >> 22
def _update(self, data):
def _update(self, data: UserPayload) -> None:
self.name = data['username']
self.id = int(data['id'])
self.discriminator = data['discriminator']
@ -85,7 +107,7 @@ class BaseUser(_UserTag):
self.system = data.get('system', False)
@classmethod
def _copy(cls, user):
def _copy(cls: Type[BU], user: BU) -> BU:
self = cls.__new__(cls) # bypass __init__
self.name = user.name
@ -100,7 +122,7 @@ class BaseUser(_UserTag):
return self
def _to_minimal_user_json(self):
def _to_minimal_user_json(self) -> Dict[str, Any]:
return {
'username': self.name,
'id': self.id,
@ -110,12 +132,12 @@ class BaseUser(_UserTag):
}
@property
def public_flags(self):
def public_flags(self) -> PublicUserFlags:
""":class:`PublicUserFlags`: The publicly available flags the user has."""
return PublicUserFlags._from_value(self._public_flags)
@property
def avatar(self):
def avatar(self) -> Asset:
""":class:`Asset`: Returns an :class:`Asset` for the avatar the user has.
If the user does not have a traditional avatar, an asset for
@ -127,7 +149,7 @@ class BaseUser(_UserTag):
return Asset._from_avatar(self._state, self.id, self._avatar)
@property
def default_avatar(self):
def default_avatar(self) -> Asset:
""":class:`Asset`: Returns the default avatar for a given user. This is calculated by the user's discriminator."""
return Asset._from_default_avatar(self._state, int(self.discriminator) % len(DefaultAvatar))
@ -176,7 +198,7 @@ class BaseUser(_UserTag):
return self.accent_colour
@property
def colour(self):
def colour(self) -> Colour:
""":class:`Colour`: A property that returns a colour denoting the rendered colour
for the user. This always returns :meth:`Colour.default`.
@ -185,7 +207,7 @@ class BaseUser(_UserTag):
return Colour.default()
@property
def color(self):
def color(self) -> Colour:
""":class:`Colour`: A property that returns a color denoting the rendered color
for the user. This always returns :meth:`Colour.default`.
@ -194,12 +216,12 @@ class BaseUser(_UserTag):
return self.colour
@property
def mention(self):
def mention(self) -> str:
""":class:`str`: Returns a string that allows you to mention the given user."""
return f'<@{self.id}>'
@property
def created_at(self):
def created_at(self) -> datetime:
""":class:`datetime.datetime`: Returns the user's creation time in UTC.
This is when the user's Discord account was created.
@ -207,7 +229,7 @@ class BaseUser(_UserTag):
return snowflake_time(self.id)
@property
def display_name(self):
def display_name(self) -> str:
""":class:`str`: Returns the user's display name.
For regular users this is just their username, but
@ -216,7 +238,7 @@ class BaseUser(_UserTag):
"""
return self.name
def mentioned_in(self, message):
def mentioned_in(self, message: Message) -> bool:
"""Checks if the user is mentioned in the specified message.
Parameters
@ -282,16 +304,22 @@ class ClientUser(BaseUser):
__slots__ = ('locale', '_flags', 'verified', 'mfa_enabled', '__weakref__')
def __init__(self, *, state, data):
if TYPE_CHECKING:
verified: bool
local: Optional[str]
mfa_enabled: bool
_flags: int
def __init__(self, *, state: ConnectionState, data: UserPayload) -> None:
super().__init__(state=state, data=data)
def __repr__(self):
def __repr__(self) -> str:
return (
f'<ClientUser id={self.id} name={self.name!r} discriminator={self.discriminator!r}'
f' bot={self.bot} verified={self.verified} mfa_enabled={self.mfa_enabled}>'
)
def _update(self, data):
def _update(self, data: UserPayload) -> None:
super()._update(data)
# There's actually an Optional[str] phone field as well but I won't use it
self.verified = data.get('verified', False)
@ -335,7 +363,7 @@ class ClientUser(BaseUser):
if avatar is not MISSING:
payload['avatar'] = _bytes_to_base64_data(avatar)
data = await self._state.http.edit_profile(payload)
data: UserPayload = await self._state.http.edit_profile(payload)
self._update(data)
@ -376,11 +404,14 @@ class User(BaseUser, discord.abc.Messageable):
__slots__ = ('_stored',)
def __init__(self, *, state, data):
if TYPE_CHECKING:
_stored: bool
def __init__(self, *, state: ConnectionState, data: UserPayload) -> None:
super().__init__(state=state, data=data)
self._stored = False
def __repr__(self):
def __repr__(self) -> str:
return f'<User id={self.id} name={self.name!r} discriminator={self.discriminator!r} bot={self.bot}>'
def __del__(self) -> None:
@ -391,17 +422,17 @@ class User(BaseUser, discord.abc.Messageable):
pass
@classmethod
def _copy(cls, user):
def _copy(cls: Type[U], user: U) -> U:
self = super()._copy(user)
self._stored = False
return self
async def _get_channel(self):
async def _get_channel(self) -> DMChannel:
ch = await self.create_dm()
return ch
@property
def dm_channel(self):
def dm_channel(self) -> Optional[DMChannel]:
"""Optional[:class:`DMChannel`]: Returns the channel associated with this user if it exists.
If this returns ``None``, you can create a DM channel by calling the
@ -410,7 +441,7 @@ class User(BaseUser, discord.abc.Messageable):
return self._state._get_private_channel_by_user(self.id)
@property
def mutual_guilds(self):
def mutual_guilds(self) -> List[Guild]:
"""List[:class:`Guild`]: The guilds that the user shares with the client.
.. note::
@ -421,7 +452,7 @@ class User(BaseUser, discord.abc.Messageable):
"""
return [guild for guild in self._state._guilds.values() if guild.get_member(self.id)]
async def create_dm(self):
async def create_dm(self) -> DMChannel:
"""|coro|
Creates a :class:`DMChannel` with this user.
@ -439,5 +470,5 @@ class User(BaseUser, discord.abc.Messageable):
return found
state = self._state
data = await state.http.start_private_message(self.id)
data: DMChannelPayload = await state.http.start_private_message(self.id)
return state.add_dm_channel(data)

Loading…
Cancel
Save