|
|
@ -47,18 +47,20 @@ class ListenerType(enum.Enum): |
|
|
|
chunk = 0 |
|
|
|
|
|
|
|
Listener = namedtuple('Listener', ('type', 'future', 'predicate')) |
|
|
|
StateContext = namedtuple('StateContext', 'try_insert_user http') |
|
|
|
log = logging.getLogger(__name__) |
|
|
|
ReadyState = namedtuple('ReadyState', ('launch', 'servers')) |
|
|
|
|
|
|
|
class ConnectionState: |
|
|
|
def __init__(self, dispatch, chunker, syncer, max_messages, *, loop): |
|
|
|
def __init__(self, *, dispatch, chunker, syncer, http, loop, **options): |
|
|
|
self.loop = loop |
|
|
|
self.max_messages = max_messages |
|
|
|
self.max_messages = max(options.get('max_messages', 5000), 100) |
|
|
|
self.dispatch = dispatch |
|
|
|
self.chunker = chunker |
|
|
|
self.syncer = syncer |
|
|
|
self.is_bot = None |
|
|
|
self._listeners = [] |
|
|
|
self.ctx = StateContext(try_insert_user=self.try_insert_user, http=http) |
|
|
|
self.clear() |
|
|
|
|
|
|
|
def clear(self): |
|
|
@ -66,6 +68,7 @@ class ConnectionState: |
|
|
|
self.sequence = None |
|
|
|
self.session_id = None |
|
|
|
self._calls = {} |
|
|
|
self._users = {} |
|
|
|
self._servers = {} |
|
|
|
self._voice_clients = {} |
|
|
|
self._private_channels = {} |
|
|
@ -116,6 +119,15 @@ class ConnectionState: |
|
|
|
for vc in self.voice_clients: |
|
|
|
vc.main_ws = ws |
|
|
|
|
|
|
|
def try_insert_user(self, data): |
|
|
|
# this way is 300% faster than `dict.setdefault`. |
|
|
|
user_id = data['id'] |
|
|
|
try: |
|
|
|
return self._users[user_id] |
|
|
|
except KeyError: |
|
|
|
self._users[user_id] = user = User(state=self.ctx, data=data) |
|
|
|
return user |
|
|
|
|
|
|
|
@property |
|
|
|
def servers(self): |
|
|
|
return self._servers.values() |
|
|
@ -153,7 +165,7 @@ class ConnectionState: |
|
|
|
return utils.find(lambda m: m.id == msg_id, self.messages) |
|
|
|
|
|
|
|
def _add_server_from_data(self, guild): |
|
|
|
server = Server(**guild) |
|
|
|
server = Server(data=guild, state=self.ctx) |
|
|
|
Server.me = property(lambda s: s.get_member(self.user.id)) |
|
|
|
Server.voice_client = property(lambda s: self._get_voice_client(s.id)) |
|
|
|
self._add_server(server) |
|
|
@ -207,7 +219,7 @@ class ConnectionState: |
|
|
|
|
|
|
|
def parse_ready(self, data): |
|
|
|
self._ready_state = ReadyState(launch=asyncio.Event(), servers=[]) |
|
|
|
self.user = User(**data['user']) |
|
|
|
self.user = self.try_insert_user(data['user']) |
|
|
|
guilds = data.get('guilds') |
|
|
|
|
|
|
|
servers = self._ready_state.servers |
|
|
@ -217,7 +229,7 @@ class ConnectionState: |
|
|
|
servers.append(server) |
|
|
|
|
|
|
|
for pm in data.get('private_channels'): |
|
|
|
self._add_private_channel(PrivateChannel(self.user, **pm)) |
|
|
|
self._add_private_channel(PrivateChannel(me=self.user, data=pm, state=self.ctx)) |
|
|
|
|
|
|
|
compat.create_task(self._delay_ready(), loop=self.loop) |
|
|
|
|
|
|
@ -226,7 +238,7 @@ class ConnectionState: |
|
|
|
|
|
|
|
def parse_message_create(self, data): |
|
|
|
channel = self.get_channel(data.get('channel_id')) |
|
|
|
message = self._create_message(channel=channel, **data) |
|
|
|
message = Message(channel=channel, data=data, state=self.ctx) |
|
|
|
self.dispatch('message', message) |
|
|
|
self.messages.append(message) |
|
|
|
|
|
|
@ -255,7 +267,7 @@ class ConnectionState: |
|
|
|
# embed only edit |
|
|
|
message.embeds = data['embeds'] |
|
|
|
else: |
|
|
|
message._update(channel=message.channel, **data) |
|
|
|
message._update(channel=message.channel, data=data) |
|
|
|
|
|
|
|
self.dispatch('message_edit', older_message, message) |
|
|
|
|
|
|
@ -329,22 +341,11 @@ class ConnectionState: |
|
|
|
server._add_member(member) |
|
|
|
|
|
|
|
old_member = member._copy() |
|
|
|
member.status = data.get('status') |
|
|
|
try: |
|
|
|
member.status = Status(member.status) |
|
|
|
except: |
|
|
|
pass |
|
|
|
|
|
|
|
game = data.get('game', {}) |
|
|
|
member.game = Game(**game) if game else None |
|
|
|
member.name = user.get('username', member.name) |
|
|
|
member.avatar = user.get('avatar', member.avatar) |
|
|
|
member.discriminator = user.get('discriminator', member.discriminator) |
|
|
|
|
|
|
|
member._presence_update(data=data, user=user) |
|
|
|
self.dispatch('member_update', old_member, member) |
|
|
|
|
|
|
|
def parse_user_update(self, data): |
|
|
|
self.user = User(**data) |
|
|
|
self.user = User(state=self.ctx, data=data) |
|
|
|
|
|
|
|
def parse_channel_delete(self, data): |
|
|
|
server = self._get_server(data.get('guild_id')) |
|
|
@ -361,7 +362,7 @@ class ConnectionState: |
|
|
|
if channel_type is ChannelType.group: |
|
|
|
channel = self._get_private_channel(channel_id) |
|
|
|
old_channel = copy.copy(channel) |
|
|
|
channel._update_group(**data) |
|
|
|
channel._update_group(data) |
|
|
|
self.dispatch('channel_update', old_channel, channel) |
|
|
|
return |
|
|
|
|
|
|
@ -370,32 +371,32 @@ class ConnectionState: |
|
|
|
channel = server.get_channel(channel_id) |
|
|
|
if channel is not None: |
|
|
|
old_channel = copy.copy(channel) |
|
|
|
channel._update(server=server, **data) |
|
|
|
channel._update(server, data) |
|
|
|
self.dispatch('channel_update', old_channel, channel) |
|
|
|
|
|
|
|
def parse_channel_create(self, data): |
|
|
|
ch_type = try_enum(ChannelType, data.get('type')) |
|
|
|
channel = None |
|
|
|
if ch_type in (ChannelType.group, ChannelType.private): |
|
|
|
channel = PrivateChannel(self.user, **data) |
|
|
|
channel = PrivateChannel(me=self.user, data=data, state=self.ctx) |
|
|
|
self._add_private_channel(channel) |
|
|
|
else: |
|
|
|
server = self._get_server(data.get('guild_id')) |
|
|
|
if server is not None: |
|
|
|
channel = Channel(server=server, **data) |
|
|
|
channel = Channel(server=server, state=self.ctx, data=data) |
|
|
|
server._add_channel(channel) |
|
|
|
|
|
|
|
self.dispatch('channel_create', channel) |
|
|
|
|
|
|
|
def parse_channel_recipient_add(self, data): |
|
|
|
channel = self._get_private_channel(data.get('channel_id')) |
|
|
|
user = User(**data.get('user', {})) |
|
|
|
user = self.try_insert_user(data['user']) |
|
|
|
channel.recipients.append(user) |
|
|
|
self.dispatch('group_join', channel, user) |
|
|
|
|
|
|
|
def parse_channel_recipient_remove(self, data): |
|
|
|
channel = self._get_private_channel(data.get('channel_id')) |
|
|
|
user = User(**data.get('user', {})) |
|
|
|
user = self.try_insert_user(data['user']) |
|
|
|
try: |
|
|
|
channel.recipients.remove(user) |
|
|
|
except ValueError: |
|
|
@ -411,7 +412,7 @@ class ConnectionState: |
|
|
|
roles.append(role) |
|
|
|
|
|
|
|
data['roles'] = sorted(roles, key=lambda r: int(r.id)) |
|
|
|
return Member(server=server, **data) |
|
|
|
return Member(server=server, data=data, state=self.ctx) |
|
|
|
|
|
|
|
def parse_guild_member_add(self, data): |
|
|
|
server = self._get_server(data.get('guild_id')) |
|
|
@ -441,35 +442,18 @@ class ConnectionState: |
|
|
|
|
|
|
|
def parse_guild_member_update(self, data): |
|
|
|
server = self._get_server(data.get('guild_id')) |
|
|
|
user_id = data['user']['id'] |
|
|
|
user = data['user'] |
|
|
|
user_id = user['id'] |
|
|
|
member = server.get_member(user_id) |
|
|
|
if member is not None: |
|
|
|
user = data['user'] |
|
|
|
old_member = member._copy() |
|
|
|
member.name = user['username'] |
|
|
|
member.discriminator = user['discriminator'] |
|
|
|
member.avatar = user['avatar'] |
|
|
|
member.bot = user.get('bot', False) |
|
|
|
|
|
|
|
# the nickname change is optional, |
|
|
|
# if it isn't in the payload then it didn't change |
|
|
|
if 'nick' in data: |
|
|
|
member.nick = data['nick'] |
|
|
|
|
|
|
|
# update the roles |
|
|
|
member.roles = [server.default_role] |
|
|
|
for role in server.roles: |
|
|
|
if role.id in data['roles']: |
|
|
|
member.roles.append(role) |
|
|
|
|
|
|
|
# sort the roles by ID since they can be "randomised" |
|
|
|
member.roles.sort(key=lambda r: int(r.id)) |
|
|
|
member._update(data, user) |
|
|
|
self.dispatch('member_update', old_member, member) |
|
|
|
|
|
|
|
def parse_guild_emojis_update(self, data): |
|
|
|
server = self._get_server(data.get('guild_id')) |
|
|
|
before_emojis = server.emojis |
|
|
|
server.emojis = [Emoji(server=server, **e) for e in data.get('emojis', [])] |
|
|
|
server.emojis = [Emoji(server=server, data=e, state=self.ctx) for e in data.get('emojis', [])] |
|
|
|
self.dispatch('server_emojis_update', before_emojis, server.emojis) |
|
|
|
|
|
|
|
def _get_create_server(self, data): |
|
|
@ -584,13 +568,13 @@ class ConnectionState: |
|
|
|
server = self._get_server(data.get('guild_id')) |
|
|
|
if server is not None: |
|
|
|
if 'user' in data: |
|
|
|
user = User(**data['user']) |
|
|
|
user = self.try_insert_user(data['user']) |
|
|
|
self.dispatch('member_unban', server, user) |
|
|
|
|
|
|
|
def parse_guild_role_create(self, data): |
|
|
|
server = self._get_server(data.get('guild_id')) |
|
|
|
role_data = data.get('role', {}) |
|
|
|
role = Role(server=server, **role_data) |
|
|
|
server = self._get_server(data['guild_id']) |
|
|
|
role_data = data['role'] |
|
|
|
role = Role(server=server, data=role_data, state=self.ctx) |
|
|
|
server._add_role(role) |
|
|
|
self.dispatch('server_role_create', role) |
|
|
|
|
|
|
@ -609,11 +593,12 @@ class ConnectionState: |
|
|
|
def parse_guild_role_update(self, data): |
|
|
|
server = self._get_server(data.get('guild_id')) |
|
|
|
if server is not None: |
|
|
|
role_id = data['role']['id'] |
|
|
|
role_data = data['role'] |
|
|
|
role_id = role_data['id'] |
|
|
|
role = utils.find(lambda r: r.id == role_id, server.roles) |
|
|
|
if role is not None: |
|
|
|
old_role = copy.copy(role) |
|
|
|
role._update(**data['role']) |
|
|
|
role._update(role_data) |
|
|
|
self.dispatch('server_role_update', old_role, role) |
|
|
|
|
|
|
|
def parse_guild_members_chunk(self, data): |
|
|
|