Browse Source

Typehint Member and various typing fixes

pull/7139/head
Rapptz 4 years ago
parent
commit
1aeec34f84
  1. 15
      discord/activity.py
  2. 2
      discord/guild.py
  3. 246
      discord/member.py
  4. 7
      discord/message.py
  5. 6
      discord/types/activity.py
  6. 15
      discord/types/member.py
  7. 4
      discord/types/message.py

15
discord/activity.py

@ -25,7 +25,7 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
import datetime
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union, overload
from .asset import Asset
from .enums import ActivityType, try_enum
@ -92,6 +92,7 @@ t.ActivityFlags = {
if TYPE_CHECKING:
from .types.activity import (
Activity as ActivityPayload,
ActivityTimestamps,
ActivityParty,
ActivityAssets,
@ -801,7 +802,17 @@ class CustomActivity(BaseActivity):
return f'<CustomActivity name={self.name!r} emoji={self.emoji!r}>'
def create_activity(data: Optional[Dict[str, Any]]) -> Optional[Union[Activity, Game, CustomActivity, Streaming, Spotify]]:
ActivityTypes = Union[Activity, Game, CustomActivity, Streaming, Spotify]
@overload
def create_activity(data: ActivityPayload) -> ActivityTypes:
...
@overload
def create_activity(data: None) -> None:
...
def create_activity(data: Optional[ActivityPayload]) -> Optional[ActivityTypes]:
if not data:
return None

2
discord/guild.py

@ -454,7 +454,7 @@ class Guild(Hashable):
user_id = int(presence['user']['id'])
member = self.get_member(user_id)
if member is not None:
member._presence_update(presence, empty_tuple)
member._presence_update(presence, empty_tuple) # type: ignore
if 'channels' in data:
channels = data['channels']

246
discord/member.py

@ -29,14 +29,14 @@ import inspect
import itertools
import sys
from operator import attrgetter
from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Union, overload
from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union, overload
import discord.abc
from . import utils
from .utils import MISSING
from .user import BaseUser, User
from .activity import create_activity
from .activity import create_activity, ActivityTypes
from .permissions import Permissions
from .enums import Status, try_enum
from .colour import Colour
@ -49,10 +49,23 @@ __all__ = (
if TYPE_CHECKING:
from .channel import VoiceChannel, StageChannel
from .guild import Guild
from .types.activity import PartialPresenceUpdate
from .types.member import (
GatewayMember as GatewayMemberPayload,
Member as MemberPayload,
UserWithMember as UserWithMemberPayload,
)
from .types.user import User as UserPayload
from .abc import Snowflake
from .state import ConnectionState
from .message import Message
from .role import Role
from .types.voice import VoiceState as VoiceStatePayload
VocalGuildChannel = Union[VoiceChannel, StageChannel]
class VoiceState:
"""Represents a Discord user's voice state.
@ -96,38 +109,49 @@ class VoiceState:
is not currently in a voice channel.
"""
__slots__ = ('session_id', 'deaf', 'mute', 'self_mute',
'self_stream', 'self_video', 'self_deaf', 'afk', 'channel',
'requested_to_speak_at', 'suppress')
def __init__(self, *, data, channel=None):
self.session_id = data.get('session_id')
__slots__ = (
'session_id',
'deaf',
'mute',
'self_mute',
'self_stream',
'self_video',
'self_deaf',
'afk',
'channel',
'requested_to_speak_at',
'suppress',
)
def __init__(self, *, data: VoiceStatePayload, channel: Optional[VocalGuildChannel] = None):
self.session_id: str = data.get('session_id')
self._update(data, channel)
def _update(self, data, channel):
self.self_mute = data.get('self_mute', False)
self.self_deaf = data.get('self_deaf', False)
self.self_stream = data.get('self_stream', False)
self.self_video = data.get('self_video', False)
self.afk = data.get('suppress', False)
self.mute = data.get('mute', False)
self.deaf = data.get('deaf', False)
self.suppress = data.get('suppress', False)
self.requested_to_speak_at = utils.parse_time(data.get('request_to_speak_timestamp'))
self.channel = channel
def __repr__(self):
def _update(self, data: VoiceStatePayload, channel: Optional[VocalGuildChannel]):
self.self_mute: bool = data.get('self_mute', False)
self.self_deaf: bool = data.get('self_deaf', False)
self.self_stream: bool = data.get('self_stream', False)
self.self_video: bool = data.get('self_video', False)
self.afk: bool = data.get('suppress', False)
self.mute: bool = data.get('mute', False)
self.deaf: bool = data.get('deaf', False)
self.suppress: bool = data.get('suppress', False)
self.requested_to_speak_at: Optional[datetime.datetime] = utils.parse_time(data.get('request_to_speak_timestamp'))
self.channel: Optional[VocalGuildChannel] = channel
def __repr__(self) -> str:
attrs = [
('self_mute', self.self_mute),
('self_deaf', self.self_deaf),
('self_stream', self.self_stream),
('suppress', self.suppress),
('requested_to_speak_at', self.requested_to_speak_at),
('channel', self.channel)
('channel', self.channel),
]
inner = ' '.join('%s=%r' % t for t in attrs)
return f'<{self.__class__.__name__} {inner}>'
def flatten_user(cls):
for attr, value in itertools.chain(BaseUser.__dict__.items(), User.__dict__.items()):
# ignore private/special methods
@ -151,9 +175,12 @@ def flatten_user(cls):
def generate_function(x):
# We want sphinx to properly show coroutine functions as coroutines
if inspect.iscoroutinefunction(value):
async def general(self, *args, **kwargs):
async def general(self, *args, **kwargs): # type: ignore
return await getattr(self._user, x)(*args, **kwargs)
else:
def general(self, *args, **kwargs):
return getattr(self._user, x)(*args, **kwargs)
@ -166,8 +193,12 @@ def flatten_user(cls):
return cls
_BaseUser = discord.abc.User
M = TypeVar('M', bound='Member')
@flatten_user
class Member(discord.abc.Messageable, _BaseUser):
"""Represents a Discord member to a :class:`Guild`.
@ -221,66 +252,90 @@ class Member(discord.abc.Messageable, _BaseUser):
Nitro boost on the guild, if available. This could be ``None``.
"""
__slots__ = ('_roles', 'joined_at', 'premium_since', '_client_status',
'activities', 'guild', 'pending', 'nick', '_user', '_state')
def __init__(self, *, data, guild, state):
self._state = state
self._user = state.store_user(data['user'])
self.guild = guild
self.joined_at = utils.parse_time(data.get('joined_at'))
self.premium_since = utils.parse_time(data.get('premium_since'))
self._update_roles(data)
self._client_status = {
None: 'offline'
}
self.activities = []
self.nick = data.get('nick', None)
self.pending = data.get('pending', False)
def __str__(self):
__slots__ = (
'_roles',
'joined_at',
'premium_since',
'activities',
'guild',
'pending',
'nick',
'_client_status',
'_user',
'_state',
)
if TYPE_CHECKING:
name: str
id: int
discriminator: str
bot: bool
system: bool
created_at: datetime.datetime
default_avatar = User.default_avatar
avatar = User.avatar
dm_channel = User.dm_channel
create_dm = User.create_dm
mutual_guilds = User.mutual_guilds
public_flags = User.public_flags
def __init__(self, *, data: GatewayMemberPayload, guild: Guild, state: ConnectionState):
self._state: ConnectionState = state
self._user: User = state.store_user(data['user'])
self.guild: Guild = guild
self.joined_at: Optional[datetime.datetime] = utils.parse_time(data.get('joined_at'))
self.premium_since: Optional[datetime.datetime] = utils.parse_time(data.get('premium_since'))
self._roles: utils.SnowflakeList = utils.SnowflakeList(map(int, data['roles']))
self._client_status: Dict[Optional[str], str] = {None: 'offline'}
self.activities: Tuple[ActivityTypes, ...] = tuple()
self.nick: Optional[str] = data.get('nick', None)
self.pending: bool = data.get('pending', False)
def __str__(self) -> str:
return str(self._user)
def __repr__(self):
return f'<Member id={self._user.id} name={self._user.name!r} discriminator={self._user.discriminator!r}' \
f' bot={self._user.bot} nick={self.nick!r} guild={self.guild!r}>'
def __repr__(self) -> str:
return (
f'<Member id={self._user.id} name={self._user.name!r} discriminator={self._user.discriminator!r}'
f' bot={self._user.bot} nick={self.nick!r} guild={self.guild!r}>'
)
def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
return isinstance(other, _BaseUser) and other.id == self.id
def __ne__(self, other):
def __ne__(self, other: Any) -> bool:
return not self.__eq__(other)
def __hash__(self):
def __hash__(self) -> int:
return hash(self._user)
@classmethod
def _from_message(cls, *, message, data):
def _from_message(cls: Type[M], *, message: Message, data: MemberPayload) -> M:
author = message.author
data['user'] = author._to_minimal_user_json()
return cls(data=data, guild=message.guild, state=message._state)
data['user'] = author._to_minimal_user_json() # type: ignore
return cls(data=data, guild=message.guild, state=message._state) # type: ignore
def _update_from_message(self, data):
def _update_from_message(self, data: MemberPayload) -> None:
self.joined_at = utils.parse_time(data.get('joined_at'))
self.premium_since = utils.parse_time(data.get('premium_since'))
self._update_roles(data)
self._roles = utils.SnowflakeList(map(int, data['roles']))
self.nick = data.get('nick', None)
self.pending = data.get('pending', False)
@classmethod
def _try_upgrade(cls, *, data, guild, state):
def _try_upgrade(cls: Type[M], *, data: UserWithMemberPayload, guild: Guild, state: ConnectionState) -> Union[User, M]:
# A User object with a 'member' key
try:
member_data = data.pop('member')
except KeyError:
return state.store_user(data)
else:
member_data['user'] = data
return cls(data=member_data, guild=guild, state=state)
member_data['user'] = data # type: ignore
return cls(data=member_data, guild=guild, state=state) # type: ignore
@classmethod
def _copy(cls, member):
self = cls.__new__(cls) # to bypass __init__
def _copy(cls: Type[M], member: M) -> M:
self: M = cls.__new__(cls) # to bypass __init__
self._roles = utils.SnowflakeList(member._roles, is_sorted=True)
self.joined_at = member.joined_at
@ -301,10 +356,7 @@ class Member(discord.abc.Messageable, _BaseUser):
ch = await self.create_dm()
return ch
def _update_roles(self, data):
self._roles = utils.SnowflakeList(map(int, data['roles']))
def _update(self, data):
def _update(self, data: MemberPayload) -> None:
# the nickname change is optional,
# if it isn't in the payload then it didn't change
try:
@ -318,21 +370,20 @@ class Member(discord.abc.Messageable, _BaseUser):
pass
self.premium_since = utils.parse_time(data.get('premium_since'))
self._update_roles(data)
self._roles = utils.SnowflakeList(map(int, data['roles']))
def _presence_update(self, data, user):
def _presence_update(self, data: PartialPresenceUpdate, user: UserPayload) -> Optional[Tuple[User, User]]:
self.activities = tuple(map(create_activity, data['activities']))
self._client_status = {
sys.intern(key): sys.intern(value)
for key, value in data.get('client_status', {}).items()
sys.intern(key): sys.intern(value) for key, value in data.get('client_status', {}).items() # type: ignore
}
self._client_status[None] = sys.intern(data['status'])
if len(user) > 1:
return self._update_inner_user(user)
return False
return None
def _update_inner_user(self, user):
def _update_inner_user(self, user: UserPayload) -> Optional[Tuple[User, User]]:
u = self._user
original = (u.name, u._avatar, u.discriminator, u._public_flags)
# These keys seem to always be available
@ -344,12 +395,12 @@ class Member(discord.abc.Messageable, _BaseUser):
return to_return, u
@property
def status(self):
def status(self) -> Status:
""":class:`Status`: The member's overall status. If the value is unknown, then it will be a :class:`str` instead."""
return try_enum(Status, self._client_status[None])
@property
def raw_status(self):
def raw_status(self) -> str:
""":class:`str`: The member's overall status as a string value.
.. versionadded:: 1.5
@ -357,31 +408,31 @@ class Member(discord.abc.Messageable, _BaseUser):
return self._client_status[None]
@status.setter
def status(self, value):
def status(self, value: Status) -> None:
# internal use only
self._client_status[None] = str(value)
@property
def mobile_status(self):
def mobile_status(self) -> Status:
""":class:`Status`: The member's status on a mobile device, if applicable."""
return try_enum(Status, self._client_status.get('mobile', 'offline'))
@property
def desktop_status(self):
def desktop_status(self) -> Status:
""":class:`Status`: The member's status on the desktop client, if applicable."""
return try_enum(Status, self._client_status.get('desktop', 'offline'))
@property
def web_status(self):
def web_status(self) -> Status:
""":class:`Status`: The member's status on the web client, if applicable."""
return try_enum(Status, self._client_status.get('web', 'offline'))
def is_on_mobile(self):
def is_on_mobile(self) -> bool:
""":class:`bool`: A helper function that determines if a member is active on a mobile device."""
return 'mobile' in self._client_status
@property
def colour(self):
def colour(self) -> Colour:
""":class:`Colour`: A property that returns a colour denoting the rendered colour
for the member. If the default colour is the one rendered then an instance
of :meth:`Colour.default` is returned.
@ -389,7 +440,7 @@ class Member(discord.abc.Messageable, _BaseUser):
There is an alias for this named :attr:`color`.
"""
roles = self.roles[1:] # remove @everyone
roles = self.roles[1:] # remove @everyone
# highest order of the colour is the one that gets rendered.
# if the highest is the default colour then the next one with a colour
@ -400,7 +451,7 @@ class Member(discord.abc.Messageable, _BaseUser):
return Colour.default()
@property
def color(self):
def color(self) -> Colour:
""":class:`Colour`: A property that returns a color denoting the rendered color for
the member. If the default color is the one rendered then an instance of :meth:`Colour.default`
is returned.
@ -410,7 +461,7 @@ class Member(discord.abc.Messageable, _BaseUser):
return self.colour
@property
def roles(self):
def roles(self) -> List[Role]:
"""List[:class:`Role`]: A :class:`list` of :class:`Role` that the member belongs to. Note
that the first element of this list is always the default '@everyone'
role.
@ -428,14 +479,14 @@ class Member(discord.abc.Messageable, _BaseUser):
return result
@property
def mention(self):
def mention(self) -> str:
""":class:`str`: Returns a string that allows you to mention the member."""
if self.nick:
return f'<@!{self._user.id}>'
return f'<@{self._user.id}>'
@property
def display_name(self):
def display_name(self) -> str:
""":class:`str`: Returns the user's display name.
For regular users this is just their username, but
@ -445,8 +496,8 @@ class Member(discord.abc.Messageable, _BaseUser):
return self.nick or self.name
@property
def activity(self):
"""Union[:class:`BaseActivity`, :class:`Spotify`]: Returns the primary
def activity(self) -> Optional[ActivityTypes]:
"""Optional[Union[:class:`BaseActivity`, :class:`Spotify`]]: Returns the primary
activity the user is currently doing. Could be ``None`` if no activity is being done.
.. note::
@ -462,7 +513,7 @@ class Member(discord.abc.Messageable, _BaseUser):
if self.activities:
return self.activities[0]
def mentioned_in(self, message):
def mentioned_in(self, message: Message) -> bool:
"""Checks if the member is mentioned in the specified message.
Parameters
@ -484,7 +535,7 @@ class Member(discord.abc.Messageable, _BaseUser):
return any(self._roles.has(role.id) for role in message.role_mentions)
@property
def top_role(self):
def top_role(self) -> Role:
""":class:`Role`: Returns the member's highest role.
This is useful for figuring where a member stands in the role
@ -497,7 +548,7 @@ class Member(discord.abc.Messageable, _BaseUser):
return max(guild.get_role(rid) or guild.default_role for rid in self._roles)
@property
def guild_permissions(self):
def guild_permissions(self) -> Permissions:
""":class:`Permissions`: Returns the member's guild permissions.
This only takes into consideration the guild permissions
@ -522,29 +573,21 @@ class Member(discord.abc.Messageable, _BaseUser):
return base
@property
def voice(self):
def voice(self) -> Optional[VoiceState]:
"""Optional[:class:`VoiceState`]: Returns the member's current voice state."""
return self.guild._voice_state_for(self._user.id)
@overload
async def ban(
self,
*,
reason: Optional[str] = ...,
delete_message_days: Literal[0, 1, 2, 3, 4, 5, 6, 7] = ...,
delete_message_days: Literal[0, 1, 2, 3, 4, 5, 6, 7] = 1,
reason: Optional[str] = None,
) -> None:
...
@overload
async def ban(self) -> None:
...
async def ban(self, **kwargs):
"""|coro|
Bans this member. Equivalent to :meth:`Guild.ban`.
"""
await self.guild.ban(self, **kwargs)
await self.guild.ban(self, reason=reason, delete_message_days=delete_message_days)
async def unban(self, *, reason: Optional[str] = None) -> None:
"""|coro|
@ -667,8 +710,7 @@ class Member(discord.abc.Messageable, _BaseUser):
if payload:
await http.edit_member(guild_id, self.id, reason=reason, **payload)
async def request_to_speak(self):
async def request_to_speak(self) -> None:
"""|coro|
Request to speak in the connected channel.
@ -723,7 +765,7 @@ class Member(discord.abc.Messageable, _BaseUser):
"""
await self.edit(voice_channel=channel, reason=reason)
async def add_roles(self, *roles: Snowflake, reason: Optional[str] = None, atomic: bool = True):
async def add_roles(self, *roles: Snowflake, reason: Optional[str] = None, atomic: bool = True) -> None:
r"""|coro|
Gives the member a number of :class:`Role`\s.
@ -792,7 +834,7 @@ class Member(discord.abc.Messageable, _BaseUser):
"""
if not atomic:
new_roles = [Object(id=r.id) for r in self.roles[1:]] # remove @everyone
new_roles = [Object(id=r.id) for r in self.roles[1:]] # remove @everyone
for role in roles:
try:
new_roles.remove(Object(id=role.id))
@ -807,7 +849,7 @@ class Member(discord.abc.Messageable, _BaseUser):
for role in roles:
await req(guild_id, user_id, role.id, reason=reason)
def get_role(self, role_id: int) -> Optional[discord.Role]:
def get_role(self, role_id: int) -> Optional[Role]:
"""Returns a role with the given ID from roles which the member has.
.. versionadded:: 2.0

7
discord/message.py

@ -60,7 +60,10 @@ if TYPE_CHECKING:
from .types.components import Component as ComponentPayload
from .types.threads import ThreadArchiveDuration
from .types.member import Member as MemberPayload
from .types.member import (
Member as MemberPayload,
UserWithMember as UserWithMemberPayload,
)
from .types.user import User as UserPayload
from .types.embed import Embed as EmbedPayload
from .abc import Snowflake
@ -839,7 +842,7 @@ class Message(Hashable):
# TODO: consider adding to cache here
self.author = Member._from_message(message=self, data=member)
def _handle_mentions(self, mentions: List[UserPayload]) -> None:
def _handle_mentions(self, mentions: List[UserWithMemberPayload]) -> None:
self.mentions = r = []
guild = self.guild
state = self._state

6
discord/types/activity.py

@ -41,9 +41,9 @@ class PartialPresenceUpdate(TypedDict):
class ClientStatus(TypedDict, total=False):
desktop: bool
mobile: bool
web: bool
desktop: str
mobile: str
web: str
class ActivityTimestamps(TypedDict, total=False):

15
discord/types/member.py

@ -44,3 +44,18 @@ class Member(PartialMember, total=False):
premium_since: str
pending: bool
permissions: str
class _OptionalGatewayMember(PartialMember, total=False):
nick: str
premium_since: str
pending: bool
permissions: str
class GatewayMember(_OptionalGatewayMember):
user: User
class UserWithMember(User, total=False):
member: _OptionalGatewayMember

4
discord/types/message.py

@ -26,7 +26,7 @@ from __future__ import annotations
from typing import List, Literal, Optional, TypedDict, Union
from .snowflake import Snowflake, SnowflakeList
from .member import Member
from .member import Member, UserWithMember
from .user import User
from .emoji import PartialEmoji
from .embed import Embed
@ -135,7 +135,7 @@ class Message(_MessageOptional):
edited_timestamp: Optional[str]
tts: bool
mention_everyone: bool
mentions: List[User]
mentions: List[UserWithMember]
mention_roles: SnowflakeList
attachments: List[Attachment]
embeds: List[Embed]

Loading…
Cancel
Save