You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

384 lines
15 KiB

import six
import weakref
from collections import deque, namedtuple
from gevent.event import Event
from disco.types.base import UNSET
from disco.util.config import Config
from disco.util.string import underscore
from disco.util.hashmap import HashMap, DefaultHashMap
from disco.util.emitter import Priority
class StackMessage(namedtuple('StackMessage', ['id', 'channel_id', 'author_id'])):
"""
A message stored on a stack inside of the state object, used for tracking
previously sent messages in channels.
Attributes
---------
id : snowflake
the id of the message
channel_id : snowflake
the id of the channel this message was sent in
author_id : snowflake
the id of the author of this message
"""
class StateConfig(Config):
"""
A configuration object for determining how the State tracking behaves.
Attributes
----------
track_messages : bool
Whether the state store should keep a buffer of previously sent messages.
Message tracking allows for multiple higher-level shortcuts and can be
highly useful when developing bots that need to delete their own messages.
Message tracking is implemented using a deque and a namedtuple, meaning
it should generally not have a high impact on memory, however users who
find they do not need and may be experiencing memory pressure can disable
this feature entirely using this attribute.
track_messages_size : int
The size of the messages deque for each channel. This value can be used
to calculate the total number of possible `StackMessage` objects kept in
memory, simply: `total_messages_size * total_channels`. This value can
be tweaked based on usage and to help prevent memory pressure.
sync_guild_members : bool
If true, guilds will be automatically synced when they are initially loaded
or joined. Generally this setting is OK for smaller bots, however bots in over
50 guilds will notice this operation can take a while to complete and may want
to batch requests using the underlying `GatewayClient.request_guild_members`
interface.
"""
track_messages = True
track_messages_size = 100
sync_guild_members = True
class State(object):
"""
The State class is used to track global state based on events emitted from
the `GatewayClient`. State tracking is a core component of the Disco client,
providing the mechanism for most of the higher-level utility functions.
Attributes
----------
EVENTS : list(str)
A list of all events the State object binds to
client : `disco.client.Client`
The Client instance this state is attached to
config : `StateConfig`
The configuration for this state instance
me : `User`
The currently logged in user
dms : dict(snowflake, `Channel`)
Mapping of all known DM Channels
guilds : dict(snowflake, `Guild`)
Mapping of all known/loaded Guilds
channels : dict(snowflake, `Channel`)
Weak mapping of all known/loaded Channels
users : dict(snowflake, `User`)
Weak mapping of all known/loaded Users
voice_clients : dict(str, 'VoiceClient')
Weak mapping of all known voice clients
voice_states : dict(str, `VoiceState`)
Weak mapping of all known/active Voice States
messages : Optional[dict(snowflake, deque)]
Mapping of channel ids to deques containing `StackMessage` objects
"""
EVENTS = [
'Ready', 'GuildCreate', 'GuildUpdate', 'GuildDelete', 'GuildMemberAdd', 'GuildMemberRemove',
'GuildMemberUpdate', 'GuildMembersChunk', 'GuildRoleCreate', 'GuildRoleUpdate', 'GuildRoleDelete',
'GuildEmojisUpdate', 'ChannelCreate', 'ChannelUpdate', 'ChannelDelete', 'VoiceServerUpdate', 'VoiceStateUpdate',
'MessageCreate', 'PresenceUpdate',
]
def __init__(self, client, config):
self.client = client
self.config = config
self.ready = Event()
self.guilds_waiting_sync = 0
self.me = None
self.dms = HashMap()
self.guilds = HashMap()
self.channels = HashMap(weakref.WeakValueDictionary())
self.users = HashMap(weakref.WeakValueDictionary())
self.voice_clients = HashMap(weakref.WeakValueDictionary())
self.voice_states = HashMap(weakref.WeakValueDictionary())
# If message tracking is enabled, listen to those events
if self.config.track_messages:
self.messages = DefaultHashMap(lambda: deque(maxlen=self.config.track_messages_size))
self.EVENTS += ['MessageDelete', 'MessageDeleteBulk']
# The bound listener objects
self.listeners = []
self.bind()
def unbind(self):
"""
Unbinds all bound event listeners for this state object.
"""
map(lambda k: k.unbind(), self.listeners)
self.listeners = []
def bind(self):
"""
Binds all events for this state object, storing the listeners for later
unbinding.
"""
assert not len(self.listeners), 'Binding while already bound is dangerous'
for event in self.EVENTS:
func = 'on_' + underscore(event)
self.listeners.append(self.client.events.on(event, getattr(self, func), priority=Priority.AFTER))
def fill_messages(self, channel):
for message in reversed(next(channel.messages_iter(bulk=True))):
self.messages[channel.id].append(
StackMessage(message.id, message.channel_id, message.author.id))
def on_ready(self, event):
self.me = event.user
self.guilds_waiting_sync = len(event.guilds)
for dm in event.private_channels:
self.dms[dm.id] = dm
self.channels[dm.id] = dm
def on_user_update(self, event):
self.me.inplace_update(event.user)
def on_message_create(self, event):
if self.config.track_messages:
self.messages[event.message.channel_id].append(
StackMessage(event.message.id, event.message.channel_id, event.message.author.id))
if event.message.channel_id in self.channels:
self.channels[event.message.channel_id].last_message_id = event.message.id
def on_message_delete(self, event):
if event.channel_id not in self.messages:
return
sm = next((i for i in self.messages[event.channel_id] if i.id == event.id), None)
if not sm:
return
self.messages[event.channel_id].remove(sm)
def on_message_delete_bulk(self, event):
if event.channel_id not in self.messages:
return
# TODO: performance
for sm in list(self.messages[event.channel_id]):
if sm.id in event.ids:
self.messages[event.channel_id].remove(sm)
def on_guild_create(self, event):
if event.unavailable is False:
self.guilds_waiting_sync -= 1
if self.guilds_waiting_sync <= 0:
self.ready.set()
self.guilds[event.guild.id] = event.guild
self.channels.update(event.guild.channels)
for member in six.itervalues(event.guild.members):
if member.user.id not in self.users:
self.users[member.user.id] = member.user
for presence in event.presences:
if presence.user.id in self.users:
self.users[presence.user.id].presence = presence
for voice_state in six.itervalues(event.guild.voice_states):
self.voice_states[voice_state.session_id] = voice_state
if self.config.sync_guild_members:
event.guild.request_guild_members()
def on_guild_update(self, event):
self.guilds[event.guild.id].inplace_update(event.guild, ignored=[
'channels',
'members',
'voice_states',
'presences',
])
def on_guild_delete(self, event):
if event.id in self.guilds:
# Just delete the guild, channel references will fall
del self.guilds[event.id]
if event.id in self.voice_clients:
self.voice_clients[event.id].disconnect()
def on_channel_create(self, event):
if event.channel.is_guild and event.channel.guild_id in self.guilds:
self.guilds[event.channel.guild_id].channels[event.channel.id] = event.channel
self.channels[event.channel.id] = event.channel
elif event.channel.is_dm:
self.dms[event.channel.id] = event.channel
self.channels[event.channel.id] = event.channel
def on_channel_update(self, event):
if event.channel.id in self.channels:
self.channels[event.channel.id].inplace_update(event.channel)
if event.overwrites is not UNSET:
self.channels[event.channel.id].overwrites = event.overwrites
self.channels[event.channel.id].after_load()
def on_channel_delete(self, event):
if event.channel.is_guild and event.channel.guild and event.channel.id in event.channel.guild.channels:
del event.channel.guild.channels[event.channel.id]
elif event.channel.is_dm and event.channel.id in self.dms:
del self.dms[event.channel.id]
def on_voice_server_update(self, event):
if event.guild_id not in self.voice_clients:
return
voice_client = self.voice_clients.get(event.guild_id)
voice_client.set_endpoint(event.endpoint)
voice_client.set_token(event.token)
def on_voice_state_update(self, event):
# Existing connection, we are either moving channels or disconnecting
if event.state.session_id in self.voice_states:
# Moving channels
if event.state.channel_id:
self.voice_states[event.state.session_id].inplace_update(event.state)
# Disconnection
else:
if event.state.guild_id in self.guilds:
if event.state.session_id in self.guilds[event.state.guild_id].voice_states:
del self.guilds[event.state.guild_id].voice_states[event.state.session_id]
del self.voice_states[event.state.session_id]
# New connection
elif event.state.channel_id:
if event.state.guild_id in self.guilds:
expired_voice_state = self.guilds[event.state.guild_id].voice_states.select_one(user_id=event.user_id)
if expired_voice_state:
del self.guilds[event.state.guild_id].voice_states[expired_voice_state.session_id]
self.guilds[event.state.guild_id].voice_states[event.state.session_id] = event.state
expired_voice_state = self.voice_states.select_one(user_id=event.user_id)
if expired_voice_state:
del self.voice_states[expired_voice_state.session_id]
self.voice_states[event.state.session_id] = event.state
def on_guild_member_add(self, event):
if event.member.user.id not in self.users:
self.users[event.member.user.id] = event.member.user
else:
event.member.user = self.users[event.member.user.id]
if event.member.guild_id not in self.guilds:
return
if (self.guilds[event.member.guild_id].member_count is not UNSET and
# Avoid adding duplicate events to member_count.
event.member.id not in self.guilds[event.member.guild_id].members):
self.guilds[event.member.guild_id].member_count += 1
self.guilds[event.member.guild_id].members[event.member.id] = event.member
def on_guild_member_update(self, event):
if event.member.guild_id not in self.guilds:
return
if event.member.id not in self.guilds[event.member.guild_id].members:
return
self.guilds[event.member.guild_id].members[event.member.id].inplace_update(event.member)
def on_guild_member_remove(self, event):
if event.guild_id not in self.guilds:
return
if event.user.id not in self.guilds[event.guild_id].members:
return
if self.guilds[event.guild_id].member_count is not UNSET:
self.guilds[event.guild_id].member_count -= 1
del self.guilds[event.guild_id].members[event.user.id]
def on_guild_members_chunk(self, event):
if event.guild_id not in self.guilds:
return
guild = self.guilds[event.guild_id]
for member in event.members:
member.guild_id = guild.id
guild.members[member.id] = member
if member.id not in self.users:
self.users[member.id] = member.user
else:
member.user = self.users[member.id]
def on_guild_role_create(self, event):
if event.guild_id not in self.guilds:
return
self.guilds[event.guild_id].roles[event.role.id] = event.role
def on_guild_role_update(self, event):
if event.guild_id not in self.guilds:
return
self.guilds[event.guild_id].roles[event.role.id].inplace_update(event.role)
def on_guild_role_delete(self, event):
if event.guild_id not in self.guilds:
return
if event.role_id not in self.guilds[event.guild_id].roles:
return
del self.guilds[event.guild_id].roles[event.role_id]
def on_guild_emojis_update(self, event):
if event.guild_id not in self.guilds:
return
for emoji in event.emojis:
emoji.guild_id = event.guild_id
self.guilds[event.guild_id].emojis = HashMap({i.id: i for i in event.emojis})
def on_presence_update(self, event):
# TODO: this is recursive, we hackfix in model, but its still lame ATM
user = event.presence.user
user.presence = event.presence
# if we have the user tracked locally, we can just use the presence
# update to update both their presence and the cached user object.
if user.id in self.users:
self.users[user.id].inplace_update(user)
else:
# Otherwise this user does not exist in our local cache, so we can
# use this opportunity to add them. They will quickly fall out of
# scope and be deleted if they aren't used below
self.users[user.id] = user
# Some updates come with a guild_id and roles the user is in, we should
# use this to update the guild member, but only if we have the guild
# cached.
if event.roles is UNSET or event.guild_id not in self.guilds:
return
if user.id not in self.guilds[event.guild_id].members:
return
self.guilds[event.guild_id].members[user.id].roles = event.roles