From 5e671a0d0ddcba46581aea5bd9d8bb887498036d Mon Sep 17 00:00:00 2001 From: Hornwitser Date: Tue, 29 Sep 2015 08:47:37 +0200 Subject: [PATCH] Move socket and connection state out of Client Move the socket message handling and Discord connection state tracking out of the Client class. The WebSocket class handles the ws4py based WebSocket to Discord, maintains the keepalive and dispatches socket_ based on activity. The ConnectionSTate class maintains the state associated with the WebSocket connection with Discord. In a reconnect and switch gateway scenario this state can be kept for a faster and less disruptive recovery. --- discord/client.py | 493 ++++++++++++++++++++++++++-------------------- 1 file changed, 278 insertions(+), 215 deletions(-) diff --git a/discord/client.py b/discord/client.py index b33f79abb..355a327df 100644 --- a/discord/client.py +++ b/discord/client.py @@ -37,7 +37,7 @@ import requests import json, re, time, copy from collections import deque import threading -from ws4py.client.threadedclient import WebSocketClient +from ws4py.client import WebSocketBaseClient import sys import logging @@ -71,6 +71,256 @@ class KeepAliveHandler(threading.Thread): log.debug(msg.format(payload['d'])) self.socket.send(json.dumps(payload)) +class WebSocket(WebSocketBaseClient): + def __init__(self, dispatch, url): + WebSocketBaseClient.__init__(self, url, + protocols=['http-only', 'chat']) + self.dispatch = dispatch + self.keep_alive = None + + def opened(self): + log.info('Opened at {}'.format(int(time.time()))) + self.dispatch('socket_opened') + + def closed(self, code, reason=None): + if self.keep_alive is not None: + self.keep_alive.stop.set() + log.info('Closed with {} ("{}") at {}'.format(code, reason, + int(time.time()))) + self.dispatch('socket_closed') + + def handshake_ok(self): + pass + + def received_message(self, msg): + response = json.loads(str(msg)) + log.debug('WebSocket Event: {}'.format(response)) + if response.get('op') != 0: + log.info("Unhandled op {}".format(response.get('op'))) + return # What about op 7? + + self.dispatch('socket_response', response) + event = response.get('t') + data = response.get('d') + + if event == 'READY': + interval = data['heartbeat_interval'] / 1000.0 + self.keep_alive = KeepAliveHandler(interval, self) + self.keep_alive.start() + + + if event in ('READY', 'MESSAGE_CREATE', 'MESSAGE_DELETE', + 'MESSAGE_UPDATE', 'PRESENCE_UPDATE', 'USER_UPDATE', + 'CHANNEL_DELETE', 'CHANNEL_UPDATE', 'CHANNEL_CREATE', + 'GUILD_MEMBER_ADD', 'GUILD_MEMBER_REMOVE', + 'GUILD_MEMBER_UPDATE', 'GUILD_CREATE', 'GUILD_DELETE'): + self.dispatch('socket_update', event, data) + + else: + log.info("Unhandled event {}".format(event)) + + +class ConnectionState(object): + def __init__(self, dispatch, **kwargs): + self.dispatch = dispatch + self.user = None + self.email = None + self.servers = [] + self.private_channels = [] + self.messages = deque([], maxlen=kwargs.get('max_length', 5000)) + + def _get_message(self, msg_id): + return utils.find(lambda m: m.id == msg_id, self.messages) + + def _get_server(self, guild_id): + return utils.find(lambda g: g.id == guild_id, self.servers) + + def _add_server(self, guild): + guild['roles'] = [Role(**role) for role in guild['roles']] + members = guild['members'] + owner = guild['owner_id'] + for i, member in enumerate(members): + roles = member['roles'] + for j, roleid in enumerate(roles): + role = utils.find(lambda r: r.id == roleid, guild['roles']) + if role is not None: + roles[j] = role + members[i] = Member(**member) + + # found the member that owns the server + if members[i].id == owner: + owner = members[i] + + for presence in guild['presences']: + user_id = presence['user']['id'] + member = utils.find(lambda m: m.id == user_id, members) + if member is not None: + member.status = presence['status'] + member.game_id = presence['game_id'] + + + server = Server(owner=owner, **guild) + + # give all the members their proper server + for member in server.members: + member.server = server + + channels = [Channel(server=server, **channel) + for channel in guild['channels']] + server.channels = channels + self.servers.append(server) + + def handle_ready(self, data): + self.user = User(**data['user']) + guilds = data.get('guilds') + + for guild in guilds: + self._add_server(guild) + + for pm in data.get('private_channels'): + self.private_channels.append(PrivateChannel(id=pm['id'], + user=User(**pm['recipient']))) + + # we're all ready + self.dispatch('ready') + + def handle_message_create(self, data): + channel = self.get_channel(data.get('channel_id')) + message = Message(channel=channel, **data) + self.dispatch('message', message) + self.messages.append(message) + + def handle_message_delete(self, data): + channel = self.get_channel(data.get('channel_id')) + message_id = data.get('id') + found = self._get_message(message_id) + if found is not None: + self.dispatch('message_delete', found) + self.messages.remove(found) + + def handle_message_update(self, data): + older_message = self._get_message(data.get('id')) + if older_message is not None: + # create a copy of the new message + message = copy.deepcopy(older_message) + # update the new update + for attr in data: + if attr == 'channel_id' or attr == 'author': + continue + value = data[attr] + if 'time' in attr: + setattr(message, attr, utils.parse_time(value)) + else: + setattr(message, attr, value) + self.dispatch('message_edit', older_message, message) + # update the older message + older_message = message + + def handle_presence_update(self, data): + server = self._get_server(data.get('guild_id')) + if server is not None: + status = data.get('status') + user = data['user'] + member_id = user['id'] + member = utils.find(lambda m: m.id == member_id, server.members) + if member is not None: + member.status = data.get('status') + member.game_id = data.get('game_id') + member.name = user.get('username', member.name) + member.avatar = user.get('avatar', member.avatar) + + # call the event now + self.dispatch('status', member) + self.dispatch('member_update', member) + + def handle_user_update(self, data): + self.user = User(**data) + + def handle_channel_delete(self, data): + server = self._get_server(data.get('guild_id')) + if server is not None: + channel_id = data.get('id') + channel = utils.find(lambda c: c.id == channel_id, server.channels) + server.channels.remove(channel) + self.dispatch('channel_delete', channel) + + def handle_channel_update(self, data): + server = self._get_server(data.get('guild_id')) + if server is not None: + channel_id = data.get('id') + channel = utils.find(lambda c: c.id == channel_id, server.channels) + channel.update(server=server, **data) + self.dispatch('channel_update', channel) + + def handle_channel_create(self, data): + is_private = data.get('is_private', False) + channel = None + if is_private: + recipient = User(**data.get('recipient')) + pm_id = data.get('id') + channel = PrivateChannel(id=pm_id, user=recipient) + self.private_channels.append(channel) + else: + server = self._get_server(data.get('guild_id')) + if server is not None: + channel = Channel(server=server, **data) + server.channels.append(channel) + + self.dispatch('channel_create', channel) + + def handle_guild_member_add(self, data): + server = self._get_server(data.get('guild_id')) + member = Member(server=server, deaf=False, mute=False, **data) + server.members.append(member) + self.dispatch('member_join', member) + + def handle_guild_member_remove(self, data): + server = self._get_server(data.get('guild_id')) + user_id = data['user']['id'] + member = utils.find(lambda m: m.id == user_id, server.members) + server.members.remove(member) + self.dispatch('member_remove', member) + + def handle_guild_member_update(self, data): + server = self._get_server(data.get('guild_id')) + user_id = data['user']['id'] + member = utils.find(lambda m: m.id == user_id, server.members) + if member is not None: + user = data['user'] + member.name = user['username'] + member.discriminator = user['discriminator'] + member.avatar = user['avatar'] + member.roles = [] + # update the roles + for role in server.roles: + if role.id in data['roles']: + member.roles.append(role) + + self.dispatch('member_update', member) + + def handle_guild_create(self, data): + self._add_server(data) + self.dispatch('server_create', self.servers[-1]) + + def handle_guild_delete(self, data): + server = self._get_server(data.get('id')) + self.servers.remove(server) + self.dispatch('server_delete', server) + + def get_channel(self, id): + if id is None: + return None + + for server in self.servers: + for channel in server.channels: + if channel.id == id: + return channel + + for pm in self.private_channels: + if pm.id == id: + return pm + + class Client(object): """Represents a client connection that connects to Discord. This class is used to interact with the Discord WebSocket and API. @@ -102,12 +352,8 @@ class Client(object): def __init__(self, **kwargs): self._is_logged_in = False - self.user = None - self.email = None - self.servers = [] - self.private_channels = [] + self.connection = ConnectionState(self.dispatch, **kwargs) self.token = '' - self.messages = deque([], maxlen=kwargs.get('max_length', 5000)) self.events = { 'on_ready': _null_event, 'on_disconnect': _null_event, @@ -133,60 +379,11 @@ class Client(object): 'authorization': self.token, } - def _get_message(self, msg_id): - return utils.find(lambda m: m.id == msg_id, self.messages) - - def _get_server(self, guild_id): - return utils.find(lambda g: g.id == guild_id, self.servers) - - def _add_server(self, guild): - guild['roles'] = [Role(**role) for role in guild['roles']] - members = guild['members'] - owner = guild['owner_id'] - for i, member in enumerate(members): - roles = member['roles'] - for j, roleid in enumerate(roles): - role = utils.find(lambda r: r.id == roleid, guild['roles']) - if role is not None: - roles[j] = role - members[i] = Member(**member) - - # found the member that owns the server - if members[i].id == owner: - owner = members[i] - - for presence in guild['presences']: - user_id = presence['user']['id'] - member = utils.find(lambda m: m.id == user_id, members) - if member is not None: - member.status = presence['status'] - member.game_id = presence['game_id'] - - - server = Server(owner=owner, **guild) - - # give all the members their proper server - for member in server.members: - member.server = server - - channels = [Channel(server=server, **channel) for channel in guild['channels']] - server.channels = channels - self.servers.append(server) - def _create_websocket(self, url, reconnect=False): if url is None: raise GatewayNotFound() log.info('websocket gateway found') - self.ws = WebSocketClient(url, protocols=['http-only', 'chat']) - - # this is kind of hacky, but it's to avoid deadlocks. - # i.e. python does not allow me to have the current thread running if it's self - # it throws a 'cannot join current thread' RuntimeError - # So instead of doing a basic inheritance scheme, we're overriding the member functions. - - self.ws.opened = self._opened - self.ws.closed = self._closed - self.ws.received_message = self._received_message + self.ws = WebSocket(self.dispatch, url) self.ws.connect() log.info('websocket has connected') @@ -220,6 +417,23 @@ class Client(object): msg = 'Caught exception in {} with args (*{}, **{})' log.exception(msg.format(event_method, args, kwargs)) + # Compatibility shim + def __getattr__(self, name): + if name in ('user', 'email', 'servers', 'private_channels', 'messages', + 'get_channel'): + return getattr(self.connection, name) + else: + msg = "'{}' object has no attribute '{}'" + raise AttributeError(msg.format(self.__class__, name)) + + # Compatibility shim + def __setattr__(self, name, value): + if name in ('user', 'email', 'servers', 'private_channels', + 'messages'): + return setattr(self.connection, name, value) + else: + object.__setattr__(self, name, value) + def dispatch(self, event, *args, **kwargs): log.debug("Dispatching event {}".format(event)) handle_method = '_'.join(('handle', event)) @@ -242,156 +456,14 @@ class Client(object): log.error('an error ({}) occurred in event {} so on_error is invoked instead'.format(type(e).__name__, event_name)) self.events['on_error'](event_name, *sys.exc_info()) - def _received_message(self, msg): - response = json.loads(str(msg)) - log.debug('WebSocket Event: {}'.format(response)) - if response.get('op') != 0: - return - - self.dispatch('response', response) - event = response.get('t') - data = response.get('d') - - if event == 'READY': - self.user = User(**data['user']) - guilds = data.get('guilds') - - for guild in guilds: - self._add_server(guild) - - for pm in data.get('private_channels'): - self.private_channels.append(PrivateChannel(id=pm['id'], user=User(**pm['recipient']))) - - # set the keep alive interval.. - interval = data.get('heartbeat_interval') / 1000.0 - self.keep_alive = KeepAliveHandler(interval, self.ws) - self.keep_alive.start() - - # we're all ready - self.dispatch('ready') - elif event == 'MESSAGE_CREATE': - channel = self.get_channel(data.get('channel_id')) - message = Message(channel=channel, **data) - self.dispatch('message', message) - self.messages.append(message) - elif event == 'MESSAGE_DELETE': - channel = self.get_channel(data.get('channel_id')) - message_id = data.get('id') - found = self._get_message(message_id) - if found is not None: - self.dispatch('message_delete', found) - self.messages.remove(found) - elif event == 'MESSAGE_UPDATE': - older_message = self._get_message(data.get('id')) - if older_message is not None: - # create a copy of the new message - message = copy.deepcopy(older_message) - # update the new update - for attr in data: - if attr == 'channel_id' or attr == 'author': - continue - value = data[attr] - if 'time' in attr: - setattr(message, attr, utils.parse_time(value)) - else: - setattr(message, attr, value) - self.dispatch('message_edit', older_message, message) - # update the older message - older_message = message - - elif event == 'PRESENCE_UPDATE': - server = self._get_server(data.get('guild_id')) - if server is not None: - status = data.get('status') - user = data['user'] - member_id = user['id'] - member = utils.find(lambda m: m.id == member_id, server.members) - if member is not None: - member.status = data.get('status') - member.game_id = data.get('game_id') - member.name = user.get('username', member.name) - member.avatar = user.get('avatar', member.avatar) - - # call the event now - self.dispatch('status', member) - self.dispatch('member_update', member) - elif event == 'USER_UPDATE': - self.user = User(**data) - elif event == 'CHANNEL_DELETE': - server = self._get_server(data.get('guild_id')) - if server is not None: - channel_id = data.get('id') - channel = utils.find(lambda c: c.id == channel_id, server.channels) - server.channels.remove(channel) - self.dispatch('channel_delete', channel) - elif event == 'CHANNEL_UPDATE': - server = self._get_server(data.get('guild_id')) - if server is not None: - channel_id = data.get('id') - channel = utils.find(lambda c: c.id == channel_id, server.channels) - channel.update(server=server, **data) - self.dispatch('channel_update', channel) - elif event == 'CHANNEL_CREATE': - is_private = data.get('is_private', False) - channel = None - if is_private: - recipient = User(**data.get('recipient')) - pm_id = data.get('id') - channel = PrivateChannel(id=pm_id, user=recipient) - self.private_channels.append(channel) - else: - server = self._get_server(data.get('guild_id')) - if server is not None: - channel = Channel(server=server, **data) - server.channels.append(channel) - - self.dispatch('channel_create', channel) - elif event == 'GUILD_MEMBER_ADD': - server = self._get_server(data.get('guild_id')) - member = Member(server=server, deaf=False, mute=False, **data) - server.members.append(member) - self.dispatch('member_join', member) - elif event == 'GUILD_MEMBER_REMOVE': - server = self._get_server(data.get('guild_id')) - user_id = data['user']['id'] - member = utils.find(lambda m: m.id == user_id, server.members) - server.members.remove(member) - self.dispatch('member_remove', member) - elif event == 'GUILD_MEMBER_UPDATE': - server = self._get_server(data.get('guild_id')) - user_id = data['user']['id'] - member = utils.find(lambda m: m.id == user_id, server.members) - if member is not None: - user = data['user'] - member.name = user['username'] - member.discriminator = user['discriminator'] - member.avatar = user['avatar'] - member.roles = [] - # update the roles - for role in server.roles: - if role.id in data['roles']: - member.roles.append(role) - - self.dispatch('member_update', member) - elif event == 'GUILD_CREATE': - self._add_server(data) - self.dispatch('server_create', self.servers[-1]) - elif event == 'GUILD_DELETE': - server = self._get_server(data.get('id')) - self.servers.remove(server) - self.dispatch('server_delete', server) - - def _opened(self): - log.info('Opened at {}'.format(int(time.time()))) - - def _closed(self, code, reason=None): - log.info('Closed with {} ("{}") at {}'.format(code, reason, int(time.time()))) - self.dispatch('disconnect') + def handle_socket_update(self, event, data): + method = '_'.join(('handle', event.lower())) + getattr(self.connection, method)(data) def run(self): """Runs the client and allows it to receive messages and events.""" log.info('Client is being run') - self.ws.run_forever() + self.ws.run() @property def is_logged_in(self): @@ -399,18 +471,10 @@ class Client(object): return self._is_logged_in def get_channel(self, id): - """Returns a :class:`Channel` or :class:`PrivateChannel` with the following ID. If not found, returns None.""" - if id is None: - return None - - for server in self.servers: - for channel in server.channels: - if channel.id == id: - return channel - - for pm in self.private_channels: - if pm.id == id: - return pm + """Returns a :class:`Channel` or :class:`PrivateChannel` with the + following ID. If not found, returns None. + """ + return self.connection.get_channel(id) def start_private_message(self, user): """Starts a private message with the user. This allows you to :meth:`send_message` to it. @@ -578,7 +642,6 @@ class Client(object): response = requests.post(endpoints.LOGOUT) self.ws.close() self._is_logged_in = False - self.keep_alive.stop.set() log.debug(request_logging_format.format(name='logout', response=response)) def logs_from(self, channel, limit=500):