From 3cf000d46716e5ed0f2340896e1924ed1dd67cc5 Mon Sep 17 00:00:00 2001 From: Nadir Chowdhury Date: Tue, 22 Feb 2022 03:09:40 +0000 Subject: [PATCH] Type up gateway payloads --- discord/member.py | 3 +- discord/message.py | 3 +- discord/raw_models.py | 17 +- discord/state.py | 150 ++++++++-------- discord/types/activity.py | 4 +- discord/types/appinfo.py | 5 + discord/types/gateway.py | 329 +++++++++++++++++++++++++++++++++++- discord/types/member.py | 6 +- discord/types/message.py | 8 + discord/types/raw_models.py | 87 ---------- 10 files changed, 434 insertions(+), 178 deletions(-) delete mode 100644 discord/types/raw_models.py diff --git a/discord/member.py b/discord/member.py index 4d05888ce..21676fc8a 100644 --- a/discord/member.py +++ b/discord/member.py @@ -59,6 +59,7 @@ if TYPE_CHECKING: Member as MemberPayload, UserWithMember as UserWithMemberPayload, ) + from .types.gateway import GuildMemberUpdateEvent from .types.user import User as UserPayload from .abc import Snowflake from .state import ConnectionState @@ -372,7 +373,7 @@ class Member(discord.abc.Messageable, _UserTag): ch = await self.create_dm() return ch - def _update(self, data: MemberPayload) -> None: + def _update(self, data: GuildMemberUpdateEvent) -> None: # the nickname change is optional, # if it isn't in the payload then it didn't change try: diff --git a/discord/message.py b/discord/message.py index 02a4efee2..f1f257be1 100644 --- a/discord/message.py +++ b/discord/message.py @@ -82,6 +82,7 @@ if TYPE_CHECKING: ) from .types.user import User as UserPayload from .types.embed import Embed as EmbedPayload + from .types.gateway import MessageReactionRemoveEvent from .abc import Snowflake from .abc import GuildChannel, PartialMessageableChannel, MessageableChannel from .components import Component @@ -762,7 +763,7 @@ class Message(Hashable): return reaction - def _remove_reaction(self, data: ReactionPayload, emoji: EmojiInputType, user_id: int) -> Reaction: + def _remove_reaction(self, data: MessageReactionRemoveEvent, emoji: EmojiInputType, user_id: int) -> Reaction: reaction = utils.find(lambda r: r.emoji == emoji, self.reactions) if reaction is None: diff --git a/discord/raw_models.py b/discord/raw_models.py index b8a5acc3e..a8cd8370f 100644 --- a/discord/raw_models.py +++ b/discord/raw_models.py @@ -24,22 +24,25 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations -from typing import TYPE_CHECKING, Optional, Set, List +from typing import TYPE_CHECKING, Optional, Set, List, Tuple, Union if TYPE_CHECKING: - from .types.raw_models import ( + from .types.gateway import ( MessageDeleteEvent, - BulkMessageDeleteEvent, - ReactionActionEvent, + MessageDeleteBulkEvent as BulkMessageDeleteEvent, + MessageReactionAddEvent, + MessageReactionRemoveEvent, + MessageReactionRemoveAllEvent as ReactionClearEvent, + MessageReactionRemoveEmojiEvent as ReactionClearEmojiEvent, MessageUpdateEvent, - ReactionClearEvent, - ReactionClearEmojiEvent, IntegrationDeleteEvent, ) from .message import Message from .partial_emoji import PartialEmoji from .member import Member + ReactionActionEvent = Union[MessageReactionAddEvent, MessageReactionRemoveEvent] + __all__ = ( 'RawMessageDeleteEvent', @@ -53,6 +56,8 @@ __all__ = ( class _RawReprMixin: + __slots__: Tuple[str, ...] = () + def __repr__(self) -> str: value = ' '.join(f'{attr}={getattr(self, attr)!r}' for attr in self.__slots__) return f'<{self.__class__.__name__} {value}>' diff --git a/discord/state.py b/discord/state.py index 39a676773..17a722bdd 100644 --- a/discord/state.py +++ b/discord/state.py @@ -69,13 +69,15 @@ if TYPE_CHECKING: from .client import Client from .gateway import DiscordWebSocket + from .types.snowflake import Snowflake from .types.activity import Activity as ActivityPayload from .types.channel import DMChannel as DMChannelPayload from .types.user import User as UserPayload from .types.emoji import Emoji as EmojiPayload from .types.sticker import GuildSticker as GuildStickerPayload from .types.guild import Guild as GuildPayload - from .types.message import Message as MessagePayload + from .types.message import Message as MessagePayload, PartialMessage as PartialMessagePayload + from .types import gateway as gw T = TypeVar('T') CS = TypeVar('CS', bound='ConnectionState') @@ -447,7 +449,7 @@ class ConnectionState: return self._chunk_guilds and not guild.chunked and not (self._intents.presences and not guild.large) def _get_guild_channel( - self, data: MessagePayload, guild_id: Optional[int] = None + self, data: PartialMessagePayload, guild_id: Optional[int] = None ) -> Tuple[Union[Channel, Thread], Optional[Guild]]: channel_id = int(data['channel_id']) try: @@ -532,7 +534,7 @@ class ConnectionState: finally: self._ready_task = None - def parse_ready(self, data) -> None: + def parse_ready(self, data: gw.ReadyEvent) -> None: if self._ready_task is not None: self._ready_task.cancel() @@ -552,15 +554,15 @@ class ConnectionState: self.application_flags = ApplicationFlags._from_value(application['flags']) # type: ignore for guild_data in data['guilds']: - self._add_guild_from_data(guild_data) + self._add_guild_from_data(guild_data) # type: ignore self.dispatch('connect') self._ready_task = asyncio.create_task(self._delay_ready()) - def parse_resumed(self, data) -> None: + def parse_resumed(self, data: gw.ResumedEvent) -> None: self.dispatch('resumed') - def parse_message_create(self, data) -> None: + def parse_message_create(self, data: gw.MessageCreateEvent) -> None: channel, _ = self._get_guild_channel(data) # channel would be the correct type here message = Message(channel=channel, data=data, state=self) # type: ignore @@ -571,7 +573,7 @@ class ConnectionState: if channel and channel.__class__ in (TextChannel, Thread): channel.last_message_id = message.id # type: ignore - def parse_message_delete(self, data) -> None: + def parse_message_delete(self, data: gw.MessageDeleteEvent) -> None: raw = RawMessageDeleteEvent(data) found = self._get_message(raw.message_id) raw.cached_message = found @@ -580,7 +582,7 @@ class ConnectionState: self.dispatch('message_delete', found) self._messages.remove(found) - def parse_message_delete_bulk(self, data) -> None: + def parse_message_delete_bulk(self, data: gw.MessageDeleteBulkEvent) -> None: raw = RawBulkMessageDeleteEvent(data) if self._messages: found_messages = [message for message in self._messages if message.id in raw.message_ids] @@ -594,7 +596,7 @@ class ConnectionState: # self._messages won't be None here self._messages.remove(msg) # type: ignore - def parse_message_update(self, data) -> None: + def parse_message_update(self, data: gw.MessageUpdateEvent) -> None: raw = RawMessageUpdateEvent(data) message = self._get_message(raw.message_id) if message is not None: @@ -612,10 +614,9 @@ class ConnectionState: if 'components' in data and self._view_store.is_message_tracked(raw.message_id): self._view_store.update_from_message(raw.message_id, data['components']) - def parse_message_reaction_add(self, data) -> None: - emoji = data['emoji'] - emoji_id = utils._get_as_snowflake(emoji, 'id') - emoji = PartialEmoji.with_state(self, id=emoji_id, animated=emoji.get('animated', False), name=emoji['name']) + def parse_message_reaction_add(self, data: gw.MessageReactionAddEvent) -> None: + emoji = PartialEmoji.from_dict(data['emoji']) + emoji._state = self raw = RawReactionActionEvent(data, emoji, 'REACTION_ADD') member_data = data.get('member') @@ -639,7 +640,7 @@ class ConnectionState: if user: self.dispatch('reaction_add', reaction, user) - def parse_message_reaction_remove_all(self, data) -> None: + def parse_message_reaction_remove_all(self, data: gw.MessageReactionRemoveAllEvent) -> None: raw = RawReactionClearEvent(data) self.dispatch('raw_reaction_clear', raw) @@ -649,10 +650,9 @@ class ConnectionState: message.reactions.clear() self.dispatch('reaction_clear', message, old_reactions) - def parse_message_reaction_remove(self, data) -> None: - emoji = data['emoji'] - emoji_id = utils._get_as_snowflake(emoji, 'id') - emoji = PartialEmoji.with_state(self, id=emoji_id, name=emoji['name']) + def parse_message_reaction_remove(self, data: gw.MessageReactionRemoveEvent) -> None: + emoji = PartialEmoji.from_dict(data['emoji']) + emoji._state = self raw = RawReactionActionEvent(data, emoji, 'REACTION_REMOVE') self.dispatch('raw_reaction_remove', raw) @@ -668,10 +668,9 @@ class ConnectionState: if user: self.dispatch('reaction_remove', reaction, user) - def parse_message_reaction_remove_emoji(self, data) -> None: - emoji = data['emoji'] - emoji_id = utils._get_as_snowflake(emoji, 'id') - emoji = PartialEmoji.with_state(self, id=emoji_id, name=emoji['name']) + def parse_message_reaction_remove_emoji(self, data: gw.MessageReactionRemoveEmojiEvent) -> None: + emoji = PartialEmoji.from_dict(data['emoji']) + emoji._state = self raw = RawReactionClearEmojiEvent(data, emoji) self.dispatch('raw_reaction_clear_emoji', raw) @@ -685,7 +684,7 @@ class ConnectionState: if reaction: self.dispatch('reaction_clear_emoji', reaction) - def parse_interaction_create(self, data) -> None: + def parse_interaction_create(self, data: gw.InteractionCreateEvent) -> None: interaction = Interaction(data=data, state=self) if data['type'] == 3: # interaction component custom_id = interaction.data['custom_id'] # type: ignore @@ -697,7 +696,7 @@ class ConnectionState: self._view_store.dispatch_modal(custom_id, interaction, components) # type: ignore self.dispatch('interaction', interaction) - def parse_presence_update(self, data) -> None: + def parse_presence_update(self, data: gw.PresenceUpdateEvent) -> None: guild_id = utils._get_as_snowflake(data, 'guild_id') # guild_id won't be None here guild = self._get_guild(guild_id) @@ -719,19 +718,19 @@ class ConnectionState: self.dispatch('presence_update', old_member, member) - def parse_user_update(self, data): + def parse_user_update(self, data: gw.UserUpdateEvent): if self.user: self.user._update(data) - def parse_invite_create(self, data) -> None: + def parse_invite_create(self, data: gw.InviteCreateEvent) -> None: invite = Invite.from_gateway(state=self, data=data) self.dispatch('invite_create', invite) - def parse_invite_delete(self, data) -> None: + def parse_invite_delete(self, data: gw.InviteDeleteEvent) -> None: invite = Invite.from_gateway(state=self, data=data) self.dispatch('invite_delete', invite) - def parse_channel_delete(self, data) -> None: + def parse_channel_delete(self, data: gw.ChannelDeleteEvent) -> None: guild = self._get_guild(utils._get_as_snowflake(data, 'guild_id')) channel_id = int(data['id']) if guild is not None: @@ -740,7 +739,7 @@ class ConnectionState: guild._remove_channel(channel) self.dispatch('guild_channel_delete', channel) - def parse_channel_update(self, data) -> None: + def parse_channel_update(self, data: gw.ChannelUpdateEvent) -> None: channel_type = try_enum(ChannelType, data.get('type')) channel_id = int(data['id']) if channel_type is ChannelType.group: @@ -757,14 +756,14 @@ class ConnectionState: channel = guild.get_channel(channel_id) if channel is not None: old_channel = copy.copy(channel) - channel._update(guild, data) + channel._update(guild, data) # type: ignore - the data payload varies based on the channel type. self.dispatch('guild_channel_update', old_channel, channel) else: _log.debug('CHANNEL_UPDATE referencing an unknown channel ID: %s. Discarding.', channel_id) else: _log.debug('CHANNEL_UPDATE referencing an unknown guild ID: %s. Discarding.', guild_id) - def parse_channel_create(self, data) -> None: + def parse_channel_create(self, data: gw.ChannelCreateEvent) -> None: factory, ch_type = _channel_factory(data['type']) if factory is None: _log.debug('CHANNEL_CREATE referencing an unknown channel type %s. Discarding.', data['type']) @@ -781,7 +780,7 @@ class ConnectionState: _log.debug('CHANNEL_CREATE referencing an unknown guild ID: %s. Discarding.', guild_id) return - def parse_channel_pins_update(self, data) -> None: + def parse_channel_pins_update(self, data: gw.ChannelPinsUpdateEvent) -> None: channel_id = int(data['channel_id']) try: guild = self._get_guild(int(data['guild_id'])) @@ -795,14 +794,14 @@ class ConnectionState: _log.debug('CHANNEL_PINS_UPDATE referencing an unknown channel ID: %s. Discarding.', channel_id) return - last_pin = utils.parse_time(data['last_pin_timestamp']) if data['last_pin_timestamp'] else None + last_pin = utils.parse_time(data.get('last_pin_timestamp')) if guild is None: self.dispatch('private_channel_pins_update', channel, last_pin) else: self.dispatch('guild_channel_pins_update', channel, last_pin) - def parse_thread_create(self, data) -> None: + def parse_thread_create(self, data: gw.ThreadCreateEvent) -> None: guild_id = int(data['guild_id']) guild: Optional[Guild] = self._get_guild(guild_id) if guild is None: @@ -815,7 +814,7 @@ class ConnectionState: if not has_thread: self.dispatch('thread_join', thread) - def parse_thread_update(self, data) -> None: + def parse_thread_update(self, data: gw.ThreadUpdateEvent) -> None: guild_id = int(data['guild_id']) guild = self._get_guild(guild_id) if guild is None: @@ -833,7 +832,7 @@ class ConnectionState: guild._add_thread(thread) self.dispatch('thread_join', thread) - def parse_thread_delete(self, data) -> None: + def parse_thread_delete(self, data: gw.ThreadDeleteEvent) -> None: guild_id = int(data['guild_id']) guild = self._get_guild(guild_id) if guild is None: @@ -843,10 +842,10 @@ class ConnectionState: thread_id = int(data['id']) thread = guild.get_thread(thread_id) if thread is not None: - guild._remove_thread(thread) # type: ignore + guild._remove_thread(thread) self.dispatch('thread_delete', thread) - def parse_thread_list_sync(self, data) -> None: + def parse_thread_list_sync(self, data: gw.ThreadListSyncEvent) -> None: guild_id = int(data['guild_id']) guild: Optional[Guild] = self._get_guild(guild_id) if guild is None: @@ -854,7 +853,7 @@ class ConnectionState: return try: - channel_ids = set(data['channel_ids']) + channel_ids = {int(i) for i in data['channel_ids']} except KeyError: # If not provided, then the entire guild is being synced # So all previous thread data should be overwritten @@ -882,7 +881,7 @@ class ConnectionState: for thread in previous_threads.values(): self.dispatch('thread_remove', thread) - def parse_thread_member_update(self, data) -> None: + def parse_thread_member_update(self, data: gw.ThreadMemberUpdate) -> None: guild_id = int(data['guild_id']) guild: Optional[Guild] = self._get_guild(guild_id) if guild is None: @@ -898,7 +897,7 @@ class ConnectionState: member = ThreadMember(thread, data) thread.me = member - def parse_thread_members_update(self, data) -> None: + def parse_thread_members_update(self, data: gw.ThreadMembersUpdate) -> None: guild_id = int(data['guild_id']) guild: Optional[Guild] = self._get_guild(guild_id) if guild is None: @@ -930,7 +929,7 @@ class ConnectionState: else: self.dispatch('thread_remove', thread) - def parse_guild_member_add(self, data) -> None: + def parse_guild_member_add(self, data: gw.GuildMemberAddEvent) -> None: guild = self._get_guild(int(data['guild_id'])) if guild is None: _log.debug('GUILD_MEMBER_ADD referencing an unknown guild ID: %s. Discarding.', data['guild_id']) @@ -947,7 +946,7 @@ class ConnectionState: self.dispatch('member_join', member) - def parse_guild_member_remove(self, data) -> None: + def parse_guild_member_remove(self, data: gw.GuildMemberRemoveEvent) -> None: guild = self._get_guild(int(data['guild_id'])) if guild is not None: try: @@ -963,7 +962,7 @@ class ConnectionState: else: _log.debug('GUILD_MEMBER_REMOVE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) - def parse_guild_member_update(self, data) -> None: + def parse_guild_member_update(self, data: gw.GuildMemberUpdateEvent) -> None: guild = self._get_guild(int(data['guild_id'])) user = data['user'] user_id = int(user['id']) @@ -982,7 +981,7 @@ class ConnectionState: self.dispatch('member_update', old_member, member) else: if self.member_cache_flags.joined: - member = Member(data=data, guild=guild, state=self) + member = Member(data=data, guild=guild, state=self) # type: ignore - the data is not complete, contains a delta of values # Force an update on the inner user if necessary user_update = member._update_inner_user(user) @@ -992,7 +991,7 @@ class ConnectionState: guild._add_member(member) _log.debug('GUILD_MEMBER_UPDATE referencing an unknown member ID: %s. Discarding.', user_id) - def parse_guild_emojis_update(self, data) -> None: + def parse_guild_emojis_update(self, data: gw.GuildEmojisUpdateEvent) -> None: guild = self._get_guild(int(data['guild_id'])) if guild is None: _log.debug('GUILD_EMOJIS_UPDATE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) @@ -1056,7 +1055,7 @@ class ConnectionState: else: self.dispatch('guild_join', guild) - def parse_guild_create(self, data) -> None: + def parse_guild_create(self, data: gw.GuildCreateEvent) -> None: unavailable = data.get('unavailable') if unavailable is True: # joined a guild with unavailable == True so.. @@ -1084,7 +1083,7 @@ class ConnectionState: else: self.dispatch('guild_join', guild) - def parse_guild_update(self, data) -> None: + def parse_guild_update(self, data: gw.GuildUpdateEvent) -> None: guild = self._get_guild(int(data['id'])) if guild is not None: old_guild = copy.copy(guild) @@ -1093,7 +1092,7 @@ class ConnectionState: else: _log.debug('GUILD_UPDATE referencing an unknown guild ID: %s. Discarding.', data['id']) - def parse_guild_delete(self, data) -> None: + def parse_guild_delete(self, data: gw.GuildDeleteEvent) -> None: guild = self._get_guild(int(data['id'])) if guild is None: _log.debug('GUILD_DELETE referencing an unknown guild ID: %s. Discarding.', data['id']) @@ -1115,7 +1114,7 @@ class ConnectionState: self._remove_guild(guild) self.dispatch('guild_remove', guild) - def parse_guild_ban_add(self, data) -> None: + def parse_guild_ban_add(self, data: gw.GuildBanAddEvent) -> None: # we make the assumption that GUILD_BAN_ADD is done # before GUILD_MEMBER_REMOVE is called # hence we don't remove it from cache or do anything @@ -1131,13 +1130,13 @@ class ConnectionState: member = guild.get_member(user.id) or user self.dispatch('member_ban', guild, member) - def parse_guild_ban_remove(self, data) -> None: + def parse_guild_ban_remove(self, data: gw.GuildBanRemoveEvent) -> None: guild = self._get_guild(int(data['guild_id'])) if guild is not None and 'user' in data: user = self.store_user(data['user']) self.dispatch('member_unban', guild, user) - def parse_guild_role_create(self, data) -> None: + def parse_guild_role_create(self, data: gw.GuildRoleCreateEvent) -> None: guild = self._get_guild(int(data['guild_id'])) if guild is None: _log.debug('GUILD_ROLE_CREATE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) @@ -1148,7 +1147,7 @@ class ConnectionState: guild._add_role(role) self.dispatch('guild_role_create', role) - def parse_guild_role_delete(self, data) -> None: + def parse_guild_role_delete(self, data: gw.GuildRoleDeleteEvent) -> None: guild = self._get_guild(int(data['guild_id'])) if guild is not None: role_id = int(data['role_id']) @@ -1161,7 +1160,7 @@ class ConnectionState: else: _log.debug('GUILD_ROLE_DELETE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) - def parse_guild_role_update(self, data) -> None: + def parse_guild_role_update(self, data: gw.GuildRoleUpdateEvent) -> None: guild = self._get_guild(int(data['guild_id'])) if guild is not None: role_data = data['role'] @@ -1174,7 +1173,7 @@ class ConnectionState: else: _log.debug('GUILD_ROLE_UPDATE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) - def parse_guild_members_chunk(self, data) -> None: + def parse_guild_members_chunk(self, data: gw.GuildMembersChunkEvent) -> None: guild_id = int(data['guild_id']) guild = self._get_guild(guild_id) presences = data.get('presences', []) @@ -1184,7 +1183,7 @@ class ConnectionState: _log.debug('Processed a chunk for %s members in guild ID %s.', len(members), guild_id) if presences: - member_dict = {str(member.id): member for member in members} + member_dict: Dict[Snowflake, Member] = {str(member.id): member for member in members} for presence in presences: user = presence['user'] member_id = user['id'] @@ -1195,14 +1194,14 @@ class ConnectionState: complete = data.get('chunk_index', 0) + 1 == data.get('chunk_count') self.process_chunk_requests(guild_id, data.get('nonce'), members, complete) - def parse_guild_integrations_update(self, data) -> None: + def parse_guild_integrations_update(self, data: gw.GuildIntegrationsUpdateEvent) -> None: guild = self._get_guild(int(data['guild_id'])) if guild is not None: self.dispatch('guild_integrations_update', guild) else: _log.debug('GUILD_INTEGRATIONS_UPDATE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) - def parse_integration_create(self, data) -> None: + def parse_integration_create(self, data: gw.IntegrationCreateEvent) -> None: guild_id = int(data.pop('guild_id')) guild = self._get_guild(guild_id) if guild is not None: @@ -1212,7 +1211,7 @@ class ConnectionState: else: _log.debug('INTEGRATION_CREATE referencing an unknown guild ID: %s. Discarding.', guild_id) - def parse_integration_update(self, data) -> None: + def parse_integration_update(self, data: gw.IntegrationUpdateEvent) -> None: guild_id = int(data.pop('guild_id')) guild = self._get_guild(guild_id) if guild is not None: @@ -1222,7 +1221,7 @@ class ConnectionState: else: _log.debug('INTEGRATION_UPDATE referencing an unknown guild ID: %s. Discarding.', guild_id) - def parse_integration_delete(self, data) -> None: + def parse_integration_delete(self, data: gw.IntegrationDeleteEvent) -> None: guild_id = int(data['guild_id']) guild = self._get_guild(guild_id) if guild is not None: @@ -1231,7 +1230,7 @@ class ConnectionState: else: _log.debug('INTEGRATION_DELETE referencing an unknown guild ID: %s. Discarding.', guild_id) - def parse_webhooks_update(self, data) -> None: + def parse_webhooks_update(self, data: gw.WebhooksUpdateEvent) -> None: guild = self._get_guild(int(data['guild_id'])) if guild is None: _log.debug('WEBHOOKS_UPDATE referencing an unknown guild ID: %s. Discarding', data['guild_id']) @@ -1243,7 +1242,7 @@ class ConnectionState: else: _log.debug('WEBHOOKS_UPDATE referencing an unknown channel ID: %s. Discarding.', data['channel_id']) - def parse_stage_instance_create(self, data) -> None: + def parse_stage_instance_create(self, data: gw.StageInstanceCreateEvent) -> None: guild = self._get_guild(int(data['guild_id'])) if guild is not None: stage_instance = StageInstance(guild=guild, state=self, data=data) @@ -1252,7 +1251,7 @@ class ConnectionState: else: _log.debug('STAGE_INSTANCE_CREATE referencing unknown guild ID: %s. Discarding.', data['guild_id']) - def parse_stage_instance_update(self, data) -> None: + def parse_stage_instance_update(self, data: gw.StageInstanceUpdateEvent) -> None: guild = self._get_guild(int(data['guild_id'])) if guild is not None: stage_instance = guild._stage_instances.get(int(data['id'])) @@ -1265,7 +1264,7 @@ class ConnectionState: else: _log.debug('STAGE_INSTANCE_UPDATE referencing unknown guild ID: %s. Discarding.', data['guild_id']) - def parse_stage_instance_delete(self, data) -> None: + def parse_stage_instance_delete(self, data: gw.StageInstanceDeleteEvent) -> None: guild = self._get_guild(int(data['guild_id'])) if guild is not None: try: @@ -1277,7 +1276,7 @@ class ConnectionState: else: _log.debug('STAGE_INSTANCE_DELETE referencing unknown guild ID: %s. Discarding.', data['guild_id']) - def parse_voice_state_update(self, data) -> None: + def parse_voice_state_update(self, data: gw.VoiceStateUpdateEvent) -> None: guild = self._get_guild(utils._get_as_snowflake(data, 'guild_id')) channel_id = utils._get_as_snowflake(data, 'channel_id') flags = self.member_cache_flags @@ -1304,18 +1303,15 @@ class ConnectionState: else: _log.debug('VOICE_STATE_UPDATE referencing an unknown member ID: %s. Discarding.', data['user_id']) - def parse_voice_server_update(self, data) -> None: - try: - key_id = int(data['guild_id']) - except KeyError: - key_id = int(data['channel_id']) + def parse_voice_server_update(self, data: gw.VoiceServerUpdateEvent) -> None: + key_id = int(data['guild_id']) vc = self._get_voice_client(key_id) if vc is not None: coro = vc.on_voice_server_update(data) asyncio.create_task(logging_coroutine(coro, info='Voice Protocol voice server update handler')) - def parse_typing_start(self, data) -> None: + def parse_typing_start(self, data: gw.TypingStartEvent) -> None: channel, guild = self._get_guild_channel(data) if channel is not None: member = None @@ -1336,7 +1332,7 @@ class ConnectionState: member = utils.find(lambda x: x.id == user_id, channel.recipients) if member is not None: - timestamp = datetime.datetime.fromtimestamp(data.get('timestamp'), tz=datetime.timezone.utc) + timestamp = datetime.datetime.fromtimestamp(data['timestamp'], tz=datetime.timezone.utc) self.dispatch('typing', channel, member, timestamp) def _get_reaction_user(self, channel: MessageableChannel, user_id: int) -> Optional[Union[User, Member]]: @@ -1482,7 +1478,7 @@ class AutoShardedConnectionState(ConnectionState): self.call_handlers('ready') self.dispatch('ready') - def parse_ready(self, data) -> None: + def parse_ready(self, data: gw.ReadyEvent) -> None: if not hasattr(self, '_ready_state'): self._ready_state = asyncio.Queue() @@ -1500,17 +1496,17 @@ class AutoShardedConnectionState(ConnectionState): self.application_flags = ApplicationFlags._from_value(application['flags']) for guild_data in data['guilds']: - self._add_guild_from_data(guild_data) + self._add_guild_from_data(guild_data) # type: ignore - _add_guild_from_data requires a complete Guild payload if self._messages: self._update_message_references() self.dispatch('connect') - self.dispatch('shard_connect', data['__shard_id__']) + self.dispatch('shard_connect', data['__shard_id__']) # type: ignore if self._ready_task is None: self._ready_task = asyncio.create_task(self._delay_ready()) - def parse_resumed(self, data) -> None: + def parse_resumed(self, data: gw.ResumedEvent) -> None: self.dispatch('resumed') - self.dispatch('shard_resumed', data['__shard_id__']) + self.dispatch('shard_resumed', data['__shard_id__']) # type: ignore diff --git a/discord/types/activity.py b/discord/types/activity.py index 9d46001e1..282656ce0 100644 --- a/discord/types/activity.py +++ b/discord/types/activity.py @@ -25,7 +25,7 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations from typing import List, Literal, Optional, TypedDict -from .user import PartialUser +from .user import User from .snowflake import Snowflake @@ -33,7 +33,7 @@ StatusType = Literal['idle', 'dnd', 'online', 'offline'] class PartialPresenceUpdate(TypedDict): - user: PartialUser + user: User guild_id: Snowflake status: StatusType activities: List[Activity] diff --git a/discord/types/appinfo.py b/discord/types/appinfo.py index e691e812c..282100a24 100644 --- a/discord/types/appinfo.py +++ b/discord/types/appinfo.py @@ -70,3 +70,8 @@ class _PartialAppInfoOptional(TypedDict, total=False): class PartialAppInfo(_PartialAppInfoOptional, BaseAppInfo): pass + + +class GatewayAppInfo(TypedDict): + id: Snowflake + flags: int diff --git a/discord/types/gateway.py b/discord/types/gateway.py index bcf3e4673..e0a07059a 100644 --- a/discord/types/gateway.py +++ b/discord/types/gateway.py @@ -22,7 +22,25 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from typing import TypedDict +from typing import List, Literal, Optional, TypedDict + + +from .activity import PartialPresenceUpdate +from .voice import GuildVoiceState +from .integration import BaseIntegration, IntegrationApplication +from .role import Role +from .channel import Channel, ChannelType, StageInstance +from .interactions import Interaction +from .invite import InviteTargetType +from .emoji import PartialEmoji +from .member import Member, MemberWithUser +from .snowflake import Snowflake +from .message import Message +from .sticker import GuildSticker +from .appinfo import GatewayAppInfo, PartialAppInfo +from .guild import Guild, UnavailableGuild +from .user import User +from .threads import Thread, ThreadMember class SessionStartLimit(TypedDict): @@ -39,3 +57,312 @@ class Gateway(TypedDict): class GatewayBot(Gateway): shards: int session_start_limit: SessionStartLimit + + +class ShardInfo(TypedDict): + shard_id: int + shard_count: int + + +class ReadyEvent(TypedDict): + v: int + user: User + guilds: List[UnavailableGuild] + session_id: str + shard: ShardInfo + application: GatewayAppInfo + + +ResumedEvent = Literal[None] + +MessageCreateEvent = Message + + +class _MessageDeleteEventOptional(TypedDict, total=False): + guild_id: Snowflake + + +class MessageDeleteEvent(_MessageDeleteEventOptional): + id: Snowflake + channel_id: Snowflake + + +class _MessageDeleteBulkEventOptional(TypedDict, total=False): + guild_id: Snowflake + + +class MessageDeleteBulkEvent(_MessageDeleteBulkEventOptional): + ids: List[Snowflake] + channel_id: Snowflake + + +class MessageUpdateEvent(Message): + channel_id: Snowflake + + +class _MessageReactionAddEventOptional(TypedDict, total=False): + member: MemberWithUser + guild_id: Snowflake + + +class MessageReactionAddEvent(_MessageReactionAddEventOptional): + user_id: Snowflake + channel_id: Snowflake + message_id: Snowflake + emoji: PartialEmoji + + +class _MessageReactionRemoveEventOptional(TypedDict, total=False): + guild_id: Snowflake + + +class MessageReactionRemoveEvent(_MessageReactionRemoveEventOptional): + user_id: Snowflake + channel_id: Snowflake + message_id: Snowflake + emoji: PartialEmoji + + +class _MessageReactionRemoveAllEventOptional(TypedDict, total=False): + guild_id: Snowflake + + +class MessageReactionRemoveAllEvent(_MessageReactionRemoveAllEventOptional): + message_id: Snowflake + channel_id: Snowflake + + +class _MessageReactionRemoveEmojiEventOptional(TypedDict, total=False): + guild_id: Snowflake + + +class MessageReactionRemoveEmojiEvent(_MessageReactionRemoveEmojiEventOptional): + emoji: PartialEmoji + message_id: Snowflake + channel_id: Snowflake + + +InteractionCreateEvent = Interaction + + +PresenceUpdateEvent = PartialPresenceUpdate + + +UserUpdateEvent = User + + +class _InviteCreateEventOptional(TypedDict, total=False): + guild_id: Snowflake + inviter: User + target_type: InviteTargetType + target_user: User + target_application: PartialAppInfo + + +class InviteCreateEvent(_InviteCreateEventOptional): + channel_id: Snowflake + code: str + created_at: str + max_age: int + max_uses: int + temporary: bool + uses: Literal[0] + + +class _InviteDeleteEventOptional(TypedDict, total=False): + guild_id: Snowflake + + +class InviteDeleteEvent(_InviteDeleteEventOptional): + channel_id: Snowflake + code: str + + +class _ChannelEvent(TypedDict): + id: Snowflake + type: ChannelType + + +ChannelCreateEvent = ChannelUpdateEvent = ChannelDeleteEvent = _ChannelEvent + + +class _ChannelPinsUpdateEventOptional(TypedDict, total=False): + guild_id: Snowflake + last_pin_timestamp: Optional[str] + + +class ChannelPinsUpdateEvent(_ChannelPinsUpdateEventOptional): + channel_id: Snowflake + + +class _ThreadCreateEventOptional(TypedDict, total=False): + newly_created: bool + members: List[ThreadMember] + + +class ThreadCreateEvent(Thread, _ThreadCreateEventOptional): + ... + + +ThreadUpdateEvent = Thread + + +class ThreadDeleteEvent(TypedDict): + id: Snowflake + guild_id: Snowflake + parent_id: Snowflake + type: ChannelType + + +class _ThreadListSyncEventOptional(TypedDict, total=False): + channel_ids: List[Snowflake] + + +class ThreadListSyncEvent(_ThreadListSyncEventOptional): + guild_id: Snowflake + threads: List[Thread] + members: List[ThreadMember] + + +class ThreadMemberUpdate(ThreadMember): + guild_id: Snowflake + + +class _ThreadMembersUpdateOptional(TypedDict, total=False): + added_members: List[ThreadMember] + removed_member_ids: List[Snowflake] + + +class ThreadMembersUpdate(_ThreadMembersUpdateOptional): + id: Snowflake + guild_id: Snowflake + member_count: int + + +class GuildMemberAddEvent(MemberWithUser): + guild_id: Snowflake + + +class GuildMemberRemoveEvent(TypedDict): + guild_id: Snowflake + user: User + + +class _GuildMemberUpdateEventOptional(TypedDict, total=False): + nick: str + premium_since: Optional[str] + deaf: bool + mute: bool + pending: bool + communication_disabled_until: str + + +class GuildMemberUpdateEvent(_GuildMemberUpdateEventOptional): + guild_id: Snowflake + roles: List[Snowflake] + user: User + avatar: Optional[str] + joined_at: Optional[str] + + +class GuildEmojisUpdateEvent(TypedDict): + guild_id: Snowflake + emojis: List[PartialEmoji] + + +class GuildStickersUpdateEvent(TypedDict): + guild_id: Snowflake + stickers: List[GuildSticker] + + +GuildCreateEvent = GuildUpdateEvent = Guild +GuildDeleteEvent = UnavailableGuild + + +class _GuildBanEvent(TypedDict): + guild_id: Snowflake + user: User + + +GuildBanAddEvent = GuildBanRemoveEvent = _GuildBanEvent + + +class _GuildRoleEvent(TypedDict): + guild_id: Snowflake + role: Role + + +class GuildRoleDeleteEvent(TypedDict): + guild_id: Snowflake + role_id: Snowflake + + +GuildRoleCreateEvent = GuildRoleUpdateEvent = _GuildRoleEvent + + +class _GuildMembersChunkEventOptional(TypedDict, total=False): + not_found: List[Snowflake] + presences: List[PresenceUpdateEvent] + nonce: str + + +class GuildMembersChunkEvent(_GuildMembersChunkEventOptional): + guild_id: Snowflake + members: List[MemberWithUser] + chunk_index: int + chunk_count: int + + +class GuildIntegrationsUpdateEvent(TypedDict): + guild_id: Snowflake + + +class _IntegrationEventOptional(BaseIntegration, total=False): + role_id: Optional[Snowflake] + enable_emoticons: bool + subscriber_count: int + revoked: bool + application: IntegrationApplication + + +class _IntegrationEvent(_IntegrationEventOptional): + guild_id: Snowflake + + +IntegrationCreateEvent = IntegrationUpdateEvent = _IntegrationEvent + + +class _IntegrationDeleteEventOptional(TypedDict, total=False): + application_id: Snowflake + + +class IntegrationDeleteEvent(_IntegrationDeleteEventOptional): + id: Snowflake + guild_id: Snowflake + + +class WebhooksUpdateEvent(TypedDict): + guild_id: Snowflake + channel_id: Snowflake + + +StageInstanceCreateEvent = StageInstanceUpdateEvent = StageInstanceDeleteEvent = StageInstance + +VoiceStateUpdateEvent = GuildVoiceState + + +class VoiceServerUpdateEvent(TypedDict): + token: str + guild_id: Snowflake + endpoint: Optional[str] + + +class _TypingStartEventOptional(TypedDict, total=False): + guild_id: Snowflake + member: MemberWithUser + + +class TypingStartEvent(_TypingStartEventOptional): + channel_id: Snowflake + user_id: Snowflake + timestamp: int diff --git a/discord/types/member.py b/discord/types/member.py index c7bf5ac47..f4748e452 100644 --- a/discord/types/member.py +++ b/discord/types/member.py @@ -22,7 +22,7 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from typing import TypedDict +from typing import Optional, TypedDict from .snowflake import SnowflakeList from .user import User @@ -42,7 +42,7 @@ class Member(PartialMember, total=False): avatar: str user: User nick: str - premium_since: str + premium_since: Optional[str] pending: bool permissions: str communication_disabled_until: str @@ -51,7 +51,7 @@ class Member(PartialMember, total=False): class _OptionalMemberWithUser(PartialMember, total=False): avatar: str nick: str - premium_since: str + premium_since: Optional[str] pending: bool permissions: str communication_disabled_until: str diff --git a/discord/types/message.py b/discord/types/message.py index 0e1bff19d..151b8add8 100644 --- a/discord/types/message.py +++ b/discord/types/message.py @@ -36,6 +36,14 @@ from .interactions import MessageInteraction from .sticker import StickerItem +class _PartialMessageOptional(TypedDict, total=False): + guild_id: Snowflake + + +class PartialMessage(_PartialMessageOptional): + channel_id: Snowflake + + class ChannelMention(TypedDict): id: Snowflake guild_id: Snowflake diff --git a/discord/types/raw_models.py b/discord/types/raw_models.py deleted file mode 100644 index 3c45b299c..000000000 --- a/discord/types/raw_models.py +++ /dev/null @@ -1,87 +0,0 @@ -""" -The MIT License (MIT) - -Copyright (c) 2015-present Rapptz - -Permission is hereby granted, free of charge, to any person obtaining a -copy of this software and associated documentation files (the "Software"), -to deal in the Software without restriction, including without limitation -the rights to use, copy, modify, merge, publish, distribute, sublicense, -and/or sell copies of the Software, and to permit persons to whom the -Software is furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -DEALINGS IN THE SOFTWARE. -""" - -from typing import TypedDict, List -from .snowflake import Snowflake -from .member import Member -from .emoji import PartialEmoji - - -class _MessageEventOptional(TypedDict, total=False): - guild_id: Snowflake - - -class MessageDeleteEvent(_MessageEventOptional): - id: Snowflake - channel_id: Snowflake - - -class BulkMessageDeleteEvent(_MessageEventOptional): - ids: List[Snowflake] - channel_id: Snowflake - - -class _ReactionActionEventOptional(TypedDict, total=False): - guild_id: Snowflake - member: Member - - -class MessageUpdateEvent(_MessageEventOptional): - id: Snowflake - channel_id: Snowflake - - -class ReactionActionEvent(_ReactionActionEventOptional): - user_id: Snowflake - channel_id: Snowflake - message_id: Snowflake - emoji: PartialEmoji - - -class _ReactionClearEventOptional(TypedDict, total=False): - guild_id: Snowflake - - -class ReactionClearEvent(_ReactionClearEventOptional): - channel_id: Snowflake - message_id: Snowflake - - -class _ReactionClearEmojiEventOptional(TypedDict, total=False): - guild_id: Snowflake - - -class ReactionClearEmojiEvent(_ReactionClearEmojiEventOptional): - channel_id: int - message_id: int - emoji: PartialEmoji - - -class _IntegrationDeleteEventOptional(TypedDict, total=False): - application_id: Snowflake - - -class IntegrationDeleteEvent(_IntegrationDeleteEventOptional): - id: Snowflake - guild_id: Snowflake