diff --git a/discord/activity.py b/discord/activity.py index f1fe39f98..d0a1af9a7 100644 --- a/discord/activity.py +++ b/discord/activity.py @@ -25,7 +25,7 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations import datetime -from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union +from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union, overload from .asset import Asset from .enums import ActivityType, try_enum @@ -92,6 +92,7 @@ t.ActivityFlags = { if TYPE_CHECKING: from .types.activity import ( + Activity as ActivityPayload, ActivityTimestamps, ActivityParty, ActivityAssets, @@ -801,7 +802,17 @@ class CustomActivity(BaseActivity): return f'' -def create_activity(data: Optional[Dict[str, Any]]) -> Optional[Union[Activity, Game, CustomActivity, Streaming, Spotify]]: +ActivityTypes = Union[Activity, Game, CustomActivity, Streaming, Spotify] + +@overload +def create_activity(data: ActivityPayload) -> ActivityTypes: + ... + +@overload +def create_activity(data: None) -> None: + ... + +def create_activity(data: Optional[ActivityPayload]) -> Optional[ActivityTypes]: if not data: return None diff --git a/discord/guild.py b/discord/guild.py index b43b12df2..8ec73543b 100644 --- a/discord/guild.py +++ b/discord/guild.py @@ -454,7 +454,7 @@ class Guild(Hashable): user_id = int(presence['user']['id']) member = self.get_member(user_id) if member is not None: - member._presence_update(presence, empty_tuple) + member._presence_update(presence, empty_tuple) # type: ignore if 'channels' in data: channels = data['channels'] diff --git a/discord/member.py b/discord/member.py index 5bc59c8d4..86e08ef66 100644 --- a/discord/member.py +++ b/discord/member.py @@ -29,14 +29,14 @@ import inspect import itertools import sys from operator import attrgetter -from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Union, overload +from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union, overload import discord.abc from . import utils from .utils import MISSING from .user import BaseUser, User -from .activity import create_activity +from .activity import create_activity, ActivityTypes from .permissions import Permissions from .enums import Status, try_enum from .colour import Colour @@ -49,10 +49,23 @@ __all__ = ( if TYPE_CHECKING: from .channel import VoiceChannel, StageChannel + from .guild import Guild + from .types.activity import PartialPresenceUpdate + from .types.member import ( + GatewayMember as GatewayMemberPayload, + Member as MemberPayload, + UserWithMember as UserWithMemberPayload, + ) + from .types.user import User as UserPayload from .abc import Snowflake + from .state import ConnectionState + from .message import Message + from .role import Role + from .types.voice import VoiceState as VoiceStatePayload VocalGuildChannel = Union[VoiceChannel, StageChannel] + class VoiceState: """Represents a Discord user's voice state. @@ -96,38 +109,49 @@ class VoiceState: is not currently in a voice channel. """ - __slots__ = ('session_id', 'deaf', 'mute', 'self_mute', - 'self_stream', 'self_video', 'self_deaf', 'afk', 'channel', - 'requested_to_speak_at', 'suppress') - - def __init__(self, *, data, channel=None): - self.session_id = data.get('session_id') + __slots__ = ( + 'session_id', + 'deaf', + 'mute', + 'self_mute', + 'self_stream', + 'self_video', + 'self_deaf', + 'afk', + 'channel', + 'requested_to_speak_at', + 'suppress', + ) + + def __init__(self, *, data: VoiceStatePayload, channel: Optional[VocalGuildChannel] = None): + self.session_id: str = data.get('session_id') self._update(data, channel) - def _update(self, data, channel): - self.self_mute = data.get('self_mute', False) - self.self_deaf = data.get('self_deaf', False) - self.self_stream = data.get('self_stream', False) - self.self_video = data.get('self_video', False) - self.afk = data.get('suppress', False) - self.mute = data.get('mute', False) - self.deaf = data.get('deaf', False) - self.suppress = data.get('suppress', False) - self.requested_to_speak_at = utils.parse_time(data.get('request_to_speak_timestamp')) - self.channel = channel - - def __repr__(self): + def _update(self, data: VoiceStatePayload, channel: Optional[VocalGuildChannel]): + self.self_mute: bool = data.get('self_mute', False) + self.self_deaf: bool = data.get('self_deaf', False) + self.self_stream: bool = data.get('self_stream', False) + self.self_video: bool = data.get('self_video', False) + self.afk: bool = data.get('suppress', False) + self.mute: bool = data.get('mute', False) + self.deaf: bool = data.get('deaf', False) + self.suppress: bool = data.get('suppress', False) + self.requested_to_speak_at: Optional[datetime.datetime] = utils.parse_time(data.get('request_to_speak_timestamp')) + self.channel: Optional[VocalGuildChannel] = channel + + def __repr__(self) -> str: attrs = [ ('self_mute', self.self_mute), ('self_deaf', self.self_deaf), ('self_stream', self.self_stream), ('suppress', self.suppress), ('requested_to_speak_at', self.requested_to_speak_at), - ('channel', self.channel) + ('channel', self.channel), ] inner = ' '.join('%s=%r' % t for t in attrs) return f'<{self.__class__.__name__} {inner}>' + def flatten_user(cls): for attr, value in itertools.chain(BaseUser.__dict__.items(), User.__dict__.items()): # ignore private/special methods @@ -151,9 +175,12 @@ def flatten_user(cls): def generate_function(x): # We want sphinx to properly show coroutine functions as coroutines if inspect.iscoroutinefunction(value): - async def general(self, *args, **kwargs): + + async def general(self, *args, **kwargs): # type: ignore return await getattr(self._user, x)(*args, **kwargs) + else: + def general(self, *args, **kwargs): return getattr(self._user, x)(*args, **kwargs) @@ -166,8 +193,12 @@ def flatten_user(cls): return cls + _BaseUser = discord.abc.User +M = TypeVar('M', bound='Member') + + @flatten_user class Member(discord.abc.Messageable, _BaseUser): """Represents a Discord member to a :class:`Guild`. @@ -221,66 +252,90 @@ class Member(discord.abc.Messageable, _BaseUser): Nitro boost on the guild, if available. This could be ``None``. """ - __slots__ = ('_roles', 'joined_at', 'premium_since', '_client_status', - 'activities', 'guild', 'pending', 'nick', '_user', '_state') - - def __init__(self, *, data, guild, state): - self._state = state - self._user = state.store_user(data['user']) - self.guild = guild - self.joined_at = utils.parse_time(data.get('joined_at')) - self.premium_since = utils.parse_time(data.get('premium_since')) - self._update_roles(data) - self._client_status = { - None: 'offline' - } - self.activities = [] - self.nick = data.get('nick', None) - self.pending = data.get('pending', False) - - def __str__(self): + __slots__ = ( + '_roles', + 'joined_at', + 'premium_since', + 'activities', + 'guild', + 'pending', + 'nick', + '_client_status', + '_user', + '_state', + ) + + if TYPE_CHECKING: + name: str + id: int + discriminator: str + bot: bool + system: bool + created_at: datetime.datetime + default_avatar = User.default_avatar + avatar = User.avatar + dm_channel = User.dm_channel + create_dm = User.create_dm + mutual_guilds = User.mutual_guilds + public_flags = User.public_flags + + def __init__(self, *, data: GatewayMemberPayload, guild: Guild, state: ConnectionState): + self._state: ConnectionState = state + self._user: User = state.store_user(data['user']) + self.guild: Guild = guild + self.joined_at: Optional[datetime.datetime] = utils.parse_time(data.get('joined_at')) + self.premium_since: Optional[datetime.datetime] = utils.parse_time(data.get('premium_since')) + self._roles: utils.SnowflakeList = utils.SnowflakeList(map(int, data['roles'])) + self._client_status: Dict[Optional[str], str] = {None: 'offline'} + self.activities: Tuple[ActivityTypes, ...] = tuple() + self.nick: Optional[str] = data.get('nick', None) + self.pending: bool = data.get('pending', False) + + def __str__(self) -> str: return str(self._user) - def __repr__(self): - return f'' + def __repr__(self) -> str: + return ( + f'' + ) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return isinstance(other, _BaseUser) and other.id == self.id - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not self.__eq__(other) - def __hash__(self): + def __hash__(self) -> int: return hash(self._user) @classmethod - def _from_message(cls, *, message, data): + def _from_message(cls: Type[M], *, message: Message, data: MemberPayload) -> M: author = message.author - data['user'] = author._to_minimal_user_json() - return cls(data=data, guild=message.guild, state=message._state) + data['user'] = author._to_minimal_user_json() # type: ignore + return cls(data=data, guild=message.guild, state=message._state) # type: ignore - def _update_from_message(self, data): + def _update_from_message(self, data: MemberPayload) -> None: self.joined_at = utils.parse_time(data.get('joined_at')) self.premium_since = utils.parse_time(data.get('premium_since')) - self._update_roles(data) + self._roles = utils.SnowflakeList(map(int, data['roles'])) self.nick = data.get('nick', None) self.pending = data.get('pending', False) @classmethod - def _try_upgrade(cls, *, data, guild, state): + def _try_upgrade(cls: Type[M], *, data: UserWithMemberPayload, guild: Guild, state: ConnectionState) -> Union[User, M]: # A User object with a 'member' key try: member_data = data.pop('member') except KeyError: return state.store_user(data) else: - member_data['user'] = data - return cls(data=member_data, guild=guild, state=state) + member_data['user'] = data # type: ignore + return cls(data=member_data, guild=guild, state=state) # type: ignore @classmethod - def _copy(cls, member): - self = cls.__new__(cls) # to bypass __init__ + def _copy(cls: Type[M], member: M) -> M: + self: M = cls.__new__(cls) # to bypass __init__ self._roles = utils.SnowflakeList(member._roles, is_sorted=True) self.joined_at = member.joined_at @@ -301,10 +356,7 @@ class Member(discord.abc.Messageable, _BaseUser): ch = await self.create_dm() return ch - def _update_roles(self, data): - self._roles = utils.SnowflakeList(map(int, data['roles'])) - - def _update(self, data): + def _update(self, data: MemberPayload) -> None: # the nickname change is optional, # if it isn't in the payload then it didn't change try: @@ -318,21 +370,20 @@ class Member(discord.abc.Messageable, _BaseUser): pass self.premium_since = utils.parse_time(data.get('premium_since')) - self._update_roles(data) + self._roles = utils.SnowflakeList(map(int, data['roles'])) - def _presence_update(self, data, user): + def _presence_update(self, data: PartialPresenceUpdate, user: UserPayload) -> Optional[Tuple[User, User]]: self.activities = tuple(map(create_activity, data['activities'])) self._client_status = { - sys.intern(key): sys.intern(value) - for key, value in data.get('client_status', {}).items() + sys.intern(key): sys.intern(value) for key, value in data.get('client_status', {}).items() # type: ignore } self._client_status[None] = sys.intern(data['status']) if len(user) > 1: return self._update_inner_user(user) - return False + return None - def _update_inner_user(self, user): + def _update_inner_user(self, user: UserPayload) -> Optional[Tuple[User, User]]: u = self._user original = (u.name, u._avatar, u.discriminator, u._public_flags) # These keys seem to always be available @@ -344,12 +395,12 @@ class Member(discord.abc.Messageable, _BaseUser): return to_return, u @property - def status(self): + def status(self) -> Status: """:class:`Status`: The member's overall status. If the value is unknown, then it will be a :class:`str` instead.""" return try_enum(Status, self._client_status[None]) @property - def raw_status(self): + def raw_status(self) -> str: """:class:`str`: The member's overall status as a string value. .. versionadded:: 1.5 @@ -357,31 +408,31 @@ class Member(discord.abc.Messageable, _BaseUser): return self._client_status[None] @status.setter - def status(self, value): + def status(self, value: Status) -> None: # internal use only self._client_status[None] = str(value) @property - def mobile_status(self): + def mobile_status(self) -> Status: """:class:`Status`: The member's status on a mobile device, if applicable.""" return try_enum(Status, self._client_status.get('mobile', 'offline')) @property - def desktop_status(self): + def desktop_status(self) -> Status: """:class:`Status`: The member's status on the desktop client, if applicable.""" return try_enum(Status, self._client_status.get('desktop', 'offline')) @property - def web_status(self): + def web_status(self) -> Status: """:class:`Status`: The member's status on the web client, if applicable.""" return try_enum(Status, self._client_status.get('web', 'offline')) - def is_on_mobile(self): + def is_on_mobile(self) -> bool: """:class:`bool`: A helper function that determines if a member is active on a mobile device.""" return 'mobile' in self._client_status @property - def colour(self): + def colour(self) -> Colour: """:class:`Colour`: A property that returns a colour denoting the rendered colour for the member. If the default colour is the one rendered then an instance of :meth:`Colour.default` is returned. @@ -389,7 +440,7 @@ class Member(discord.abc.Messageable, _BaseUser): There is an alias for this named :attr:`color`. """ - roles = self.roles[1:] # remove @everyone + roles = self.roles[1:] # remove @everyone # highest order of the colour is the one that gets rendered. # if the highest is the default colour then the next one with a colour @@ -400,7 +451,7 @@ class Member(discord.abc.Messageable, _BaseUser): return Colour.default() @property - def color(self): + def color(self) -> Colour: """:class:`Colour`: A property that returns a color denoting the rendered color for the member. If the default color is the one rendered then an instance of :meth:`Colour.default` is returned. @@ -410,7 +461,7 @@ class Member(discord.abc.Messageable, _BaseUser): return self.colour @property - def roles(self): + def roles(self) -> List[Role]: """List[:class:`Role`]: A :class:`list` of :class:`Role` that the member belongs to. Note that the first element of this list is always the default '@everyone' role. @@ -428,14 +479,14 @@ class Member(discord.abc.Messageable, _BaseUser): return result @property - def mention(self): + def mention(self) -> str: """:class:`str`: Returns a string that allows you to mention the member.""" if self.nick: return f'<@!{self._user.id}>' return f'<@{self._user.id}>' @property - def display_name(self): + def display_name(self) -> str: """:class:`str`: Returns the user's display name. For regular users this is just their username, but @@ -445,8 +496,8 @@ class Member(discord.abc.Messageable, _BaseUser): return self.nick or self.name @property - def activity(self): - """Union[:class:`BaseActivity`, :class:`Spotify`]: Returns the primary + def activity(self) -> Optional[ActivityTypes]: + """Optional[Union[:class:`BaseActivity`, :class:`Spotify`]]: Returns the primary activity the user is currently doing. Could be ``None`` if no activity is being done. .. note:: @@ -462,7 +513,7 @@ class Member(discord.abc.Messageable, _BaseUser): if self.activities: return self.activities[0] - def mentioned_in(self, message): + def mentioned_in(self, message: Message) -> bool: """Checks if the member is mentioned in the specified message. Parameters @@ -484,7 +535,7 @@ class Member(discord.abc.Messageable, _BaseUser): return any(self._roles.has(role.id) for role in message.role_mentions) @property - def top_role(self): + def top_role(self) -> Role: """:class:`Role`: Returns the member's highest role. This is useful for figuring where a member stands in the role @@ -497,7 +548,7 @@ class Member(discord.abc.Messageable, _BaseUser): return max(guild.get_role(rid) or guild.default_role for rid in self._roles) @property - def guild_permissions(self): + def guild_permissions(self) -> Permissions: """:class:`Permissions`: Returns the member's guild permissions. This only takes into consideration the guild permissions @@ -522,29 +573,21 @@ class Member(discord.abc.Messageable, _BaseUser): return base @property - def voice(self): + def voice(self) -> Optional[VoiceState]: """Optional[:class:`VoiceState`]: Returns the member's current voice state.""" return self.guild._voice_state_for(self._user.id) - @overload async def ban( self, *, - reason: Optional[str] = ..., - delete_message_days: Literal[0, 1, 2, 3, 4, 5, 6, 7] = ..., + delete_message_days: Literal[0, 1, 2, 3, 4, 5, 6, 7] = 1, + reason: Optional[str] = None, ) -> None: - ... - - @overload - async def ban(self) -> None: - ... - - async def ban(self, **kwargs): """|coro| Bans this member. Equivalent to :meth:`Guild.ban`. """ - await self.guild.ban(self, **kwargs) + await self.guild.ban(self, reason=reason, delete_message_days=delete_message_days) async def unban(self, *, reason: Optional[str] = None) -> None: """|coro| @@ -667,8 +710,7 @@ class Member(discord.abc.Messageable, _BaseUser): if payload: await http.edit_member(guild_id, self.id, reason=reason, **payload) - - async def request_to_speak(self): + async def request_to_speak(self) -> None: """|coro| Request to speak in the connected channel. @@ -723,7 +765,7 @@ class Member(discord.abc.Messageable, _BaseUser): """ await self.edit(voice_channel=channel, reason=reason) - async def add_roles(self, *roles: Snowflake, reason: Optional[str] = None, atomic: bool = True): + async def add_roles(self, *roles: Snowflake, reason: Optional[str] = None, atomic: bool = True) -> None: r"""|coro| Gives the member a number of :class:`Role`\s. @@ -792,7 +834,7 @@ class Member(discord.abc.Messageable, _BaseUser): """ if not atomic: - new_roles = [Object(id=r.id) for r in self.roles[1:]] # remove @everyone + new_roles = [Object(id=r.id) for r in self.roles[1:]] # remove @everyone for role in roles: try: new_roles.remove(Object(id=role.id)) @@ -807,7 +849,7 @@ class Member(discord.abc.Messageable, _BaseUser): for role in roles: await req(guild_id, user_id, role.id, reason=reason) - def get_role(self, role_id: int) -> Optional[discord.Role]: + def get_role(self, role_id: int) -> Optional[Role]: """Returns a role with the given ID from roles which the member has. .. versionadded:: 2.0 diff --git a/discord/message.py b/discord/message.py index e641028b6..21635e1db 100644 --- a/discord/message.py +++ b/discord/message.py @@ -60,7 +60,10 @@ if TYPE_CHECKING: from .types.components import Component as ComponentPayload from .types.threads import ThreadArchiveDuration - from .types.member import Member as MemberPayload + from .types.member import ( + Member as MemberPayload, + UserWithMember as UserWithMemberPayload, + ) from .types.user import User as UserPayload from .types.embed import Embed as EmbedPayload from .abc import Snowflake @@ -839,7 +842,7 @@ class Message(Hashable): # TODO: consider adding to cache here self.author = Member._from_message(message=self, data=member) - def _handle_mentions(self, mentions: List[UserPayload]) -> None: + def _handle_mentions(self, mentions: List[UserWithMemberPayload]) -> None: self.mentions = r = [] guild = self.guild state = self._state diff --git a/discord/types/activity.py b/discord/types/activity.py index 9c583d3a8..9d46001e1 100644 --- a/discord/types/activity.py +++ b/discord/types/activity.py @@ -41,9 +41,9 @@ class PartialPresenceUpdate(TypedDict): class ClientStatus(TypedDict, total=False): - desktop: bool - mobile: bool - web: bool + desktop: str + mobile: str + web: str class ActivityTimestamps(TypedDict, total=False): diff --git a/discord/types/member.py b/discord/types/member.py index b80ab6130..d93005e8f 100644 --- a/discord/types/member.py +++ b/discord/types/member.py @@ -44,3 +44,18 @@ class Member(PartialMember, total=False): premium_since: str pending: bool permissions: str + + +class _OptionalGatewayMember(PartialMember, total=False): + nick: str + premium_since: str + pending: bool + permissions: str + + +class GatewayMember(_OptionalGatewayMember): + user: User + + +class UserWithMember(User, total=False): + member: _OptionalGatewayMember diff --git a/discord/types/message.py b/discord/types/message.py index 0b345f8b9..2dbdf9839 100644 --- a/discord/types/message.py +++ b/discord/types/message.py @@ -26,7 +26,7 @@ from __future__ import annotations from typing import List, Literal, Optional, TypedDict, Union from .snowflake import Snowflake, SnowflakeList -from .member import Member +from .member import Member, UserWithMember from .user import User from .emoji import PartialEmoji from .embed import Embed @@ -135,7 +135,7 @@ class Message(_MessageOptional): edited_timestamp: Optional[str] tts: bool mention_everyone: bool - mentions: List[User] + mentions: List[UserWithMember] mention_roles: SnowflakeList attachments: List[Attachment] embeds: List[Embed]