From c6e6c22a954acf8e171ddc3b57e7525183adfe5e Mon Sep 17 00:00:00 2001 From: dolfies Date: Tue, 9 Nov 2021 21:33:27 -0500 Subject: [PATCH] Migrate state.py --- discord/state.py | 619 +++++++++++++++++++++++++++++++++-------------- 1 file changed, 436 insertions(+), 183 deletions(-) diff --git a/discord/state.py b/discord/state.py index 89198213f..8839aa30a 100644 --- a/discord/state.py +++ b/discord/state.py @@ -25,15 +25,15 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations import asyncio -from collections import deque, OrderedDict +from collections import deque import copy import datetime -import itertools import logging from typing import Dict, Optional, TYPE_CHECKING, Union, Callable, Any, List, TypeVar, Coroutine, Sequence, Tuple, Deque import inspect - +import time import os +import random from .guild import Guild from .activity import BaseActivity @@ -46,18 +46,19 @@ from .channel import * from .channel import _channel_factory from .raw_models import * from .member import Member +from .relationship import Relationship from .role import Role -from .enums import ChannelType, try_enum, Status +from .enums import ChannelType, RequiredActionType, Status, try_enum, UnavailableGuildType, VoiceRegion from . import utils -from .flags import MemberCacheFlags +from .flags import GuildSubscriptionOptions, MemberCacheFlags from .object import Object from .invite import Invite from .integrations import _integration_factory -from .interactions import Interaction -from .ui.view import ViewStore, View from .stage_instance import StageInstance from .threads import Thread, ThreadMember from .sticker import GuildSticker +from .settings import UserSettings + if TYPE_CHECKING: from .abc import PrivateChannel @@ -67,6 +68,8 @@ if TYPE_CHECKING: from .voice_client import VoiceProtocol from .client import Client from .gateway import DiscordWebSocket + from .calls import Call + from .member import VoiceState from .types.activity import Activity as ActivityPayload from .types.channel import DMChannel as DMChannelPayload @@ -75,6 +78,7 @@ if TYPE_CHECKING: from .types.sticker import GuildSticker as GuildStickerPayload from .types.guild import Guild as GuildPayload from .types.message import Message as MessagePayload + from .types.voice import GuildVoiceState T = TypeVar('T') CS = TypeVar('CS', bound='ConnectionState') @@ -136,7 +140,7 @@ async def logging_coroutine(coroutine: Coroutine[Any, Any, T], *, info: str) -> try: await coroutine except Exception: - _log.exception('Exception occurred during %s', info) + _log.exception('Exception occurred during %s.', info) class ConnectionState: @@ -153,10 +157,12 @@ class ConnectionState: hooks: Dict[str, Callable], http: HTTPClient, loop: asyncio.AbstractEventLoop, + client: Client, **options: Any, ) -> None: self.loop: asyncio.AbstractEventLoop = loop self.http: HTTPClient = http + self.client = client self.max_messages: Optional[int] = options.get('max_messages', 1000) if self.max_messages is not None and self.max_messages <= 0: self.max_messages = 1000 @@ -166,9 +172,6 @@ class ConnectionState: self.hooks: Dict[str, Callable] = hooks self._ready_task: Optional[asyncio.Task] = None self.heartbeat_timeout: float = options.get('heartbeat_timeout', 60.0) - self.guild_ready_timeout: float = options.get('guild_ready_timeout', 2.0) - if self.guild_ready_timeout < 0: - raise ValueError('guild_ready_timeout cannot be negative') allowed_mentions = options.get('allowed_mentions') @@ -193,6 +196,16 @@ class ConnectionState: status = str(status) self._chunk_guilds: bool = options.get('chunk_guilds_at_startup', True) + self._request_guilds = options.get('request_guilds', True) + + subscription_options = options.get('guild_subscription_options') + if subscription_options is None: + subscription_options = GuildSubscriptionOptions.off() + else: + if not isinstance(subscription_options, GuildSubscriptionOptions): + raise TypeError(f'subscription_options parameter must be GuildSubscriptionOptions not {type(subscription_options)!r}') + self._subscription_options = subscription_options + self._subscribe_guilds = subscription_options.auto_subscribe cache_flags = options.get('member_cache_flags', None) if cache_flags is None: @@ -216,33 +229,38 @@ class ConnectionState: self.clear() - def clear(self, *, views: bool = True) -> None: + def clear(self) -> None: self.user: Optional[ClientUser] = None + self.settings: Optional[UserSettings] = None + self.analytics_token: Optional[str] = None # Originally, this code used WeakValueDictionary to maintain references to the - # global user mapping. + # global user mapping # However, profiling showed that this came with two cons: # 1. The __weakref__ slot caused a non-trivial increase in memory - # 2. The performance of the mapping caused store_user to be a bottleneck. + # 2. The performance of the mapping caused store_user to be a bottleneck # Since this is undesirable, a mapping is now used instead with stored # references now using a regular dictionary with eviction being done - # using __del__. Testing this for memory leaks led to no discernable leaks, - # though more testing will have to be done. + # using __del__ + # Testing this for memory leaks led to no discernable leaks self._users: Dict[int, User] = {} self._emojis: Dict[int, Emoji] = {} self._stickers: Dict[int, GuildSticker] = {} self._guilds: Dict[int, Guild] = {} - if views: - self._view_store: ViewStore = ViewStore(self) + self._queued_guilds: Dict[int, Guild] = {} + self._unavailable_guilds: Dict[int, UnavailableGuildType] = {} + self._calls: Dict[int, Call] = {} + self._call_message_cache: List[Message] = [] # Hopefully this won't be a memory leak self._voice_clients: Dict[int, VoiceProtocol] = {} + self._voice_states: Dict[int, VoiceState] = {} - # LRU of max size 128 - self._private_channels: OrderedDict[int, PrivateChannel] = OrderedDict() - # extra dict to look up private channels by user id + self._private_channels: Dict[int, PrivateChannel] = {} self._private_channels_by_user: Dict[int, DMChannel] = {} + self._last_private_channel: tuple = (None, None) + if self.max_messages is not None: self._messages: Optional[Deque[Message]] = deque(maxlen=self.max_messages) else: @@ -276,6 +294,10 @@ class ConnectionState: else: await coro(*args, **kwargs) + @property + def ws(self): + return self.client.ws + @property def self_id(self) -> Optional[int]: u = self.user @@ -285,8 +307,33 @@ class ConnectionState: def voice_clients(self) -> List[VoiceProtocol]: return list(self._voice_clients.values()) + def _update_voice_state(self, data: GuildVoiceState, channel_id: int) -> Tuple[User, VoiceState, VoiceState]: + user_id = int(data['user_id']) + user = self.get_user(user_id) + channel = self._get_private_channel(channel_id) + + try: + # Check if we should remove the voice state from cache + if channel is None: + after = self._voice_states.pop(user_id) + else: + after = self._voice_states[user_id] + + before = copy.copy(after) + after._update(data, channel) + except KeyError: + # if we're here then add it into the cache + after = VoiceState(data=data, channel=channel) + before = VoiceState(data=data, channel=None) + self._voice_states[user_id] = after + + return user, before, after + + def _voice_state_for(self, user_id: int) -> Optional[VoiceState]: + return self._voice_states.get(user_id) + def _get_voice_client(self, guild_id: Optional[int]) -> Optional[VoiceProtocol]: - # the keys of self._voice_clients are ints + # The keys of self._voice_clients are ints return self._voice_clients.get(guild_id) # type: ignore def _add_voice_client(self, guild_id: int, voice: VoiceProtocol) -> None: @@ -302,7 +349,15 @@ class ConnectionState: def store_user(self, data: UserPayload) -> User: user_id = int(data['id']) try: - return self._users[user_id] + user = self._users[user_id] + # We use the data available to us since we + # might not have events for that user + # However, the data may only have an ID + try: + user._update(data) + except KeyError: + pass + return user except KeyError: user = User(state=self, data=data) if user.discriminator != '0000': @@ -317,14 +372,14 @@ class ConnectionState: return User(state=self, data=data) def deref_user_no_intents(self, user_id: int) -> None: - return + pass def get_user(self, id: Optional[int]) -> Optional[User]: - # the keys of self._users are ints + # The keys of self._users are ints return self._users.get(id) # type: ignore def store_emoji(self, guild: Guild, data: EmojiPayload) -> Emoji: - # the id will be present here + # The id will be present here emoji_id = int(data['id']) # type: ignore self._emojis[emoji_id] = emoji = Emoji(guild=guild, state=self, data=data) return emoji @@ -334,23 +389,16 @@ class ConnectionState: self._stickers[sticker_id] = sticker = GuildSticker(state=self, data=data) return sticker - def store_view(self, view: View, message_id: Optional[int] = None) -> None: - self._view_store.add_view(view, message_id) - - def prevent_view_updates_for(self, message_id: int) -> Optional[View]: - return self._view_store.remove_message_tracking(message_id) - - @property - def persistent_views(self) -> Sequence[View]: - return self._view_store.persistent_views - @property def guilds(self) -> List[Guild]: return list(self._guilds.values()) def _get_guild(self, guild_id: Optional[int]) -> Optional[Guild]: - # the keys of self._guilds are ints - return self._guilds.get(guild_id) # type: ignore + # The keys of self._guilds are ints + guild = self._guilds.get(guild_id) # type: ignore + if guild is None: + guild = self._queued_guilds.get(guild_id) # type: ignore + return guild def _add_guild(self, guild: Guild) -> None: self._guilds[guild.id] = guild @@ -386,29 +434,39 @@ class ConnectionState: def private_channels(self) -> List[PrivateChannel]: return list(self._private_channels.values()) - def _get_private_channel(self, channel_id: Optional[int]) -> Optional[PrivateChannel]: + async def access_private_channel(self, channel_id: int) -> None: + if not self._get_accessed_private_channel(channel_id): + await self._access_private_channel(channel_id) + self._set_accessed_private_channel(channel_id) + + async def _access_private_channel(self, channel_id: int) -> None: + if (ws := self.ws) is None: + return + try: - # the keys of self._private_channels are ints - value = self._private_channels[channel_id] # type: ignore - except KeyError: - return None - else: - self._private_channels.move_to_end(channel_id) # type: ignore - return value + await ws.access_dm(channel_id) + except Exception as exc: + _log.warning('Sending ACCESS_DM failed for channel %s, (%s).', channel_id, exc) + + def _set_accessed_private_channel(self, channel_id): + self._last_private_channel = (channel_id, time.time()) + + def _get_accessed_private_channel(self, channel_id): + timestamp, existing_id = self._last_private_channel + return existing_id == channel_id and int(time.time() - timestamp) < random.randrange(120000, 420000) + + def _get_private_channel(self, channel_id: Optional[int]) -> Optional[PrivateChannel]: + # The keys of self._private_channels are ints + return self._private_channels.get(channel_id) # type: ignore def _get_private_channel_by_user(self, user_id: Optional[int]) -> Optional[DMChannel]: - # the keys of self._private_channels are ints + # The keys of self._private_channels are ints return self._private_channels_by_user.get(user_id) # type: ignore def _add_private_channel(self, channel: PrivateChannel) -> None: channel_id = channel.id self._private_channels[channel_id] = channel - if len(self._private_channels) > 128: - _, to_remove = self._private_channels.popitem(last=False) - if isinstance(to_remove, DMChannel) and to_remove.recipient: - self._private_channels_by_user.pop(to_remove.recipient.id, None) - if isinstance(channel, DMChannel) and channel.recipient: self._private_channels_by_user[channel.recipient.id] = channel @@ -428,71 +486,77 @@ class ConnectionState: def _get_message(self, msg_id: Optional[int]) -> Optional[Message]: return utils.find(lambda m: m.id == msg_id, reversed(self._messages)) if self._messages else None - def _add_guild_from_data(self, data: GuildPayload) -> Guild: - guild = Guild(data=data, state=self) - self._add_guild(guild) - return guild + def _add_guild_from_data(self, guild: GuildPayload, *, from_ready: bool = False) -> Guild: + guild_id = int(guild['id']) + unavailable = guild.get('unavailable', False) + + if not unavailable: + guild = Guild(data=guild, state=self) + self._add_guild(guild) + return guild + else: + self._unavailable_guilds[guild_id] = UnavailableGuildType.existing if from_ready else UnavailableGuildType.joined + _log.debug('Forcing GUILD_CREATE for unavailable guild %s.' % guild_id) + asyncio.ensure_future(self.request_guild(guild_id), loop=self.loop) def _guild_needs_chunking(self, guild: Guild) -> bool: - # If presences are enabled then we get back the old guild.large behaviour - return self._chunk_guilds and not guild.chunked and not (True and not guild.large) + return self._chunk_guilds and not guild.chunked and any( + guild.me.guild_permissions.kick_members, + guild.me.guild_permissions.manage_roles, + guild.me.guild_permissions.ban_members + ) + + def _guild_needs_subscribing(self, guild): # TODO: rework + return not guild.subscribed and self._subscribe_guilds def _get_guild_channel(self, data: MessagePayload) -> Tuple[Union[Channel, Thread], Optional[Guild]]: channel_id = int(data['channel_id']) try: guild = self._get_guild(int(data['guild_id'])) except KeyError: - channel = DMChannel._from_message(self, channel_id) + channel = self.get_channel(channel_id) guild = None else: channel = guild and guild._resolve_channel(channel_id) return channel or PartialMessageable(state=self, id=channel_id), guild - async def chunker( - self, guild_id: int, query: str = '', limit: int = 0, presences: bool = False, *, nonce: Optional[str] = None + def request_guild(self, guild_id: int) -> None: + return self.ws.request_lazy_guild(guild_id, typing=True, activities=True, threads=True) + + def chunker( + self, guild_id: int, query: str = '', limit: int = 0, presences: bool = True, *, nonce: Optional[str] = None ) -> None: - ws = self._get_websocket(guild_id) # This is ignored upstream - await ws.request_chunks(guild_id, query=query, limit=limit, presences=presences, nonce=nonce) + return self.ws.request_chunks(guild_id, query=query, limit=limit, presences=presences, nonce=nonce) async def query_members(self, guild: Guild, query: str, limit: int, user_ids: List[int], cache: bool, presences: bool): guild_id = guild.id - ws = self._get_websocket(guild_id) - if ws is None: - raise RuntimeError('Somehow do not have a websocket for this guild_id') - request = ChunkRequest(guild.id, self.loop, self._get_guild, cache=cache) self._chunk_requests[request.nonce] = request try: - # start the query operation - await ws.request_chunks( - guild_id, query=query, limit=limit, user_ids=user_ids, presences=presences, nonce=request.nonce + await self.ws.request_chunks( + [guild_id], query=query, limit=limit, user_ids=user_ids, presences=presences, nonce=request.nonce ) return await asyncio.wait_for(request.wait(), timeout=30.0) except asyncio.TimeoutError: - _log.warning('Timed out waiting for chunks with query %r and limit %d for guild_id %d', query, limit, guild_id) + _log.warning('Timed out waiting for chunks with query %r and limit %d for guild_id %d.', query, limit, guild_id) raise async def _delay_ready(self) -> None: try: states = [] - while True: - # this snippet of code is basically waiting N seconds - # until the last GUILD_CREATE was sent - try: - guild = await asyncio.wait_for(self._ready_state.get(), timeout=self.guild_ready_timeout) - except asyncio.TimeoutError: - break - else: - if self._guild_needs_chunking(guild): - future = await self.chunk_guild(guild, wait=False) - states.append((guild, future)) - else: - if guild.unavailable is False: - self.dispatch('guild_available', guild) - else: - self.dispatch('guild_join', guild) + subscribes = [] + for guild in self._guilds.values(): + if self._request_guilds: + await self.request_guild(guild.id) + + if self._guild_needs_chunking(guild): + future = await self.chunk_guild(guild, wait=False) + states.append((guild, future)) + + if self._guild_needs_subscribing(guild): + subscribes.append(guild) for guild, future in states: try: @@ -500,48 +564,92 @@ class ConnectionState: except asyncio.TimeoutError: _log.warning('Timed out waiting for chunks for guild_id %s.', guild.id) - if guild.unavailable is False: - self.dispatch('guild_available', guild) - else: - self.dispatch('guild_join', guild) - - # remove the state - try: - del self._ready_state - except AttributeError: - pass # already been deleted somehow + options = self._subscription_options + ticket = asyncio.Semaphore(options.concurrent_guilds) + await asyncio.gather(*[guild.subscribe(ticket=ticket, max_online=options.max_online) for guild in subscribes]) except asyncio.CancelledError: pass else: - # dispatch the event + # Dispatch the event self.call_handlers('ready') self.dispatch('ready') finally: self._ready_task = None def parse_ready(self, data) -> None: + # Before parsing, we wait for READY_SUPPLEMENTAL + # This has voice state objects, as well as an initial member cache + self._ready_data: dict = data + + def parse_ready_supplemental(self, data) -> None: if self._ready_task is not None: self._ready_task.cancel() - self._ready_state = asyncio.Queue() - self.clear(views=False) - self.user = ClientUser(state=self, data=data['user']) - self.store_user(data['user']) + self.clear() - if self.application_id is None: + # Merge with READY data + extra_data = data + data = self._ready_data + + # Discord bad + for guild_data, guild_extra, merged_members, merged_me, merged_presences in zip( + data.get('guilds', []), + extra_data.get('guilds', []), + extra_data.get('merged_members', []), + data.get('merged_members', []), + extra_data['merged_presences'].get('guilds', []) + ): + guild_data['voice_states'] = guild_extra.get('voice_states', []) + guild_data['merged_members'] = merged_me + guild_data['merged_members'].extend(merged_members) + guild_data['merged_presences'] = merged_presences + # There's also a friends key that has presence data for your friends + # Parsing that would require a redesign of the Relationship class ;-; + + # Self parsing + self.user = ClientUser(state=self, data=data['user']) + user = self.store_user(data['user']) + + # Temp user parsing + temp_users = {user.id: user._to_minimal_user_json()} + for u in data.get('users', []): + u_id = int(u['id']) + temp_users[u_id] = u + + # Guild parsing + for guild_data in data.get('guilds', []): + for member in guild_data['merged_members']: + if 'user' not in member: + member['user'] = temp_users.get(int(member.pop('user_id'))) + self._add_guild_from_data(guild_data, from_ready=True) + + # Relationship parsing + for relationship in data.get('relationships', []): try: - application = data['application'] + r_id = int(relationship['id']) except KeyError: - pass + continue else: - self.application_id = utils._get_as_snowflake(application, 'id') - # flags will always be present here - self.application_flags = ApplicationFlags._from_value(application['flags']) # type: ignore - - for guild_data in data['guilds']: - self._add_guild_from_data(guild_data) - + if 'user' not in relationship: + relationship['user'] = temp_users[int(relationship.pop('user_id'))] + user._relationships[r_id] = Relationship(state=self, data=relationship) + + # Private channel parsing + for pm in data.get('private_channels', []): + factory, _ = _channel_factory(pm['type']) + if 'recipients' not in pm: + pm['recipients'] = [temp_users[int(u_id)] for u_id in pm.pop('recipient_ids')] + self._add_private_channel(factory(me=user, data=pm, state=self)) + + # Extras + region = data.get('geo_ordered_rtc_regions', ['us-west'])[0] + self.preferred_region = try_enum(VoiceRegion, region) + self.settings = UserSettings(data=data.get('user_settings', {}), state=self) + + # We're done + del self._ready_data + self.call_handlers('connect') self.dispatch('connect') self._ready_task = asyncio.create_task(self._delay_ready()) @@ -549,13 +657,20 @@ class ConnectionState: self.dispatch('resumed') def parse_message_create(self, data) -> None: + guild_id = utils._get_as_snowflake(data, 'guild_id') channel, _ = self._get_guild_channel(data) - # channel would be the correct type here + if guild_id in self._unavailable_guilds: # I don't know how I feel about this :( + return + + # Channel will be the correct type here message = Message(channel=channel, data=data, state=self) # type: ignore self.dispatch('message', message) if self._messages is not None: self._messages.append(message) - # we ensure that the channel is either a TextChannel or Thread + if message.call is not None: + self._call_message_cache[message.id] = message + + # We ensure that the channel is either a TextChannel or Thread if channel and channel.__class__ in (TextChannel, Thread): channel.last_message_id = message.id # type: ignore @@ -597,9 +712,6 @@ class ConnectionState: else: self.dispatch('raw_message_edit', raw) - if 'components' in data and self._view_store.is_message_tracked(raw.message_id): - self._view_store.update_from_message(raw.message_id, data['components']) - def parse_message_reaction_add(self, data) -> None: emoji = data['emoji'] emoji_id = utils._get_as_snowflake(emoji, 'id') @@ -673,15 +785,6 @@ class ConnectionState: if reaction: self.dispatch('reaction_clear_emoji', reaction) - def parse_interaction_create(self, data) -> None: - interaction = Interaction(data=data, state=self) - if data['type'] == 3: # interaction component - custom_id = interaction.data['custom_id'] # type: ignore - component_type = interaction.data['component_type'] # type: ignore - self._view_store.dispatch(component_type, custom_id, interaction) - - self.dispatch('interaction', interaction) - def parse_presence_update(self, data) -> None: guild_id = utils._get_as_snowflake(data, 'guild_id') # guild_id won't be None here @@ -791,6 +894,22 @@ class ConnectionState: else: self.dispatch('guild_channel_pins_update', channel, last_pin) + def parse_channel_recipient_add(self, data) -> None: + channel = self._get_private_channel(int(data['channel_id'])) + user = self.store_user(data['user']) + channel.recipients.append(user) + self.dispatch('group_join', channel, user) + + def parse_channel_recipient_remove(self, data) -> None: + channel = self._get_private_channel(int(data['channel_id'])) + user = self.store_user(data['user']) + try: + channel.recipients.remove(user) + except ValueError: + pass + else: + self.dispatch('group_remove', channel, user) + def parse_thread_create(self, data) -> None: guild_id = int(data['guild_id']) guild: Optional[Guild] = self._get_guild(guild_id) @@ -934,7 +1053,7 @@ class ConnectionState: except AttributeError: pass - self.dispatch('member_join', member) + # self.dispatch('member_join', member) def parse_guild_member_remove(self, data) -> None: guild = self._get_guild(int(data['guild_id'])) @@ -948,7 +1067,7 @@ class ConnectionState: member = guild.get_member(user_id) if member is not None: guild._remove_member(member) # type: ignore - self.dispatch('member_remove', member) + # self.dispatch('member_remove', member) else: _log.debug('GUILD_MEMBER_REMOVE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) @@ -981,6 +1100,100 @@ class ConnectionState: guild._add_member(member) _log.debug('GUILD_MEMBER_UPDATE referencing an unknown member ID: %s. Discarding.', user_id) + def parse_guild_sync(self, data) -> None: + print('I noticed you triggered a `GUILD_SYNC`.\nIf you want to share your secrets, please feel free to email me.') + + def parse_guild_member_list_update(self, data) -> None: # Rewrite incoming... + self.dispatch('raw_guild_member_list_update', data) + guild = self._get_guild(int(data['guild_id'])) + if guild is None: + _log.debug('GUILD_MEMBER_LIST_UPDATE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) + return + + ops = data['ops'] + + if data['member_count'] > 0: + guild._member_count = data['member_count'] + + online_count = 0 + for group in data['groups']: + online_count += group['count'] if group['id'] != 'offline' else 0 + guild._online_count = online_count + + for opdata in ops: + op = opdata['op'] + # There are two OPs I'm not parsing. + # INVALIDATE: Usually invalid (hehe). + # DELETE: Sends the index, not the user ID, so I can't do anything with + # it unless I keep a seperate list of the member sidebar (maybe in future). + + if op == 'SYNC': + members = [Member(guild=guild, data=member['member'], state=self) for member in [item for item in opdata.get('items', []) if 'member' in item]] + + member_dict = {str(member.id): member for member in members} + for presence in [item for item in opdata.get('items', []) if 'member' in item]: + presence = presence['member']['presence'] + user = presence['user'] + member_id = user['id'] + member = member_dict.get(member_id) + member._presence_update(presence, user) + + for member in members: + guild._add_member(member) + + if op == 'INSERT': + if 'member' not in opdata['item']: + # Hoisted role INSERT + return + + mdata = opdata['item']['member'] + user = mdata['user'] + user_id = int(user['id']) + + member = guild.get_member(user_id) + if member is not None: # INSERTs are also sent when a user changes range + old_member = Member._copy(member) + member._update(mdata) + user_update = member._update_inner_user(user) + if 'presence' in mdata: + presence = mdata['presence'] + user = presence['user'] + member_id = user['id'] + member._presence_update(presence, user) + if user_update: + self.dispatch('user_update', user_update[0], user_update[1]) + + self.dispatch('member_update', old_member, member) + else: + member = Member(data=mdata, guild=guild, state=self) + guild._add_member(member) + + if op == 'UPDATE': + if 'member' not in opdata['item']: + # Hoisted role UPDATE + return + + mdata = opdata['item']['member'] + user = mdata['user'] + user_id = int(user['id']) + + member = guild.get_member(user_id) + if member is not None: + old_member = Member._copy(member) + member._update(mdata) + user_update = member._update_inner_user(user) + if 'presence' in mdata: + presence = mdata['presence'] + user = presence['user'] + member_id = user['id'] + member._presence_update(presence, user) + if user_update: + self.dispatch('user_update', user_update[0], user_update[1]) + + self.dispatch('member_update', old_member, member) + else: + _log.debug('GUILD_MEMBER_LIST_UPDATE type UPDATE referencing an unknown member ID: %s. Discarding.', user_id) + def parse_guild_emojis_update(self, data) -> None: guild = self._get_guild(int(data['guild_id'])) if guild is None: @@ -1008,15 +1221,12 @@ class ConnectionState: self.dispatch('guild_stickers_update', guild, before_stickers, guild.stickers) def _get_create_guild(self, data): - if data.get('unavailable') is False: - # GUILD_CREATE with unavailable in the response - # usually means that the guild has become available - # and is therefore in the cache - guild = self._get_guild(int(data['id'])) - if guild is not None: - guild.unavailable = False - guild._from_data(data) - return guild + guild = self._get_guild(int(data['id'])) + # Discord being Discord sends a GUILD_CREATE after an OPCode 14 is sent (a la bots) + # However, we want that if we forced a GUILD_CREATE for an unavailable guild + if guild is not None: + guild._from_data(data) + return return self._add_guild_from_data(data) @@ -1034,44 +1244,44 @@ class ConnectionState: return await request.wait() return request.get_future() - async def _chunk_and_dispatch(self, guild, unavailable): - try: - await asyncio.wait_for(self.chunk_guild(guild), timeout=60.0) - except asyncio.TimeoutError: - _log.info('Somehow timed out waiting for chunks.') + async def _parse_and_dispatch(self, guild, *, chunk, subscribe) -> None: + self._queued_guilds[guild.id] = guild + + if chunk: + try: + await asyncio.wait_for(self.chunk_guild(guild), timeout=60.0) + except asyncio.TimeoutError: + log.info('Somehow timed out waiting for chunks.') - if unavailable is False: - self.dispatch('guild_available', guild) + if subscribe: + await guild.subscribe(max_online=self._subscription_options.max_online) + + self._queued_guilds.pop(guild.id) + + # Dispatch available/join depending on circumstances + if guild.id in self._unavailable_guilds: + type = self._unavailable_guilds.pop(guild.id) + if type is UnavailableGuildType.existing: + self.dispatch('guild_available', guild) + else: + self.dispatch('guild_join', guild) else: self.dispatch('guild_join', guild) - def parse_guild_create(self, data) -> None: - unavailable = data.get('unavailable') - if unavailable is True: - # joined a guild with unavailable == True so.. - return + def parse_guild_create(self, data): + guild_id = int(data['id']) guild = self._get_create_guild(data) - try: - # Notify the on_ready state, if any, that this guild is complete. - self._ready_state.put_nowait(guild) - except AttributeError: - pass - else: - # If we're waiting for the event, put the rest on hold + if guild is None: return - # check if it requires chunking - if self._guild_needs_chunking(guild): - asyncio.create_task(self._chunk_and_dispatch(guild, unavailable)) - return + if self._request_guilds: + asyncio.ensure_future(self.request_guild(guild.id), loop=self.loop) - # Dispatch available if newly available - if unavailable is False: - self.dispatch('guild_available', guild) - else: - self.dispatch('guild_join', guild) + # Chunk/subscribe if needed + needs_chunking, needs_subscribing = self._guild_needs_chunking(guild), self._guild_needs_subscribing(guild) + asyncio.ensure_future(self._parse_and_dispatch(guild, chunk=needs_chunking, subscribe=needs_subscribing), loop=self.loop) def parse_guild_update(self, data) -> None: guild = self._get_guild(int(data['id'])) @@ -1095,7 +1305,7 @@ class ConnectionState: self.dispatch('guild_unavailable', guild) return - # do a cleanup of the messages cache + # Cleanup the message cache if self._messages is not None: self._messages: Optional[Deque[Message]] = deque( (msg for msg in self._messages if msg.guild != guild), maxlen=self.max_messages @@ -1105,11 +1315,6 @@ class ConnectionState: self.dispatch('guild_remove', guild) def parse_guild_ban_add(self, data) -> None: - # we make the assumption that GUILD_BAN_ADD is done - # before GUILD_MEMBER_REMOVE is called - # hence we don't remove it from cache or do anything - # strange with it, the main purpose of this event - # is mainly to dispatch to another event worth listening to for logging guild = self._get_guild(int(data['guild_id'])) if guild is not None: try: @@ -1168,7 +1373,7 @@ class ConnectionState: guild = self._get_guild(guild_id) presences = data.get('presences', []) - # the guild won't be None here + # The guild won't be None here members = [Member(guild=guild, data=member, state=self) for member in data.get('members', [])] # type: ignore _log.debug('Processed a chunk for %s members in guild ID %s.', len(members), guild_id) @@ -1266,24 +1471,43 @@ class ConnectionState: else: _log.debug('STAGE_INSTANCE_DELETE referencing unknown guild ID: %s. Discarding.', data['guild_id']) + def parse_call_create(self, data) -> None: + channel = self._get_private_channel(int(data['channel_id'])) + message = self._call_message_cache.pop((int(data['message_id'])), None) + call = channel._add_call(state=self, message=message, channel=channel, **data) + self._calls[channel.id] = call + self.dispatch('call_create', call) + + def parse_call_update(self, data) -> None: + call = self._calls.get(int(data['channel_id'])) + call._update(**data) + self.dispatch('call_update', call) + + def parse_call_delete(self, data) -> None: + call = self._calls.pop(int(data['channel_id']), None) + if call is not None: + call._deleteup() + self.dispatch('call_delete', call) + def parse_voice_state_update(self, data) -> None: guild = self._get_guild(utils._get_as_snowflake(data, 'guild_id')) channel_id = utils._get_as_snowflake(data, 'channel_id') + session_id = data['session_id'] flags = self.member_cache_flags # self.user is *always* cached when this is called self_id = self.user.id # type: ignore - if guild is not None: - if int(data['user_id']) == self_id: - voice = self._get_voice_client(guild.id) - if voice is not None: - coro = voice.on_voice_state_update(data) - asyncio.create_task(logging_coroutine(coro, info='Voice Protocol voice state update handler')) + if int(data['user_id']) == self_id: + voice = self._get_voice_client(guild.id) + if voice is not None: + coro = voice.on_voice_state_update(data) + asyncio.create_task(logging_coroutine(coro, info='Voice Protocol voice state update handler')) + + if guild is not None member, before, after = guild._update_voice_state(data, channel_id) # type: ignore if member is not None: if flags.voice: if channel_id is None and flags._voice_only and member.id != self_id: - # Only remove from cache if we only have the voice flag enabled # Member doesn't meet the Snowflake protocol currently guild._remove_member(member) # type: ignore elif channel_id is not None: @@ -1292,18 +1516,24 @@ class ConnectionState: self.dispatch('voice_state_update', member, before, after) else: _log.debug('VOICE_STATE_UPDATE referencing an unknown member ID: %s. Discarding.', data['user_id']) + else: + user, before, after = self._update_voice_state(data) + self.dispatch('voice_state_update', user, before, after) def parse_voice_server_update(self, data) -> None: - try: - key_id = int(data['guild_id']) - except KeyError: - key_id = int(data['channel_id']) + key_id = utils._get_as_snowflake(data, 'guild_id') + if key_id is None: + key_id = self.user.id vc = self._get_voice_client(key_id) if vc is not None: coro = vc.on_voice_server_update(data) asyncio.create_task(logging_coroutine(coro, info='Voice Protocol voice server update handler')) + def parse_user_required_action_update(self, data) -> None: + required_action = try_enum(RequiredActionType, data['required_action']) + self.dispatch('required_action_update', required_action) + def parse_typing_start(self, data) -> None: channel, guild = self._get_guild_channel(data) if channel is not None: @@ -1328,6 +1558,29 @@ class ConnectionState: timestamp = datetime.datetime.fromtimestamp(data.get('timestamp'), tz=datetime.timezone.utc) self.dispatch('typing', channel, member, timestamp) + def parse_user_required_action_update(self, data) -> None: + required_action = try_enum(RequiredActionType, data['required_action']) + self.dispatch('required_action_update', required_action) + + def parse_relationship_add(self, data) -> None: + key = int(data['id']) + old = self.user.get_relationship(key) + new = Relationship(state=self, data=data) + self.user._relationships[key] = new + if old is not None: + self.dispatch('relationship_update', old, new) + else: + self.dispatch('relationship_add', new) + + def parse_relationship_remove(self, data) -> None: + key = int(data['id']) + try: + old = self.user._relationships.pop(key) + except KeyError: + pass + else: + self.dispatch('relationship_remove', old) + def _get_reaction_user(self, channel: MessageableChannel, user_id: int) -> Optional[Union[User, Member]]: if isinstance(channel, TextChannel): return channel.guild.get_member(user_id)