From 75923005358047bde36657ce93c5e9b183aa82fb Mon Sep 17 00:00:00 2001 From: Stocker <44980366+StockerMC@users.noreply.github.com> Date: Sat, 21 Aug 2021 14:39:02 -0400 Subject: [PATCH] Typehint state.py --- discord/state.py | 411 ++++++++++++++++++++++----------------- discord/types/appinfo.py | 1 + discord/types/voice.py | 4 +- 3 files changed, 237 insertions(+), 179 deletions(-) diff --git a/discord/state.py b/discord/state.py index b6f407d8d..680d112f7 100644 --- a/discord/state.py +++ b/discord/state.py @@ -30,9 +30,7 @@ import copy import datetime import itertools import logging -from typing import Dict, Optional, TYPE_CHECKING, Union -import weakref -import warnings +from typing import Dict, Optional, TYPE_CHECKING, Union, Callable, Any, List, TypeVar, Coroutine, Sequence, Tuple, Deque import inspect import os @@ -56,26 +54,43 @@ from .object import Object from .invite import Invite from .integrations import _integration_factory from .interactions import Interaction -from .ui.view import ViewStore +from .ui.view import ViewStore, View from .stage_instance import StageInstance from .threads import Thread, ThreadMember from .sticker import GuildSticker if TYPE_CHECKING: + from .abc import PrivateChannel + from .message import MessageableChannel + from .guild import GuildChannel, VocalGuildChannel from .http import HTTPClient + from .voice_client import VoiceProtocol + from .client import Client + from .gateway import DiscordWebSocket + from .types.activity import Activity as ActivityPayload + from .types.channel import DMChannel as DMChannelPayload + from .types.user import User as UserPayload + from .types.emoji import Emoji as EmojiPayload + from .types.sticker import GuildSticker as GuildStickerPayload + from .types.guild import Guild as GuildPayload + from .types.message import Message as MessagePayload + + T = TypeVar('T') + CS = TypeVar('CS', bound='ConnectionState') + Channel = Union[GuildChannel, VocalGuildChannel, PrivateChannel, PartialMessageable] class ChunkRequest: - def __init__(self, guild_id, loop, resolver, *, cache=True): - self.guild_id = guild_id - self.resolver = resolver - self.loop = loop - self.cache = cache - self.nonce = os.urandom(16).hex() - self.buffer = [] # List[Member] - self.waiters = [] - - def add_members(self, members): + def __init__(self, guild_id: int, loop: asyncio.AbstractEventLoop, resolver: Callable[[int], Any], *, cache: bool = True) -> None: + self.guild_id: int = guild_id + self.resolver: Callable[[int], Any] = resolver + self.loop: asyncio.AbstractEventLoop = loop + self.cache: bool = cache + self.nonce: str = os.urandom(16).hex() + self.buffer: List[Member] = [] + self.waiters: List[asyncio.Future[List[Member]]] = [] + + def add_members(self, members: List[Member]) -> None: self.buffer.extend(members) if self.cache: guild = self.resolver(self.guild_id) @@ -87,7 +102,7 @@ class ChunkRequest: if existing is None or existing.joined_at is None: guild._add_member(member) - async def wait(self): + async def wait(self) -> List[Member]: future = self.loop.create_future() self.waiters.append(future) try: @@ -95,35 +110,40 @@ class ChunkRequest: finally: self.waiters.remove(future) - def get_future(self): + def get_future(self) -> asyncio.Future[List[Member]]: future = self.loop.create_future() self.waiters.append(future) return future - def done(self): + def done(self) -> None: for future in self.waiters: if not future.done(): future.set_result(self.buffer) -log = logging.getLogger(__name__) +log: logging.Logger = logging.getLogger(__name__) -async def logging_coroutine(coroutine, *, info): +async def logging_coroutine(coroutine: Coroutine[Any, Any, T], *, info: str) -> Optional[T]: try: await coroutine except Exception: log.exception('Exception occurred during %s', info) class ConnectionState: - def __init__(self, *, dispatch, handlers, hooks, http: HTTPClient, loop: asyncio.AbstractEventLoop, **options): + if TYPE_CHECKING: + _get_websocket: Callable[..., DiscordWebSocket] + _get_client: Callable[..., Client] + _parsers: Dict[str, Callable[[Dict[str, Any]], None]] + + def __init__(self, *, dispatch: Callable, handlers: Dict[str, Callable], hooks: Dict[str, Callable], http: HTTPClient, loop: asyncio.AbstractEventLoop, **options: Any) -> None: self.loop: asyncio.AbstractEventLoop = loop self.http: HTTPClient = http self.max_messages: Optional[int] = options.get('max_messages', 1000) if self.max_messages is not None and self.max_messages <= 0: self.max_messages = 1000 - self.dispatch = dispatch - self.handlers = handlers - self.hooks = hooks + self.dispatch: Callable = dispatch + self.handlers: Dict[str, Callable] = handlers + self.hooks: Dict[str, Callable] = hooks self.shard_count: Optional[int] = None self._ready_task: Optional[asyncio.Task] = None self.application_id: Optional[int] = utils._get_as_snowflake(options, 'application_id') @@ -195,8 +215,8 @@ class ConnectionState: self.clear() - def clear(self): - self.user = None + def clear(self) -> None: + self.user: Optional[ClientUser] = None # Originally, this code used WeakValueDictionary to maintain references to the # global user mapping. @@ -210,19 +230,22 @@ class ConnectionState: # using __del__. Testing this for memory leaks led to no discernable leaks, # though more testing will have to be done. self._users: Dict[int, User] = {} - self._emojis = {} - self._stickers = {} - self._guilds = {} - self._view_store = ViewStore(self) - self._voice_clients = {} + self._emojis: Dict[int, Emoji] = {} + self._stickers: Dict[int, GuildSticker] = {} + self._guilds: Dict[int, Guild] = {} + self._view_store: ViewStore = ViewStore(self) + self._voice_clients: Dict[int, VoiceProtocol] = {} # LRU of max size 128 - self._private_channels = OrderedDict() + self._private_channels: OrderedDict[int, PrivateChannel] = OrderedDict() # extra dict to look up private channels by user id - self._private_channels_by_user = {} - self._messages = self.max_messages and deque(maxlen=self.max_messages) + self._private_channels_by_user: Dict[int, PrivateChannel] = {} + if self.max_messages is not None: + self._messages: Optional[Deque[Message]] = deque(maxlen=self.max_messages) + else: + self._messages: Optional[Deque[Message]] = None - def process_chunk_requests(self, guild_id, nonce, members, complete): + def process_chunk_requests(self, guild_id: int, nonce: Optional[str], members: List[Member], complete: bool) -> None: removed = [] for key, request in self._chunk_requests.items(): if request.guild_id == guild_id and request.nonce == nonce: @@ -234,7 +257,7 @@ class ConnectionState: for key in removed: del self._chunk_requests[key] - def call_handlers(self, key, *args, **kwargs): + def call_handlers(self, key: str, *args: Any, **kwargs: Any) -> None: try: func = self.handlers[key] except KeyError: @@ -242,7 +265,7 @@ class ConnectionState: else: func(*args, **kwargs) - async def call_hooks(self, key, *args, **kwargs): + async def call_hooks(self, key: str, *args: Any, **kwargs: Any) -> None: try: coro = self.hooks[key] except KeyError: @@ -251,34 +274,35 @@ class ConnectionState: await coro(*args, **kwargs) @property - def self_id(self): + def self_id(self) -> Optional[int]: u = self.user return u.id if u else None @property - def intents(self): + def intents(self) -> Intents: ret = Intents.none() ret.value = self._intents.value return ret @property - def voice_clients(self): + def voice_clients(self) -> List[VoiceProtocol]: return list(self._voice_clients.values()) - def _get_voice_client(self, guild_id): - return self._voice_clients.get(guild_id) + def _get_voice_client(self, guild_id: Optional[int]) -> Optional[VoiceProtocol]: + # the keys of self._voice_clients are ints + return self._voice_clients.get(guild_id) # type: ignore - def _add_voice_client(self, guild_id, voice): + def _add_voice_client(self, guild_id: int, voice: VoiceProtocol) -> None: self._voice_clients[guild_id] = voice - def _remove_voice_client(self, guild_id): + def _remove_voice_client(self, guild_id: int) -> None: self._voice_clients.pop(guild_id, None) - def _update_references(self, ws): + def _update_references(self, ws: DiscordWebSocket) -> None: for vc in self.voice_clients: - vc.main_ws = ws + vc.main_ws = ws # type: ignore - def store_user(self, data): + def store_user(self, data: UserPayload) -> User: user_id = int(data['id']) try: return self._users[user_id] @@ -289,49 +313,52 @@ class ConnectionState: user._stored = True return user - def deref_user(self, user_id): + def deref_user(self, user_id: int) -> None: self._users.pop(user_id, None) - def create_user(self, data): + def create_user(self, data: UserPayload) -> User: return User(state=self, data=data) - def deref_user_no_intents(self, user_id): + def deref_user_no_intents(self, user_id: int) -> None: return - def get_user(self, id): - return self._users.get(id) + def get_user(self, id: Optional[int]) -> Optional[User]: + # the keys of self._users are ints + return self._users.get(id) # type: ignore - def store_emoji(self, guild, data): - emoji_id = int(data['id']) + def store_emoji(self, guild: Guild, data: EmojiPayload) -> Emoji: + # the id will be present here + emoji_id = int(data['id']) # type: ignore self._emojis[emoji_id] = emoji = Emoji(guild=guild, state=self, data=data) return emoji - def store_sticker(self, guild, data): + def store_sticker(self, guild: Guild, data: GuildStickerPayload) -> GuildSticker: sticker_id = int(data['id']) self._stickers[sticker_id] = sticker = GuildSticker(state=self, data=data) return sticker - def store_view(self, view, message_id=None): + def store_view(self, view: View, message_id: Optional[int] = None) -> None: self._view_store.add_view(view, message_id) - def prevent_view_updates_for(self, message_id): + def prevent_view_updates_for(self, message_id: int) -> Optional[View]: return self._view_store.remove_message_tracking(message_id) @property - def persistent_views(self): + def persistent_views(self) -> Sequence[View]: return self._view_store.persistent_views @property - def guilds(self): + def guilds(self) -> List[Guild]: return list(self._guilds.values()) - def _get_guild(self, guild_id): - return self._guilds.get(guild_id) + def _get_guild(self, guild_id: Optional[int]) -> Optional[Guild]: + # the keys of self._guilds are ints + return self._guilds.get(guild_id) # type: ignore - def _add_guild(self, guild): + def _add_guild(self, guild: Guild) -> None: self._guilds[guild.id] = guild - def _remove_guild(self, guild): + def _remove_guild(self, guild: Guild) -> None: self._guilds.pop(guild.id, None) for emoji in guild.emojis: @@ -343,36 +370,40 @@ class ConnectionState: del guild @property - def emojis(self): + def emojis(self) -> List[Emoji]: return list(self._emojis.values()) @property - def stickers(self): + def stickers(self) -> List[GuildSticker]: return list(self._stickers.values()) - def get_emoji(self, emoji_id): - return self._emojis.get(emoji_id) + def get_emoji(self, emoji_id: Optional[int]) -> Optional[Emoji]: + # the keys of self._emojis are ints + return self._emojis.get(emoji_id) # type: ignore - def get_sticker(self, sticker_id): - return self._stickers.get(sticker_id) + def get_sticker(self, sticker_id: Optional[int]) -> Optional[GuildSticker]: + # the keys of self._stickers are ints + return self._stickers.get(sticker_id) # type: ignore @property - def private_channels(self): + def private_channels(self) -> List[PrivateChannel]: return list(self._private_channels.values()) - def _get_private_channel(self, channel_id): + def _get_private_channel(self, channel_id: Optional[int]) -> Optional[PrivateChannel]: try: - value = self._private_channels[channel_id] + # the keys of self._private_channels are ints + value = self._private_channels[channel_id] # type: ignore except KeyError: return None else: - self._private_channels.move_to_end(channel_id) + self._private_channels.move_to_end(channel_id) # type: ignore return value - def _get_private_channel_by_user(self, user_id): - return self._private_channels_by_user.get(user_id) + def _get_private_channel_by_user(self, user_id: Optional[int]) -> Optional[PrivateChannel]: + # the keys of self._private_channels are ints + return self._private_channels_by_user.get(user_id) # type: ignore - def _add_private_channel(self, channel): + def _add_private_channel(self, channel: PrivateChannel) -> None: channel_id = channel.id self._private_channels[channel_id] = channel @@ -384,29 +415,32 @@ class ConnectionState: if isinstance(channel, DMChannel) and channel.recipient: self._private_channels_by_user[channel.recipient.id] = channel - def add_dm_channel(self, data): - channel = DMChannel(me=self.user, state=self, data=data) + def add_dm_channel(self, data: DMChannelPayload) -> DMChannel: + # self.user is *always* cached when this is called + channel = DMChannel(me=self.user, state=self, data=data) # type: ignore self._add_private_channel(channel) return channel - def _remove_private_channel(self, channel): + def _remove_private_channel(self, channel: PrivateChannel) -> None: self._private_channels.pop(channel.id, None) if isinstance(channel, DMChannel): - self._private_channels_by_user.pop(channel.recipient.id, None) + recipient = channel.recipient + if recipient is not None: + self._private_channels_by_user.pop(recipient.id, None) - def _get_message(self, msg_id): + def _get_message(self, msg_id: Optional[int]) -> Optional[Message]: return utils.find(lambda m: m.id == msg_id, reversed(self._messages)) if self._messages else None - def _add_guild_from_data(self, guild): - guild = Guild(data=guild, state=self) + def _add_guild_from_data(self, data: GuildPayload) -> Guild: + guild = Guild(data=data, state=self) self._add_guild(guild) return guild - def _guild_needs_chunking(self, guild): + def _guild_needs_chunking(self, guild: Guild) -> bool: # If presences are enabled then we get back the old guild.large behaviour return self._chunk_guilds and not guild.chunked and not (self._intents.presences and not guild.large) - def _get_guild_channel(self, data): + def _get_guild_channel(self, data: MessagePayload) -> Tuple[Union[Channel, Thread], Optional[Guild]]: channel_id = int(data['channel_id']) try: guild = self._get_guild(int(data['guild_id'])) @@ -418,11 +452,11 @@ class ConnectionState: return channel or PartialMessageable(state=self, id=channel_id), guild - async def chunker(self, guild_id, query='', limit=0, presences=False, *, nonce=None): + async def chunker(self, guild_id: int, query: str = '', limit: int = 0, presences: bool = False, *, nonce: Optional[str] = None) -> None: ws = self._get_websocket(guild_id) # This is ignored upstream await ws.request_chunks(guild_id, query=query, limit=limit, presences=presences, nonce=nonce) - async def query_members(self, guild, query, limit, user_ids, cache, presences): + async def query_members(self, guild: Guild, query: str, limit: int, user_ids: List[int], cache: bool, presences: bool): guild_id = guild.id ws = self._get_websocket(guild_id) if ws is None: @@ -439,7 +473,7 @@ class ConnectionState: log.warning('Timed out waiting for chunks with query %r and limit %d for guild_id %d', query, limit, guild_id) raise - async def _delay_ready(self): + async def _delay_ready(self) -> None: try: states = [] while True: @@ -485,13 +519,13 @@ class ConnectionState: finally: self._ready_task = None - def parse_ready(self, data): + def parse_ready(self, data) -> None: if self._ready_task is not None: self._ready_task.cancel() self._ready_state = asyncio.Queue() self.clear() - self.user = user = ClientUser(state=self, data=data['user']) + self.user = ClientUser(state=self, data=data['user']) self.store_user(data['user']) if self.application_id is None: @@ -501,7 +535,8 @@ class ConnectionState: pass else: self.application_id = utils._get_as_snowflake(application, 'id') - self.application_flags = ApplicationFlags._from_value(application['flags']) + # flags will always be present here + self.application_flags = ApplicationFlags._from_value(application['flags']) # type: ignore for guild_data in data['guilds']: self._add_guild_from_data(guild_data) @@ -509,19 +544,21 @@ class ConnectionState: self.dispatch('connect') self._ready_task = asyncio.create_task(self._delay_ready()) - def parse_resumed(self, data): + def parse_resumed(self, data) -> None: self.dispatch('resumed') - def parse_message_create(self, data): + def parse_message_create(self, data) -> None: channel, _ = self._get_guild_channel(data) - message = Message(channel=channel, data=data, state=self) + # channel would be the correct type here + message = Message(channel=channel, data=data, state=self) # type: ignore self.dispatch('message', message) if self._messages is not None: self._messages.append(message) + # we ensure that the channel is either a TextChannel or Thread if channel and channel.__class__ in (TextChannel, Thread): - channel.last_message_id = message.id + channel.last_message_id = message.id # type: ignore - def parse_message_delete(self, data): + def parse_message_delete(self, data) -> None: raw = RawMessageDeleteEvent(data) found = self._get_message(raw.message_id) raw.cached_message = found @@ -530,7 +567,7 @@ class ConnectionState: self.dispatch('message_delete', found) self._messages.remove(found) - def parse_message_delete_bulk(self, data): + def parse_message_delete_bulk(self, data) -> None: raw = RawBulkMessageDeleteEvent(data) if self._messages: found_messages = [message for message in self._messages if message.id in raw.message_ids] @@ -541,9 +578,10 @@ class ConnectionState: if found_messages: self.dispatch('bulk_message_delete', found_messages) for msg in found_messages: - self._messages.remove(msg) + # self._messages won't be None here + self._messages.remove(msg) # type: ignore - def parse_message_update(self, data): + def parse_message_update(self, data) -> None: raw = RawMessageUpdateEvent(data) message = self._get_message(raw.message_id) if message is not None: @@ -561,7 +599,7 @@ class ConnectionState: if 'components' in data and self._view_store.is_message_tracked(raw.message_id): self._view_store.update_from_message(raw.message_id, data['components']) - def parse_message_reaction_add(self, data): + def parse_message_reaction_add(self, data) -> None: emoji = data['emoji'] emoji_id = utils._get_as_snowflake(emoji, 'id') emoji = PartialEmoji.with_state(self, id=emoji_id, animated=emoji.get('animated', False), name=emoji['name']) @@ -570,7 +608,10 @@ class ConnectionState: member_data = data.get('member') if member_data: guild = self._get_guild(raw.guild_id) - raw.member = Member(data=member_data, guild=guild, state=self) + if guild is not None: + raw.member = Member(data=member_data, guild=guild, state=self) + else: + raw.member = None else: raw.member = None self.dispatch('raw_reaction_add', raw) @@ -585,7 +626,7 @@ class ConnectionState: if user: self.dispatch('reaction_add', reaction, user) - def parse_message_reaction_remove_all(self, data): + def parse_message_reaction_remove_all(self, data) -> None: raw = RawReactionClearEvent(data) self.dispatch('raw_reaction_clear', raw) @@ -595,7 +636,7 @@ class ConnectionState: message.reactions.clear() self.dispatch('reaction_clear', message, old_reactions) - def parse_message_reaction_remove(self, data): + def parse_message_reaction_remove(self, data) -> None: emoji = data['emoji'] emoji_id = utils._get_as_snowflake(emoji, 'id') emoji = PartialEmoji.with_state(self, id=emoji_id, name=emoji['name']) @@ -614,7 +655,7 @@ class ConnectionState: if user: self.dispatch('reaction_remove', reaction, user) - def parse_message_reaction_remove_emoji(self, data): + def parse_message_reaction_remove_emoji(self, data) -> None: emoji = data['emoji'] emoji_id = utils._get_as_snowflake(emoji, 'id') emoji = PartialEmoji.with_state(self, id=emoji_id, name=emoji['name']) @@ -631,7 +672,7 @@ class ConnectionState: if reaction: self.dispatch('reaction_clear_emoji', reaction) - def parse_interaction_create(self, data): + def parse_interaction_create(self, data) -> None: interaction = Interaction(data=data, state=self) if data['type'] == 3: # interaction component custom_id = interaction.data['custom_id'] # type: ignore @@ -640,8 +681,9 @@ class ConnectionState: self.dispatch('interaction', interaction) - def parse_presence_update(self, data): + def parse_presence_update(self, data) -> None: guild_id = utils._get_as_snowflake(data, 'guild_id') + # guild_id won't be None here guild = self._get_guild(guild_id) if guild is None: log.debug('PRESENCE_UPDATE referencing an unknown guild ID: %s. Discarding.', guild_id) @@ -661,21 +703,23 @@ class ConnectionState: self.dispatch('presence_update', old_member, member) - def parse_user_update(self, data): - self.user._update(data) - ref = self._users.get(self.user.id) + def parse_user_update(self, data) -> None: + # self.user is *always* cached when this is called + user: ClientUser = self.user # type: ignore + user._update(data) + ref = self._users.get(user.id) if ref: ref._update(data) - def parse_invite_create(self, data): + def parse_invite_create(self, data) -> None: invite = Invite.from_gateway(state=self, data=data) self.dispatch('invite_create', invite) - def parse_invite_delete(self, data): + def parse_invite_delete(self, data) -> None: invite = Invite.from_gateway(state=self, data=data) self.dispatch('invite_delete', invite) - def parse_channel_delete(self, data): + def parse_channel_delete(self, data) -> None: guild = self._get_guild(utils._get_as_snowflake(data, 'guild_id')) channel_id = int(data['id']) if guild is not None: @@ -684,13 +728,14 @@ class ConnectionState: guild._remove_channel(channel) self.dispatch('guild_channel_delete', channel) - def parse_channel_update(self, data): + def parse_channel_update(self, data) -> None: channel_type = try_enum(ChannelType, data.get('type')) channel_id = int(data['id']) if channel_type is ChannelType.group: channel = self._get_private_channel(channel_id) old_channel = copy.copy(channel) - channel._update_group(data) + # the channel is a GroupChannel + channel._update_group(data) # type: ignore self.dispatch('private_channel_update', old_channel, channel) return @@ -707,7 +752,7 @@ class ConnectionState: else: log.debug('CHANNEL_UPDATE referencing an unknown guild ID: %s. Discarding.', guild_id) - def parse_channel_create(self, data): + def parse_channel_create(self, data) -> None: factory, ch_type = _channel_factory(data['type']) if factory is None: log.debug('CHANNEL_CREATE referencing an unknown channel type %s. Discarding.', data['type']) @@ -716,14 +761,15 @@ class ConnectionState: guild_id = utils._get_as_snowflake(data, 'guild_id') guild = self._get_guild(guild_id) if guild is not None: - channel = factory(guild=guild, state=self, data=data) - guild._add_channel(channel) + # the factory can't be a DMChannel or GroupChannel here + channel = factory(guild=guild, state=self, data=data) # type: ignore + guild._add_channel(channel) # type: ignore self.dispatch('guild_channel_create', channel) else: log.debug('CHANNEL_CREATE referencing an unknown guild ID: %s. Discarding.', guild_id) return - def parse_channel_pins_update(self, data): + def parse_channel_pins_update(self, data) -> None: channel_id = int(data['channel_id']) try: guild = self._get_guild(int(data['guild_id'])) @@ -744,7 +790,7 @@ class ConnectionState: else: self.dispatch('guild_channel_pins_update', channel, last_pin) - def parse_thread_create(self, data): + def parse_thread_create(self, data) -> None: guild_id = int(data['guild_id']) guild: Optional[Guild] = self._get_guild(guild_id) if guild is None: @@ -757,7 +803,7 @@ class ConnectionState: if not has_thread: self.dispatch('thread_join', thread) - def parse_thread_update(self, data): + def parse_thread_update(self, data) -> None: guild_id = int(data['guild_id']) guild = self._get_guild(guild_id) if guild is None: @@ -775,7 +821,7 @@ class ConnectionState: guild._add_thread(thread) self.dispatch('thread_join', thread) - def parse_thread_delete(self, data): + def parse_thread_delete(self, data) -> None: guild_id = int(data['guild_id']) guild = self._get_guild(guild_id) if guild is None: @@ -785,10 +831,10 @@ class ConnectionState: thread_id = int(data['id']) thread = guild.get_thread(thread_id) if thread is not None: - guild._remove_thread(thread) + guild._remove_thread(thread) # type: ignore self.dispatch('thread_delete', thread) - def parse_thread_list_sync(self, data): + def parse_thread_list_sync(self, data) -> None: guild_id = int(data['guild_id']) guild: Optional[Guild] = self._get_guild(guild_id) if guild is None: @@ -827,7 +873,7 @@ class ConnectionState: for thread in previous_threads.values(): self.dispatch('thread_remove', thread) - def parse_thread_member_update(self, data): + def parse_thread_member_update(self, data) -> None: guild_id = int(data['guild_id']) guild: Optional[Guild] = self._get_guild(guild_id) if guild is None: @@ -843,7 +889,7 @@ class ConnectionState: member = ThreadMember(thread, data) thread.me = member - def parse_thread_members_update(self, data): + def parse_thread_members_update(self, data) -> None: guild_id = int(data['guild_id']) guild: Optional[Guild] = self._get_guild(guild_id) if guild is None: @@ -875,7 +921,7 @@ class ConnectionState: else: self.dispatch('thread_remove', thread) - def parse_guild_member_add(self, data): + def parse_guild_member_add(self, data) -> None: guild = self._get_guild(int(data['guild_id'])) if guild is None: log.debug('GUILD_MEMBER_ADD referencing an unknown guild ID: %s. Discarding.', data['guild_id']) @@ -892,7 +938,7 @@ class ConnectionState: self.dispatch('member_join', member) - def parse_guild_member_remove(self, data): + def parse_guild_member_remove(self, data) -> None: guild = self._get_guild(int(data['guild_id'])) if guild is not None: try: @@ -903,12 +949,12 @@ class ConnectionState: user_id = int(data['user']['id']) member = guild.get_member(user_id) if member is not None: - guild._remove_member(member) + guild._remove_member(member) # type: ignore self.dispatch('member_remove', member) else: log.debug('GUILD_MEMBER_REMOVE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) - def parse_guild_member_update(self, data): + def parse_guild_member_update(self, data) -> None: guild = self._get_guild(int(data['guild_id'])) user = data['user'] user_id = int(user['id']) @@ -937,7 +983,7 @@ class ConnectionState: guild._add_member(member) log.debug('GUILD_MEMBER_UPDATE referencing an unknown member ID: %s. Discarding.', user_id) - def parse_guild_emojis_update(self, data): + def parse_guild_emojis_update(self, data) -> None: guild = self._get_guild(int(data['guild_id'])) if guild is None: log.debug('GUILD_EMOJIS_UPDATE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) @@ -946,10 +992,11 @@ class ConnectionState: before_emojis = guild.emojis for emoji in before_emojis: self._emojis.pop(emoji.id, None) - guild.emojis = tuple(map(lambda d: self.store_emoji(guild, d), data['emojis'])) + # guild won't be None here + guild.emojis = tuple(map(lambda d: self.store_emoji(guild, d), data['emojis'])) #type: ignore self.dispatch('guild_emojis_update', guild, before_emojis, guild.emojis) - def parse_guild_stickers_update(self, data): + def parse_guild_stickers_update(self, data) -> None: guild = self._get_guild(int(data['guild_id'])) if guild is None: log.debug('GUILD_STICKERS_UPDATE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) @@ -958,7 +1005,8 @@ class ConnectionState: before_stickers = guild.stickers for emoji in before_stickers: self._stickers.pop(emoji.id, None) - guild.stickers = tuple(map(lambda d: self.store_sticker(guild, d), data['stickers'])) + # guild won't be None here + guild.stickers = tuple(map(lambda d: self.store_sticker(guild, d), data['stickers'])) # type: ignore self.dispatch('guild_stickers_update', guild, before_stickers, guild.stickers) def _get_create_guild(self, data): @@ -999,7 +1047,7 @@ class ConnectionState: else: self.dispatch('guild_join', guild) - def parse_guild_create(self, data): + def parse_guild_create(self, data) -> None: unavailable = data.get('unavailable') if unavailable is True: # joined a guild with unavailable == True so.. @@ -1027,7 +1075,7 @@ class ConnectionState: else: self.dispatch('guild_join', guild) - def parse_guild_update(self, data): + def parse_guild_update(self, data) -> None: guild = self._get_guild(int(data['id'])) if guild is not None: old_guild = copy.copy(guild) @@ -1036,7 +1084,7 @@ class ConnectionState: else: log.debug('GUILD_UPDATE referencing an unknown guild ID: %s. Discarding.', data['id']) - def parse_guild_delete(self, data): + def parse_guild_delete(self, data) -> None: guild = self._get_guild(int(data['id'])) if guild is None: log.debug('GUILD_DELETE referencing an unknown guild ID: %s. Discarding.', data['id']) @@ -1051,12 +1099,12 @@ class ConnectionState: # do a cleanup of the messages cache if self._messages is not None: - self._messages = deque((msg for msg in self._messages if msg.guild != guild), maxlen=self.max_messages) + self._messages: Optional[Deque[Message]] = deque((msg for msg in self._messages if msg.guild != guild), maxlen=self.max_messages) self._remove_guild(guild) self.dispatch('guild_remove', guild) - def parse_guild_ban_add(self, data): + def parse_guild_ban_add(self, data) -> None: # we make the assumption that GUILD_BAN_ADD is done # before GUILD_MEMBER_REMOVE is called # hence we don't remove it from cache or do anything @@ -1072,13 +1120,13 @@ class ConnectionState: member = guild.get_member(user.id) or user self.dispatch('member_ban', guild, member) - def parse_guild_ban_remove(self, data): + def parse_guild_ban_remove(self, data) -> None: guild = self._get_guild(int(data['guild_id'])) if guild is not None and 'user' in data: user = self.store_user(data['user']) self.dispatch('member_unban', guild, user) - def parse_guild_role_create(self, data): + def parse_guild_role_create(self, data) -> None: guild = self._get_guild(int(data['guild_id'])) if guild is None: log.debug('GUILD_ROLE_CREATE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) @@ -1089,7 +1137,7 @@ class ConnectionState: guild._add_role(role) self.dispatch('guild_role_create', role) - def parse_guild_role_delete(self, data): + def parse_guild_role_delete(self, data) -> None: guild = self._get_guild(int(data['guild_id'])) if guild is not None: role_id = int(data['role_id']) @@ -1102,7 +1150,7 @@ class ConnectionState: else: log.debug('GUILD_ROLE_DELETE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) - def parse_guild_role_update(self, data): + def parse_guild_role_update(self, data) -> None: guild = self._get_guild(int(data['guild_id'])) if guild is not None: role_data = data['role'] @@ -1115,12 +1163,13 @@ class ConnectionState: else: log.debug('GUILD_ROLE_UPDATE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) - def parse_guild_members_chunk(self, data): + def parse_guild_members_chunk(self, data) -> None: guild_id = int(data['guild_id']) guild = self._get_guild(guild_id) presences = data.get('presences', []) - members = [Member(guild=guild, data=member, state=self) for member in data.get('members', [])] + # the guild won't be None here + members = [Member(guild=guild, data=member, state=self) for member in data.get('members', [])] # type: ignore log.debug('Processed a chunk for %s members in guild ID %s.', len(members), guild_id) if presences: @@ -1129,19 +1178,20 @@ class ConnectionState: user = presence['user'] member_id = user['id'] member = member_dict.get(member_id) - member._presence_update(presence, user) + if member is not None: + member._presence_update(presence, user) complete = data.get('chunk_index', 0) + 1 == data.get('chunk_count') self.process_chunk_requests(guild_id, data.get('nonce'), members, complete) - def parse_guild_integrations_update(self, data): + def parse_guild_integrations_update(self, data) -> None: guild = self._get_guild(int(data['guild_id'])) if guild is not None: self.dispatch('guild_integrations_update', guild) else: log.debug('GUILD_INTEGRATIONS_UPDATE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) - def parse_integration_create(self, data): + def parse_integration_create(self, data) -> None: guild_id = int(data.pop('guild_id')) guild = self._get_guild(guild_id) if guild is not None: @@ -1151,7 +1201,7 @@ class ConnectionState: else: log.debug('INTEGRATION_CREATE referencing an unknown guild ID: %s. Discarding.', guild_id) - def parse_integration_update(self, data): + def parse_integration_update(self, data) -> None: guild_id = int(data.pop('guild_id')) guild = self._get_guild(guild_id) if guild is not None: @@ -1161,7 +1211,7 @@ class ConnectionState: else: log.debug('INTEGRATION_UPDATE referencing an unknown guild ID: %s. Discarding.', guild_id) - def parse_integration_delete(self, data): + def parse_integration_delete(self, data) -> None: guild_id = int(data['guild_id']) guild = self._get_guild(guild_id) if guild is not None: @@ -1170,7 +1220,7 @@ class ConnectionState: else: log.debug('INTEGRATION_DELETE referencing an unknown guild ID: %s. Discarding.', guild_id) - def parse_webhooks_update(self, data): + def parse_webhooks_update(self, data) -> None: guild = self._get_guild(int(data['guild_id'])) if guild is None: log.debug('WEBHOOKS_UPDATE referencing an unknown guild ID: %s. Discarding', data['guild_id']) @@ -1182,7 +1232,7 @@ class ConnectionState: else: log.debug('WEBHOOKS_UPDATE referencing an unknown channel ID: %s. Discarding.', data['channel_id']) - def parse_stage_instance_create(self, data): + def parse_stage_instance_create(self, data) -> None: guild = self._get_guild(int(data['guild_id'])) if guild is not None: stage_instance = StageInstance(guild=guild, state=self, data=data) @@ -1191,7 +1241,7 @@ class ConnectionState: else: log.debug('STAGE_INSTANCE_CREATE referencing unknown guild ID: %s. Discarding.', data['guild_id']) - def parse_stage_instance_update(self, data): + def parse_stage_instance_update(self, data) -> None: guild = self._get_guild(int(data['guild_id'])) if guild is not None: stage_instance = guild._stage_instances.get(int(data['id'])) @@ -1204,7 +1254,7 @@ class ConnectionState: else: log.debug('STAGE_INSTANCE_UPDATE referencing unknown guild ID: %s. Discarding.', data['guild_id']) - def parse_stage_instance_delete(self, data): + def parse_stage_instance_delete(self, data) -> None: guild = self._get_guild(int(data['guild_id'])) if guild is not None: try: @@ -1216,11 +1266,12 @@ class ConnectionState: else: log.debug('STAGE_INSTANCE_DELETE referencing unknown guild ID: %s. Discarding.', data['guild_id']) - def parse_voice_state_update(self, data): + def parse_voice_state_update(self, data) -> None: guild = self._get_guild(utils._get_as_snowflake(data, 'guild_id')) channel_id = utils._get_as_snowflake(data, 'channel_id') flags = self.member_cache_flags - self_id = self.user.id + # self.user is *always* cached when this is called + self_id = self.user.id # type: ignore if guild is not None: if int(data['user_id']) == self_id: voice = self._get_voice_client(guild.id) @@ -1228,12 +1279,13 @@ class ConnectionState: coro = voice.on_voice_state_update(data) asyncio.create_task(logging_coroutine(coro, info='Voice Protocol voice state update handler')) - member, before, after = guild._update_voice_state(data, channel_id) + member, before, after = guild._update_voice_state(data, channel_id) # type: ignore if member is not None: if flags.voice: if channel_id is None and flags._voice_only and member.id != self_id: - # Only remove from cache iff we only have the voice flag enabled - guild._remove_member(member) + # Only remove from cache if we only have the voice flag enabled + # Member doesn't meet the Snowflake protocol currently + guild._remove_member(member) # type: ignore elif channel_id is not None: guild._add_member(member) @@ -1241,7 +1293,7 @@ class ConnectionState: else: log.debug('VOICE_STATE_UPDATE referencing an unknown member ID: %s. Discarding.', data['user_id']) - def parse_voice_server_update(self, data): + def parse_voice_server_update(self, data) -> None: try: key_id = int(data['guild_id']) except KeyError: @@ -1252,15 +1304,18 @@ class ConnectionState: coro = vc.on_voice_server_update(data) asyncio.create_task(logging_coroutine(coro, info='Voice Protocol voice server update handler')) - def parse_typing_start(self, data): + def parse_typing_start(self, data) -> None: channel, guild = self._get_guild_channel(data) if channel is not None: member = None user_id = utils._get_as_snowflake(data, 'user_id') if isinstance(channel, DMChannel): member = channel.recipient + elif isinstance(channel, (Thread, TextChannel)) and guild is not None: - member = guild.get_member(user_id) + # user_id won't be None + member = guild.get_member(user_id) # type: ignore + if member is None: member_data = data.get('member') if member_data: @@ -1273,12 +1328,12 @@ class ConnectionState: timestamp = datetime.datetime.fromtimestamp(data.get('timestamp'), tz=datetime.timezone.utc) self.dispatch('typing', channel, member, timestamp) - def _get_reaction_user(self, channel, user_id): + def _get_reaction_user(self, channel: MessageableChannel, user_id: int) -> Optional[Union[User, Member]]: if isinstance(channel, TextChannel): return channel.guild.get_member(user_id) return self.get_user(user_id) - def get_reaction_emoji(self, data): + def get_reaction_emoji(self, data) -> Union[Emoji, PartialEmoji]: emoji_id = utils._get_as_snowflake(data, 'id') if not emoji_id: @@ -1289,7 +1344,7 @@ class ConnectionState: except KeyError: return PartialEmoji.with_state(self, animated=data.get('animated', False), id=emoji_id, name=data['name']) - def _upgrade_partial_emoji(self, emoji): + def _upgrade_partial_emoji(self, emoji: PartialEmoji) -> Union[Emoji, PartialEmoji, str]: emoji_id = emoji.id if not emoji_id: return emoji.name @@ -1298,7 +1353,7 @@ class ConnectionState: except KeyError: return emoji - def get_channel(self, id): + def get_channel(self, id: Optional[int]) -> Optional[Union[Channel, Thread]]: if id is None: return None @@ -1311,18 +1366,18 @@ class ConnectionState: if channel is not None: return channel - def create_message(self, *, channel, data): + def create_message(self, *, channel: Union[TextChannel, Thread, DMChannel, GroupChannel, PartialMessageable], data: MessagePayload) -> Message: return Message(state=self, channel=channel, data=data) class AutoShardedConnectionState(ConnectionState): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) - self._ready_task = None - self.shard_ids = () - self.shards_launched = asyncio.Event() + self.shard_ids: Union[List[int], range] = [] + self.shards_launched: asyncio.Event = asyncio.Event() - def _update_message_references(self): - for msg in self._messages: + def _update_message_references(self) -> None: + # self._messages won't be None when this is called + for msg in self._messages: # type: ignore if not msg.guild: continue @@ -1330,13 +1385,14 @@ class AutoShardedConnectionState(ConnectionState): if new_guild is not None and new_guild is not msg.guild: channel_id = msg.channel.id channel = new_guild._resolve_channel(channel_id) or Object(id=channel_id) - msg._rebind_cached_references(new_guild, channel) + # channel will either be a TextChannel, Thread or Object + msg._rebind_cached_references(new_guild, channel) # type: ignore - async def chunker(self, guild_id, query='', limit=0, presences=False, *, shard_id=None, nonce=None): + async def chunker(self, guild_id: int, query: str = '', limit: int = 0, presences: bool = False, *, shard_id: Optional[int] = None, nonce: Optional[str] = None) -> None: ws = self._get_websocket(guild_id, shard_id=shard_id) await ws.request_chunks(guild_id, query=query, limit=limit, presences=presences, nonce=nonce) - async def _delay_ready(self): + async def _delay_ready(self) -> None: await self.shards_launched.wait() processed = [] max_concurrency = len(self.shard_ids) * 2 @@ -1403,12 +1459,13 @@ class AutoShardedConnectionState(ConnectionState): self.call_handlers('ready') self.dispatch('ready') - def parse_ready(self, data): + def parse_ready(self, data) -> None: if not hasattr(self, '_ready_state'): self._ready_state = asyncio.Queue() self.user = user = ClientUser(state=self, data=data['user']) - self._users[user.id] = user + # self._users is a list of Users, we're setting a ClientUser + self._users[user.id] = user # type: ignore if self.application_id is None: try: @@ -1431,6 +1488,6 @@ class AutoShardedConnectionState(ConnectionState): if self._ready_task is None: self._ready_task = asyncio.create_task(self._delay_ready()) - def parse_resumed(self, data): + def parse_resumed(self, data) -> None: self.dispatch('resumed') self.dispatch('shard_resumed', data['__shard_id__']) diff --git a/discord/types/appinfo.py b/discord/types/appinfo.py index d223837fa..912d5ad5d 100644 --- a/discord/types/appinfo.py +++ b/discord/types/appinfo.py @@ -61,6 +61,7 @@ class _PartialAppInfoOptional(TypedDict, total=False): terms_of_service_url: str privacy_policy_url: str max_participants: int + flags: int class PartialAppInfo(_PartialAppInfoOptional, BaseAppInfo): pass diff --git a/discord/types/voice.py b/discord/types/voice.py index b29288d45..825840258 100644 --- a/discord/types/voice.py +++ b/discord/types/voice.py @@ -24,14 +24,14 @@ DEALINGS IN THE SOFTWARE. from typing import Optional, TypedDict, List, Literal from .snowflake import Snowflake -from .member import Member +from .member import GatewayMember SupportedModes = Literal['xsalsa20_poly1305_lite', 'xsalsa20_poly1305_suffix', 'xsalsa20_poly1305'] class _PartialVoiceStateOptional(TypedDict, total=False): - member: Member + member: GatewayMember self_stream: bool