Browse Source

Migrate state.py

pull/10109/head
dolfies 4 years ago
parent
commit
c6e6c22a95
  1. 619
      discord/state.py

619
discord/state.py

@ -25,15 +25,15 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
import asyncio
from collections import deque, OrderedDict
from collections import deque
import copy
import datetime
import itertools
import logging
from typing import Dict, Optional, TYPE_CHECKING, Union, Callable, Any, List, TypeVar, Coroutine, Sequence, Tuple, Deque
import inspect
import time
import os
import random
from .guild import Guild
from .activity import BaseActivity
@ -46,18 +46,19 @@ from .channel import *
from .channel import _channel_factory
from .raw_models import *
from .member import Member
from .relationship import Relationship
from .role import Role
from .enums import ChannelType, try_enum, Status
from .enums import ChannelType, RequiredActionType, Status, try_enum, UnavailableGuildType, VoiceRegion
from . import utils
from .flags import MemberCacheFlags
from .flags import GuildSubscriptionOptions, MemberCacheFlags
from .object import Object
from .invite import Invite
from .integrations import _integration_factory
from .interactions import Interaction
from .ui.view import ViewStore, View
from .stage_instance import StageInstance
from .threads import Thread, ThreadMember
from .sticker import GuildSticker
from .settings import UserSettings
if TYPE_CHECKING:
from .abc import PrivateChannel
@ -67,6 +68,8 @@ if TYPE_CHECKING:
from .voice_client import VoiceProtocol
from .client import Client
from .gateway import DiscordWebSocket
from .calls import Call
from .member import VoiceState
from .types.activity import Activity as ActivityPayload
from .types.channel import DMChannel as DMChannelPayload
@ -75,6 +78,7 @@ if TYPE_CHECKING:
from .types.sticker import GuildSticker as GuildStickerPayload
from .types.guild import Guild as GuildPayload
from .types.message import Message as MessagePayload
from .types.voice import GuildVoiceState
T = TypeVar('T')
CS = TypeVar('CS', bound='ConnectionState')
@ -136,7 +140,7 @@ async def logging_coroutine(coroutine: Coroutine[Any, Any, T], *, info: str) ->
try:
await coroutine
except Exception:
_log.exception('Exception occurred during %s', info)
_log.exception('Exception occurred during %s.', info)
class ConnectionState:
@ -153,10 +157,12 @@ class ConnectionState:
hooks: Dict[str, Callable],
http: HTTPClient,
loop: asyncio.AbstractEventLoop,
client: Client,
**options: Any,
) -> None:
self.loop: asyncio.AbstractEventLoop = loop
self.http: HTTPClient = http
self.client = client
self.max_messages: Optional[int] = options.get('max_messages', 1000)
if self.max_messages is not None and self.max_messages <= 0:
self.max_messages = 1000
@ -166,9 +172,6 @@ class ConnectionState:
self.hooks: Dict[str, Callable] = hooks
self._ready_task: Optional[asyncio.Task] = None
self.heartbeat_timeout: float = options.get('heartbeat_timeout', 60.0)
self.guild_ready_timeout: float = options.get('guild_ready_timeout', 2.0)
if self.guild_ready_timeout < 0:
raise ValueError('guild_ready_timeout cannot be negative')
allowed_mentions = options.get('allowed_mentions')
@ -193,6 +196,16 @@ class ConnectionState:
status = str(status)
self._chunk_guilds: bool = options.get('chunk_guilds_at_startup', True)
self._request_guilds = options.get('request_guilds', True)
subscription_options = options.get('guild_subscription_options')
if subscription_options is None:
subscription_options = GuildSubscriptionOptions.off()
else:
if not isinstance(subscription_options, GuildSubscriptionOptions):
raise TypeError(f'subscription_options parameter must be GuildSubscriptionOptions not {type(subscription_options)!r}')
self._subscription_options = subscription_options
self._subscribe_guilds = subscription_options.auto_subscribe
cache_flags = options.get('member_cache_flags', None)
if cache_flags is None:
@ -216,33 +229,38 @@ class ConnectionState:
self.clear()
def clear(self, *, views: bool = True) -> None:
def clear(self) -> None:
self.user: Optional[ClientUser] = None
self.settings: Optional[UserSettings] = None
self.analytics_token: Optional[str] = None
# Originally, this code used WeakValueDictionary to maintain references to the
# global user mapping.
# global user mapping
# However, profiling showed that this came with two cons:
# 1. The __weakref__ slot caused a non-trivial increase in memory
# 2. The performance of the mapping caused store_user to be a bottleneck.
# 2. The performance of the mapping caused store_user to be a bottleneck
# Since this is undesirable, a mapping is now used instead with stored
# references now using a regular dictionary with eviction being done
# using __del__. Testing this for memory leaks led to no discernable leaks,
# though more testing will have to be done.
# using __del__
# Testing this for memory leaks led to no discernable leaks
self._users: Dict[int, User] = {}
self._emojis: Dict[int, Emoji] = {}
self._stickers: Dict[int, GuildSticker] = {}
self._guilds: Dict[int, Guild] = {}
if views:
self._view_store: ViewStore = ViewStore(self)
self._queued_guilds: Dict[int, Guild] = {}
self._unavailable_guilds: Dict[int, UnavailableGuildType] = {}
self._calls: Dict[int, Call] = {}
self._call_message_cache: List[Message] = [] # Hopefully this won't be a memory leak
self._voice_clients: Dict[int, VoiceProtocol] = {}
self._voice_states: Dict[int, VoiceState] = {}
# LRU of max size 128
self._private_channels: OrderedDict[int, PrivateChannel] = OrderedDict()
# extra dict to look up private channels by user id
self._private_channels: Dict[int, PrivateChannel] = {}
self._private_channels_by_user: Dict[int, DMChannel] = {}
self._last_private_channel: tuple = (None, None)
if self.max_messages is not None:
self._messages: Optional[Deque[Message]] = deque(maxlen=self.max_messages)
else:
@ -276,6 +294,10 @@ class ConnectionState:
else:
await coro(*args, **kwargs)
@property
def ws(self):
return self.client.ws
@property
def self_id(self) -> Optional[int]:
u = self.user
@ -285,8 +307,33 @@ class ConnectionState:
def voice_clients(self) -> List[VoiceProtocol]:
return list(self._voice_clients.values())
def _update_voice_state(self, data: GuildVoiceState, channel_id: int) -> Tuple[User, VoiceState, VoiceState]:
user_id = int(data['user_id'])
user = self.get_user(user_id)
channel = self._get_private_channel(channel_id)
try:
# Check if we should remove the voice state from cache
if channel is None:
after = self._voice_states.pop(user_id)
else:
after = self._voice_states[user_id]
before = copy.copy(after)
after._update(data, channel)
except KeyError:
# if we're here then add it into the cache
after = VoiceState(data=data, channel=channel)
before = VoiceState(data=data, channel=None)
self._voice_states[user_id] = after
return user, before, after
def _voice_state_for(self, user_id: int) -> Optional[VoiceState]:
return self._voice_states.get(user_id)
def _get_voice_client(self, guild_id: Optional[int]) -> Optional[VoiceProtocol]:
# the keys of self._voice_clients are ints
# The keys of self._voice_clients are ints
return self._voice_clients.get(guild_id) # type: ignore
def _add_voice_client(self, guild_id: int, voice: VoiceProtocol) -> None:
@ -302,7 +349,15 @@ class ConnectionState:
def store_user(self, data: UserPayload) -> User:
user_id = int(data['id'])
try:
return self._users[user_id]
user = self._users[user_id]
# We use the data available to us since we
# might not have events for that user
# However, the data may only have an ID
try:
user._update(data)
except KeyError:
pass
return user
except KeyError:
user = User(state=self, data=data)
if user.discriminator != '0000':
@ -317,14 +372,14 @@ class ConnectionState:
return User(state=self, data=data)
def deref_user_no_intents(self, user_id: int) -> None:
return
pass
def get_user(self, id: Optional[int]) -> Optional[User]:
# the keys of self._users are ints
# The keys of self._users are ints
return self._users.get(id) # type: ignore
def store_emoji(self, guild: Guild, data: EmojiPayload) -> Emoji:
# the id will be present here
# The id will be present here
emoji_id = int(data['id']) # type: ignore
self._emojis[emoji_id] = emoji = Emoji(guild=guild, state=self, data=data)
return emoji
@ -334,23 +389,16 @@ class ConnectionState:
self._stickers[sticker_id] = sticker = GuildSticker(state=self, data=data)
return sticker
def store_view(self, view: View, message_id: Optional[int] = None) -> None:
self._view_store.add_view(view, message_id)
def prevent_view_updates_for(self, message_id: int) -> Optional[View]:
return self._view_store.remove_message_tracking(message_id)
@property
def persistent_views(self) -> Sequence[View]:
return self._view_store.persistent_views
@property
def guilds(self) -> List[Guild]:
return list(self._guilds.values())
def _get_guild(self, guild_id: Optional[int]) -> Optional[Guild]:
# the keys of self._guilds are ints
return self._guilds.get(guild_id) # type: ignore
# The keys of self._guilds are ints
guild = self._guilds.get(guild_id) # type: ignore
if guild is None:
guild = self._queued_guilds.get(guild_id) # type: ignore
return guild
def _add_guild(self, guild: Guild) -> None:
self._guilds[guild.id] = guild
@ -386,29 +434,39 @@ class ConnectionState:
def private_channels(self) -> List[PrivateChannel]:
return list(self._private_channels.values())
def _get_private_channel(self, channel_id: Optional[int]) -> Optional[PrivateChannel]:
async def access_private_channel(self, channel_id: int) -> None:
if not self._get_accessed_private_channel(channel_id):
await self._access_private_channel(channel_id)
self._set_accessed_private_channel(channel_id)
async def _access_private_channel(self, channel_id: int) -> None:
if (ws := self.ws) is None:
return
try:
# the keys of self._private_channels are ints
value = self._private_channels[channel_id] # type: ignore
except KeyError:
return None
else:
self._private_channels.move_to_end(channel_id) # type: ignore
return value
await ws.access_dm(channel_id)
except Exception as exc:
_log.warning('Sending ACCESS_DM failed for channel %s, (%s).', channel_id, exc)
def _set_accessed_private_channel(self, channel_id):
self._last_private_channel = (channel_id, time.time())
def _get_accessed_private_channel(self, channel_id):
timestamp, existing_id = self._last_private_channel
return existing_id == channel_id and int(time.time() - timestamp) < random.randrange(120000, 420000)
def _get_private_channel(self, channel_id: Optional[int]) -> Optional[PrivateChannel]:
# The keys of self._private_channels are ints
return self._private_channels.get(channel_id) # type: ignore
def _get_private_channel_by_user(self, user_id: Optional[int]) -> Optional[DMChannel]:
# the keys of self._private_channels are ints
# The keys of self._private_channels are ints
return self._private_channels_by_user.get(user_id) # type: ignore
def _add_private_channel(self, channel: PrivateChannel) -> None:
channel_id = channel.id
self._private_channels[channel_id] = channel
if len(self._private_channels) > 128:
_, to_remove = self._private_channels.popitem(last=False)
if isinstance(to_remove, DMChannel) and to_remove.recipient:
self._private_channels_by_user.pop(to_remove.recipient.id, None)
if isinstance(channel, DMChannel) and channel.recipient:
self._private_channels_by_user[channel.recipient.id] = channel
@ -428,71 +486,77 @@ class ConnectionState:
def _get_message(self, msg_id: Optional[int]) -> Optional[Message]:
return utils.find(lambda m: m.id == msg_id, reversed(self._messages)) if self._messages else None
def _add_guild_from_data(self, data: GuildPayload) -> Guild:
guild = Guild(data=data, state=self)
self._add_guild(guild)
return guild
def _add_guild_from_data(self, guild: GuildPayload, *, from_ready: bool = False) -> Guild:
guild_id = int(guild['id'])
unavailable = guild.get('unavailable', False)
if not unavailable:
guild = Guild(data=guild, state=self)
self._add_guild(guild)
return guild
else:
self._unavailable_guilds[guild_id] = UnavailableGuildType.existing if from_ready else UnavailableGuildType.joined
_log.debug('Forcing GUILD_CREATE for unavailable guild %s.' % guild_id)
asyncio.ensure_future(self.request_guild(guild_id), loop=self.loop)
def _guild_needs_chunking(self, guild: Guild) -> bool:
# If presences are enabled then we get back the old guild.large behaviour
return self._chunk_guilds and not guild.chunked and not (True and not guild.large)
return self._chunk_guilds and not guild.chunked and any(
guild.me.guild_permissions.kick_members,
guild.me.guild_permissions.manage_roles,
guild.me.guild_permissions.ban_members
)
def _guild_needs_subscribing(self, guild): # TODO: rework
return not guild.subscribed and self._subscribe_guilds
def _get_guild_channel(self, data: MessagePayload) -> Tuple[Union[Channel, Thread], Optional[Guild]]:
channel_id = int(data['channel_id'])
try:
guild = self._get_guild(int(data['guild_id']))
except KeyError:
channel = DMChannel._from_message(self, channel_id)
channel = self.get_channel(channel_id)
guild = None
else:
channel = guild and guild._resolve_channel(channel_id)
return channel or PartialMessageable(state=self, id=channel_id), guild
async def chunker(
self, guild_id: int, query: str = '', limit: int = 0, presences: bool = False, *, nonce: Optional[str] = None
def request_guild(self, guild_id: int) -> None:
return self.ws.request_lazy_guild(guild_id, typing=True, activities=True, threads=True)
def chunker(
self, guild_id: int, query: str = '', limit: int = 0, presences: bool = True, *, nonce: Optional[str] = None
) -> None:
ws = self._get_websocket(guild_id) # This is ignored upstream
await ws.request_chunks(guild_id, query=query, limit=limit, presences=presences, nonce=nonce)
return self.ws.request_chunks(guild_id, query=query, limit=limit, presences=presences, nonce=nonce)
async def query_members(self, guild: Guild, query: str, limit: int, user_ids: List[int], cache: bool, presences: bool):
guild_id = guild.id
ws = self._get_websocket(guild_id)
if ws is None:
raise RuntimeError('Somehow do not have a websocket for this guild_id')
request = ChunkRequest(guild.id, self.loop, self._get_guild, cache=cache)
self._chunk_requests[request.nonce] = request
try:
# start the query operation
await ws.request_chunks(
guild_id, query=query, limit=limit, user_ids=user_ids, presences=presences, nonce=request.nonce
await self.ws.request_chunks(
[guild_id], query=query, limit=limit, user_ids=user_ids, presences=presences, nonce=request.nonce
)
return await asyncio.wait_for(request.wait(), timeout=30.0)
except asyncio.TimeoutError:
_log.warning('Timed out waiting for chunks with query %r and limit %d for guild_id %d', query, limit, guild_id)
_log.warning('Timed out waiting for chunks with query %r and limit %d for guild_id %d.', query, limit, guild_id)
raise
async def _delay_ready(self) -> None:
try:
states = []
while True:
# this snippet of code is basically waiting N seconds
# until the last GUILD_CREATE was sent
try:
guild = await asyncio.wait_for(self._ready_state.get(), timeout=self.guild_ready_timeout)
except asyncio.TimeoutError:
break
else:
if self._guild_needs_chunking(guild):
future = await self.chunk_guild(guild, wait=False)
states.append((guild, future))
else:
if guild.unavailable is False:
self.dispatch('guild_available', guild)
else:
self.dispatch('guild_join', guild)
subscribes = []
for guild in self._guilds.values():
if self._request_guilds:
await self.request_guild(guild.id)
if self._guild_needs_chunking(guild):
future = await self.chunk_guild(guild, wait=False)
states.append((guild, future))
if self._guild_needs_subscribing(guild):
subscribes.append(guild)
for guild, future in states:
try:
@ -500,48 +564,92 @@ class ConnectionState:
except asyncio.TimeoutError:
_log.warning('Timed out waiting for chunks for guild_id %s.', guild.id)
if guild.unavailable is False:
self.dispatch('guild_available', guild)
else:
self.dispatch('guild_join', guild)
# remove the state
try:
del self._ready_state
except AttributeError:
pass # already been deleted somehow
options = self._subscription_options
ticket = asyncio.Semaphore(options.concurrent_guilds)
await asyncio.gather(*[guild.subscribe(ticket=ticket, max_online=options.max_online) for guild in subscribes])
except asyncio.CancelledError:
pass
else:
# dispatch the event
# Dispatch the event
self.call_handlers('ready')
self.dispatch('ready')
finally:
self._ready_task = None
def parse_ready(self, data) -> None:
# Before parsing, we wait for READY_SUPPLEMENTAL
# This has voice state objects, as well as an initial member cache
self._ready_data: dict = data
def parse_ready_supplemental(self, data) -> None:
if self._ready_task is not None:
self._ready_task.cancel()
self._ready_state = asyncio.Queue()
self.clear(views=False)
self.user = ClientUser(state=self, data=data['user'])
self.store_user(data['user'])
self.clear()
if self.application_id is None:
# Merge with READY data
extra_data = data
data = self._ready_data
# Discord bad
for guild_data, guild_extra, merged_members, merged_me, merged_presences in zip(
data.get('guilds', []),
extra_data.get('guilds', []),
extra_data.get('merged_members', []),
data.get('merged_members', []),
extra_data['merged_presences'].get('guilds', [])
):
guild_data['voice_states'] = guild_extra.get('voice_states', [])
guild_data['merged_members'] = merged_me
guild_data['merged_members'].extend(merged_members)
guild_data['merged_presences'] = merged_presences
# There's also a friends key that has presence data for your friends
# Parsing that would require a redesign of the Relationship class ;-;
# Self parsing
self.user = ClientUser(state=self, data=data['user'])
user = self.store_user(data['user'])
# Temp user parsing
temp_users = {user.id: user._to_minimal_user_json()}
for u in data.get('users', []):
u_id = int(u['id'])
temp_users[u_id] = u
# Guild parsing
for guild_data in data.get('guilds', []):
for member in guild_data['merged_members']:
if 'user' not in member:
member['user'] = temp_users.get(int(member.pop('user_id')))
self._add_guild_from_data(guild_data, from_ready=True)
# Relationship parsing
for relationship in data.get('relationships', []):
try:
application = data['application']
r_id = int(relationship['id'])
except KeyError:
pass
continue
else:
self.application_id = utils._get_as_snowflake(application, 'id')
# flags will always be present here
self.application_flags = ApplicationFlags._from_value(application['flags']) # type: ignore
for guild_data in data['guilds']:
self._add_guild_from_data(guild_data)
if 'user' not in relationship:
relationship['user'] = temp_users[int(relationship.pop('user_id'))]
user._relationships[r_id] = Relationship(state=self, data=relationship)
# Private channel parsing
for pm in data.get('private_channels', []):
factory, _ = _channel_factory(pm['type'])
if 'recipients' not in pm:
pm['recipients'] = [temp_users[int(u_id)] for u_id in pm.pop('recipient_ids')]
self._add_private_channel(factory(me=user, data=pm, state=self))
# Extras
region = data.get('geo_ordered_rtc_regions', ['us-west'])[0]
self.preferred_region = try_enum(VoiceRegion, region)
self.settings = UserSettings(data=data.get('user_settings', {}), state=self)
# We're done
del self._ready_data
self.call_handlers('connect')
self.dispatch('connect')
self._ready_task = asyncio.create_task(self._delay_ready())
@ -549,13 +657,20 @@ class ConnectionState:
self.dispatch('resumed')
def parse_message_create(self, data) -> None:
guild_id = utils._get_as_snowflake(data, 'guild_id')
channel, _ = self._get_guild_channel(data)
# channel would be the correct type here
if guild_id in self._unavailable_guilds: # I don't know how I feel about this :(
return
# Channel will be the correct type here
message = Message(channel=channel, data=data, state=self) # type: ignore
self.dispatch('message', message)
if self._messages is not None:
self._messages.append(message)
# we ensure that the channel is either a TextChannel or Thread
if message.call is not None:
self._call_message_cache[message.id] = message
# We ensure that the channel is either a TextChannel or Thread
if channel and channel.__class__ in (TextChannel, Thread):
channel.last_message_id = message.id # type: ignore
@ -597,9 +712,6 @@ class ConnectionState:
else:
self.dispatch('raw_message_edit', raw)
if 'components' in data and self._view_store.is_message_tracked(raw.message_id):
self._view_store.update_from_message(raw.message_id, data['components'])
def parse_message_reaction_add(self, data) -> None:
emoji = data['emoji']
emoji_id = utils._get_as_snowflake(emoji, 'id')
@ -673,15 +785,6 @@ class ConnectionState:
if reaction:
self.dispatch('reaction_clear_emoji', reaction)
def parse_interaction_create(self, data) -> None:
interaction = Interaction(data=data, state=self)
if data['type'] == 3: # interaction component
custom_id = interaction.data['custom_id'] # type: ignore
component_type = interaction.data['component_type'] # type: ignore
self._view_store.dispatch(component_type, custom_id, interaction)
self.dispatch('interaction', interaction)
def parse_presence_update(self, data) -> None:
guild_id = utils._get_as_snowflake(data, 'guild_id')
# guild_id won't be None here
@ -791,6 +894,22 @@ class ConnectionState:
else:
self.dispatch('guild_channel_pins_update', channel, last_pin)
def parse_channel_recipient_add(self, data) -> None:
channel = self._get_private_channel(int(data['channel_id']))
user = self.store_user(data['user'])
channel.recipients.append(user)
self.dispatch('group_join', channel, user)
def parse_channel_recipient_remove(self, data) -> None:
channel = self._get_private_channel(int(data['channel_id']))
user = self.store_user(data['user'])
try:
channel.recipients.remove(user)
except ValueError:
pass
else:
self.dispatch('group_remove', channel, user)
def parse_thread_create(self, data) -> None:
guild_id = int(data['guild_id'])
guild: Optional[Guild] = self._get_guild(guild_id)
@ -934,7 +1053,7 @@ class ConnectionState:
except AttributeError:
pass
self.dispatch('member_join', member)
# self.dispatch('member_join', member)
def parse_guild_member_remove(self, data) -> None:
guild = self._get_guild(int(data['guild_id']))
@ -948,7 +1067,7 @@ class ConnectionState:
member = guild.get_member(user_id)
if member is not None:
guild._remove_member(member) # type: ignore
self.dispatch('member_remove', member)
# self.dispatch('member_remove', member)
else:
_log.debug('GUILD_MEMBER_REMOVE referencing an unknown guild ID: %s. Discarding.', data['guild_id'])
@ -981,6 +1100,100 @@ class ConnectionState:
guild._add_member(member)
_log.debug('GUILD_MEMBER_UPDATE referencing an unknown member ID: %s. Discarding.', user_id)
def parse_guild_sync(self, data) -> None:
print('I noticed you triggered a `GUILD_SYNC`.\nIf you want to share your secrets, please feel free to email me.')
def parse_guild_member_list_update(self, data) -> None: # Rewrite incoming...
self.dispatch('raw_guild_member_list_update', data)
guild = self._get_guild(int(data['guild_id']))
if guild is None:
_log.debug('GUILD_MEMBER_LIST_UPDATE referencing an unknown guild ID: %s. Discarding.', data['guild_id'])
return
ops = data['ops']
if data['member_count'] > 0:
guild._member_count = data['member_count']
online_count = 0
for group in data['groups']:
online_count += group['count'] if group['id'] != 'offline' else 0
guild._online_count = online_count
for opdata in ops:
op = opdata['op']
# There are two OPs I'm not parsing.
# INVALIDATE: Usually invalid (hehe).
# DELETE: Sends the index, not the user ID, so I can't do anything with
# it unless I keep a seperate list of the member sidebar (maybe in future).
if op == 'SYNC':
members = [Member(guild=guild, data=member['member'], state=self) for member in [item for item in opdata.get('items', []) if 'member' in item]]
member_dict = {str(member.id): member for member in members}
for presence in [item for item in opdata.get('items', []) if 'member' in item]:
presence = presence['member']['presence']
user = presence['user']
member_id = user['id']
member = member_dict.get(member_id)
member._presence_update(presence, user)
for member in members:
guild._add_member(member)
if op == 'INSERT':
if 'member' not in opdata['item']:
# Hoisted role INSERT
return
mdata = opdata['item']['member']
user = mdata['user']
user_id = int(user['id'])
member = guild.get_member(user_id)
if member is not None: # INSERTs are also sent when a user changes range
old_member = Member._copy(member)
member._update(mdata)
user_update = member._update_inner_user(user)
if 'presence' in mdata:
presence = mdata['presence']
user = presence['user']
member_id = user['id']
member._presence_update(presence, user)
if user_update:
self.dispatch('user_update', user_update[0], user_update[1])
self.dispatch('member_update', old_member, member)
else:
member = Member(data=mdata, guild=guild, state=self)
guild._add_member(member)
if op == 'UPDATE':
if 'member' not in opdata['item']:
# Hoisted role UPDATE
return
mdata = opdata['item']['member']
user = mdata['user']
user_id = int(user['id'])
member = guild.get_member(user_id)
if member is not None:
old_member = Member._copy(member)
member._update(mdata)
user_update = member._update_inner_user(user)
if 'presence' in mdata:
presence = mdata['presence']
user = presence['user']
member_id = user['id']
member._presence_update(presence, user)
if user_update:
self.dispatch('user_update', user_update[0], user_update[1])
self.dispatch('member_update', old_member, member)
else:
_log.debug('GUILD_MEMBER_LIST_UPDATE type UPDATE referencing an unknown member ID: %s. Discarding.', user_id)
def parse_guild_emojis_update(self, data) -> None:
guild = self._get_guild(int(data['guild_id']))
if guild is None:
@ -1008,15 +1221,12 @@ class ConnectionState:
self.dispatch('guild_stickers_update', guild, before_stickers, guild.stickers)
def _get_create_guild(self, data):
if data.get('unavailable') is False:
# GUILD_CREATE with unavailable in the response
# usually means that the guild has become available
# and is therefore in the cache
guild = self._get_guild(int(data['id']))
if guild is not None:
guild.unavailable = False
guild._from_data(data)
return guild
guild = self._get_guild(int(data['id']))
# Discord being Discord sends a GUILD_CREATE after an OPCode 14 is sent (a la bots)
# However, we want that if we forced a GUILD_CREATE for an unavailable guild
if guild is not None:
guild._from_data(data)
return
return self._add_guild_from_data(data)
@ -1034,44 +1244,44 @@ class ConnectionState:
return await request.wait()
return request.get_future()
async def _chunk_and_dispatch(self, guild, unavailable):
try:
await asyncio.wait_for(self.chunk_guild(guild), timeout=60.0)
except asyncio.TimeoutError:
_log.info('Somehow timed out waiting for chunks.')
async def _parse_and_dispatch(self, guild, *, chunk, subscribe) -> None:
self._queued_guilds[guild.id] = guild
if chunk:
try:
await asyncio.wait_for(self.chunk_guild(guild), timeout=60.0)
except asyncio.TimeoutError:
log.info('Somehow timed out waiting for chunks.')
if unavailable is False:
self.dispatch('guild_available', guild)
if subscribe:
await guild.subscribe(max_online=self._subscription_options.max_online)
self._queued_guilds.pop(guild.id)
# Dispatch available/join depending on circumstances
if guild.id in self._unavailable_guilds:
type = self._unavailable_guilds.pop(guild.id)
if type is UnavailableGuildType.existing:
self.dispatch('guild_available', guild)
else:
self.dispatch('guild_join', guild)
else:
self.dispatch('guild_join', guild)
def parse_guild_create(self, data) -> None:
unavailable = data.get('unavailable')
if unavailable is True:
# joined a guild with unavailable == True so..
return
def parse_guild_create(self, data):
guild_id = int(data['id'])
guild = self._get_create_guild(data)
try:
# Notify the on_ready state, if any, that this guild is complete.
self._ready_state.put_nowait(guild)
except AttributeError:
pass
else:
# If we're waiting for the event, put the rest on hold
if guild is None:
return
# check if it requires chunking
if self._guild_needs_chunking(guild):
asyncio.create_task(self._chunk_and_dispatch(guild, unavailable))
return
if self._request_guilds:
asyncio.ensure_future(self.request_guild(guild.id), loop=self.loop)
# Dispatch available if newly available
if unavailable is False:
self.dispatch('guild_available', guild)
else:
self.dispatch('guild_join', guild)
# Chunk/subscribe if needed
needs_chunking, needs_subscribing = self._guild_needs_chunking(guild), self._guild_needs_subscribing(guild)
asyncio.ensure_future(self._parse_and_dispatch(guild, chunk=needs_chunking, subscribe=needs_subscribing), loop=self.loop)
def parse_guild_update(self, data) -> None:
guild = self._get_guild(int(data['id']))
@ -1095,7 +1305,7 @@ class ConnectionState:
self.dispatch('guild_unavailable', guild)
return
# do a cleanup of the messages cache
# Cleanup the message cache
if self._messages is not None:
self._messages: Optional[Deque[Message]] = deque(
(msg for msg in self._messages if msg.guild != guild), maxlen=self.max_messages
@ -1105,11 +1315,6 @@ class ConnectionState:
self.dispatch('guild_remove', guild)
def parse_guild_ban_add(self, data) -> None:
# we make the assumption that GUILD_BAN_ADD is done
# before GUILD_MEMBER_REMOVE is called
# hence we don't remove it from cache or do anything
# strange with it, the main purpose of this event
# is mainly to dispatch to another event worth listening to for logging
guild = self._get_guild(int(data['guild_id']))
if guild is not None:
try:
@ -1168,7 +1373,7 @@ class ConnectionState:
guild = self._get_guild(guild_id)
presences = data.get('presences', [])
# the guild won't be None here
# The guild won't be None here
members = [Member(guild=guild, data=member, state=self) for member in data.get('members', [])] # type: ignore
_log.debug('Processed a chunk for %s members in guild ID %s.', len(members), guild_id)
@ -1266,24 +1471,43 @@ class ConnectionState:
else:
_log.debug('STAGE_INSTANCE_DELETE referencing unknown guild ID: %s. Discarding.', data['guild_id'])
def parse_call_create(self, data) -> None:
channel = self._get_private_channel(int(data['channel_id']))
message = self._call_message_cache.pop((int(data['message_id'])), None)
call = channel._add_call(state=self, message=message, channel=channel, **data)
self._calls[channel.id] = call
self.dispatch('call_create', call)
def parse_call_update(self, data) -> None:
call = self._calls.get(int(data['channel_id']))
call._update(**data)
self.dispatch('call_update', call)
def parse_call_delete(self, data) -> None:
call = self._calls.pop(int(data['channel_id']), None)
if call is not None:
call._deleteup()
self.dispatch('call_delete', call)
def parse_voice_state_update(self, data) -> None:
guild = self._get_guild(utils._get_as_snowflake(data, 'guild_id'))
channel_id = utils._get_as_snowflake(data, 'channel_id')
session_id = data['session_id']
flags = self.member_cache_flags
# self.user is *always* cached when this is called
self_id = self.user.id # type: ignore
if guild is not None:
if int(data['user_id']) == self_id:
voice = self._get_voice_client(guild.id)
if voice is not None:
coro = voice.on_voice_state_update(data)
asyncio.create_task(logging_coroutine(coro, info='Voice Protocol voice state update handler'))
if int(data['user_id']) == self_id:
voice = self._get_voice_client(guild.id)
if voice is not None:
coro = voice.on_voice_state_update(data)
asyncio.create_task(logging_coroutine(coro, info='Voice Protocol voice state update handler'))
if guild is not None
member, before, after = guild._update_voice_state(data, channel_id) # type: ignore
if member is not None:
if flags.voice:
if channel_id is None and flags._voice_only and member.id != self_id:
# Only remove from cache if we only have the voice flag enabled
# Member doesn't meet the Snowflake protocol currently
guild._remove_member(member) # type: ignore
elif channel_id is not None:
@ -1292,18 +1516,24 @@ class ConnectionState:
self.dispatch('voice_state_update', member, before, after)
else:
_log.debug('VOICE_STATE_UPDATE referencing an unknown member ID: %s. Discarding.', data['user_id'])
else:
user, before, after = self._update_voice_state(data)
self.dispatch('voice_state_update', user, before, after)
def parse_voice_server_update(self, data) -> None:
try:
key_id = int(data['guild_id'])
except KeyError:
key_id = int(data['channel_id'])
key_id = utils._get_as_snowflake(data, 'guild_id')
if key_id is None:
key_id = self.user.id
vc = self._get_voice_client(key_id)
if vc is not None:
coro = vc.on_voice_server_update(data)
asyncio.create_task(logging_coroutine(coro, info='Voice Protocol voice server update handler'))
def parse_user_required_action_update(self, data) -> None:
required_action = try_enum(RequiredActionType, data['required_action'])
self.dispatch('required_action_update', required_action)
def parse_typing_start(self, data) -> None:
channel, guild = self._get_guild_channel(data)
if channel is not None:
@ -1328,6 +1558,29 @@ class ConnectionState:
timestamp = datetime.datetime.fromtimestamp(data.get('timestamp'), tz=datetime.timezone.utc)
self.dispatch('typing', channel, member, timestamp)
def parse_user_required_action_update(self, data) -> None:
required_action = try_enum(RequiredActionType, data['required_action'])
self.dispatch('required_action_update', required_action)
def parse_relationship_add(self, data) -> None:
key = int(data['id'])
old = self.user.get_relationship(key)
new = Relationship(state=self, data=data)
self.user._relationships[key] = new
if old is not None:
self.dispatch('relationship_update', old, new)
else:
self.dispatch('relationship_add', new)
def parse_relationship_remove(self, data) -> None:
key = int(data['id'])
try:
old = self.user._relationships.pop(key)
except KeyError:
pass
else:
self.dispatch('relationship_remove', old)
def _get_reaction_user(self, channel: MessageableChannel, user_id: int) -> Optional[Union[User, Member]]:
if isinstance(channel, TextChannel):
return channel.guild.get_member(user_id)

Loading…
Cancel
Save