diff --git a/discord/client.py b/discord/client.py index b33f79abb..355a327df 100644 --- a/discord/client.py +++ b/discord/client.py @@ -37,7 +37,7 @@ import requests import json, re, time, copy from collections import deque import threading -from ws4py.client.threadedclient import WebSocketClient +from ws4py.client import WebSocketBaseClient import sys import logging @@ -71,6 +71,256 @@ class KeepAliveHandler(threading.Thread): log.debug(msg.format(payload['d'])) self.socket.send(json.dumps(payload)) +class WebSocket(WebSocketBaseClient): + def __init__(self, dispatch, url): + WebSocketBaseClient.__init__(self, url, + protocols=['http-only', 'chat']) + self.dispatch = dispatch + self.keep_alive = None + + def opened(self): + log.info('Opened at {}'.format(int(time.time()))) + self.dispatch('socket_opened') + + def closed(self, code, reason=None): + if self.keep_alive is not None: + self.keep_alive.stop.set() + log.info('Closed with {} ("{}") at {}'.format(code, reason, + int(time.time()))) + self.dispatch('socket_closed') + + def handshake_ok(self): + pass + + def received_message(self, msg): + response = json.loads(str(msg)) + log.debug('WebSocket Event: {}'.format(response)) + if response.get('op') != 0: + log.info("Unhandled op {}".format(response.get('op'))) + return # What about op 7? + + self.dispatch('socket_response', response) + event = response.get('t') + data = response.get('d') + + if event == 'READY': + interval = data['heartbeat_interval'] / 1000.0 + self.keep_alive = KeepAliveHandler(interval, self) + self.keep_alive.start() + + + if event in ('READY', 'MESSAGE_CREATE', 'MESSAGE_DELETE', + 'MESSAGE_UPDATE', 'PRESENCE_UPDATE', 'USER_UPDATE', + 'CHANNEL_DELETE', 'CHANNEL_UPDATE', 'CHANNEL_CREATE', + 'GUILD_MEMBER_ADD', 'GUILD_MEMBER_REMOVE', + 'GUILD_MEMBER_UPDATE', 'GUILD_CREATE', 'GUILD_DELETE'): + self.dispatch('socket_update', event, data) + + else: + log.info("Unhandled event {}".format(event)) + + +class ConnectionState(object): + def __init__(self, dispatch, **kwargs): + self.dispatch = dispatch + self.user = None + self.email = None + self.servers = [] + self.private_channels = [] + self.messages = deque([], maxlen=kwargs.get('max_length', 5000)) + + def _get_message(self, msg_id): + return utils.find(lambda m: m.id == msg_id, self.messages) + + def _get_server(self, guild_id): + return utils.find(lambda g: g.id == guild_id, self.servers) + + def _add_server(self, guild): + guild['roles'] = [Role(**role) for role in guild['roles']] + members = guild['members'] + owner = guild['owner_id'] + for i, member in enumerate(members): + roles = member['roles'] + for j, roleid in enumerate(roles): + role = utils.find(lambda r: r.id == roleid, guild['roles']) + if role is not None: + roles[j] = role + members[i] = Member(**member) + + # found the member that owns the server + if members[i].id == owner: + owner = members[i] + + for presence in guild['presences']: + user_id = presence['user']['id'] + member = utils.find(lambda m: m.id == user_id, members) + if member is not None: + member.status = presence['status'] + member.game_id = presence['game_id'] + + + server = Server(owner=owner, **guild) + + # give all the members their proper server + for member in server.members: + member.server = server + + channels = [Channel(server=server, **channel) + for channel in guild['channels']] + server.channels = channels + self.servers.append(server) + + def handle_ready(self, data): + self.user = User(**data['user']) + guilds = data.get('guilds') + + for guild in guilds: + self._add_server(guild) + + for pm in data.get('private_channels'): + self.private_channels.append(PrivateChannel(id=pm['id'], + user=User(**pm['recipient']))) + + # we're all ready + self.dispatch('ready') + + def handle_message_create(self, data): + channel = self.get_channel(data.get('channel_id')) + message = Message(channel=channel, **data) + self.dispatch('message', message) + self.messages.append(message) + + def handle_message_delete(self, data): + channel = self.get_channel(data.get('channel_id')) + message_id = data.get('id') + found = self._get_message(message_id) + if found is not None: + self.dispatch('message_delete', found) + self.messages.remove(found) + + def handle_message_update(self, data): + older_message = self._get_message(data.get('id')) + if older_message is not None: + # create a copy of the new message + message = copy.deepcopy(older_message) + # update the new update + for attr in data: + if attr == 'channel_id' or attr == 'author': + continue + value = data[attr] + if 'time' in attr: + setattr(message, attr, utils.parse_time(value)) + else: + setattr(message, attr, value) + self.dispatch('message_edit', older_message, message) + # update the older message + older_message = message + + def handle_presence_update(self, data): + server = self._get_server(data.get('guild_id')) + if server is not None: + status = data.get('status') + user = data['user'] + member_id = user['id'] + member = utils.find(lambda m: m.id == member_id, server.members) + if member is not None: + member.status = data.get('status') + member.game_id = data.get('game_id') + member.name = user.get('username', member.name) + member.avatar = user.get('avatar', member.avatar) + + # call the event now + self.dispatch('status', member) + self.dispatch('member_update', member) + + def handle_user_update(self, data): + self.user = User(**data) + + def handle_channel_delete(self, data): + server = self._get_server(data.get('guild_id')) + if server is not None: + channel_id = data.get('id') + channel = utils.find(lambda c: c.id == channel_id, server.channels) + server.channels.remove(channel) + self.dispatch('channel_delete', channel) + + def handle_channel_update(self, data): + server = self._get_server(data.get('guild_id')) + if server is not None: + channel_id = data.get('id') + channel = utils.find(lambda c: c.id == channel_id, server.channels) + channel.update(server=server, **data) + self.dispatch('channel_update', channel) + + def handle_channel_create(self, data): + is_private = data.get('is_private', False) + channel = None + if is_private: + recipient = User(**data.get('recipient')) + pm_id = data.get('id') + channel = PrivateChannel(id=pm_id, user=recipient) + self.private_channels.append(channel) + else: + server = self._get_server(data.get('guild_id')) + if server is not None: + channel = Channel(server=server, **data) + server.channels.append(channel) + + self.dispatch('channel_create', channel) + + def handle_guild_member_add(self, data): + server = self._get_server(data.get('guild_id')) + member = Member(server=server, deaf=False, mute=False, **data) + server.members.append(member) + self.dispatch('member_join', member) + + def handle_guild_member_remove(self, data): + server = self._get_server(data.get('guild_id')) + user_id = data['user']['id'] + member = utils.find(lambda m: m.id == user_id, server.members) + server.members.remove(member) + self.dispatch('member_remove', member) + + def handle_guild_member_update(self, data): + server = self._get_server(data.get('guild_id')) + user_id = data['user']['id'] + member = utils.find(lambda m: m.id == user_id, server.members) + if member is not None: + user = data['user'] + member.name = user['username'] + member.discriminator = user['discriminator'] + member.avatar = user['avatar'] + member.roles = [] + # update the roles + for role in server.roles: + if role.id in data['roles']: + member.roles.append(role) + + self.dispatch('member_update', member) + + def handle_guild_create(self, data): + self._add_server(data) + self.dispatch('server_create', self.servers[-1]) + + def handle_guild_delete(self, data): + server = self._get_server(data.get('id')) + self.servers.remove(server) + self.dispatch('server_delete', server) + + def get_channel(self, id): + if id is None: + return None + + for server in self.servers: + for channel in server.channels: + if channel.id == id: + return channel + + for pm in self.private_channels: + if pm.id == id: + return pm + + class Client(object): """Represents a client connection that connects to Discord. This class is used to interact with the Discord WebSocket and API. @@ -102,12 +352,8 @@ class Client(object): def __init__(self, **kwargs): self._is_logged_in = False - self.user = None - self.email = None - self.servers = [] - self.private_channels = [] + self.connection = ConnectionState(self.dispatch, **kwargs) self.token = '' - self.messages = deque([], maxlen=kwargs.get('max_length', 5000)) self.events = { 'on_ready': _null_event, 'on_disconnect': _null_event, @@ -133,60 +379,11 @@ class Client(object): 'authorization': self.token, } - def _get_message(self, msg_id): - return utils.find(lambda m: m.id == msg_id, self.messages) - - def _get_server(self, guild_id): - return utils.find(lambda g: g.id == guild_id, self.servers) - - def _add_server(self, guild): - guild['roles'] = [Role(**role) for role in guild['roles']] - members = guild['members'] - owner = guild['owner_id'] - for i, member in enumerate(members): - roles = member['roles'] - for j, roleid in enumerate(roles): - role = utils.find(lambda r: r.id == roleid, guild['roles']) - if role is not None: - roles[j] = role - members[i] = Member(**member) - - # found the member that owns the server - if members[i].id == owner: - owner = members[i] - - for presence in guild['presences']: - user_id = presence['user']['id'] - member = utils.find(lambda m: m.id == user_id, members) - if member is not None: - member.status = presence['status'] - member.game_id = presence['game_id'] - - - server = Server(owner=owner, **guild) - - # give all the members their proper server - for member in server.members: - member.server = server - - channels = [Channel(server=server, **channel) for channel in guild['channels']] - server.channels = channels - self.servers.append(server) - def _create_websocket(self, url, reconnect=False): if url is None: raise GatewayNotFound() log.info('websocket gateway found') - self.ws = WebSocketClient(url, protocols=['http-only', 'chat']) - - # this is kind of hacky, but it's to avoid deadlocks. - # i.e. python does not allow me to have the current thread running if it's self - # it throws a 'cannot join current thread' RuntimeError - # So instead of doing a basic inheritance scheme, we're overriding the member functions. - - self.ws.opened = self._opened - self.ws.closed = self._closed - self.ws.received_message = self._received_message + self.ws = WebSocket(self.dispatch, url) self.ws.connect() log.info('websocket has connected') @@ -220,6 +417,23 @@ class Client(object): msg = 'Caught exception in {} with args (*{}, **{})' log.exception(msg.format(event_method, args, kwargs)) + # Compatibility shim + def __getattr__(self, name): + if name in ('user', 'email', 'servers', 'private_channels', 'messages', + 'get_channel'): + return getattr(self.connection, name) + else: + msg = "'{}' object has no attribute '{}'" + raise AttributeError(msg.format(self.__class__, name)) + + # Compatibility shim + def __setattr__(self, name, value): + if name in ('user', 'email', 'servers', 'private_channels', + 'messages'): + return setattr(self.connection, name, value) + else: + object.__setattr__(self, name, value) + def dispatch(self, event, *args, **kwargs): log.debug("Dispatching event {}".format(event)) handle_method = '_'.join(('handle', event)) @@ -242,156 +456,14 @@ class Client(object): log.error('an error ({}) occurred in event {} so on_error is invoked instead'.format(type(e).__name__, event_name)) self.events['on_error'](event_name, *sys.exc_info()) - def _received_message(self, msg): - response = json.loads(str(msg)) - log.debug('WebSocket Event: {}'.format(response)) - if response.get('op') != 0: - return - - self.dispatch('response', response) - event = response.get('t') - data = response.get('d') - - if event == 'READY': - self.user = User(**data['user']) - guilds = data.get('guilds') - - for guild in guilds: - self._add_server(guild) - - for pm in data.get('private_channels'): - self.private_channels.append(PrivateChannel(id=pm['id'], user=User(**pm['recipient']))) - - # set the keep alive interval.. - interval = data.get('heartbeat_interval') / 1000.0 - self.keep_alive = KeepAliveHandler(interval, self.ws) - self.keep_alive.start() - - # we're all ready - self.dispatch('ready') - elif event == 'MESSAGE_CREATE': - channel = self.get_channel(data.get('channel_id')) - message = Message(channel=channel, **data) - self.dispatch('message', message) - self.messages.append(message) - elif event == 'MESSAGE_DELETE': - channel = self.get_channel(data.get('channel_id')) - message_id = data.get('id') - found = self._get_message(message_id) - if found is not None: - self.dispatch('message_delete', found) - self.messages.remove(found) - elif event == 'MESSAGE_UPDATE': - older_message = self._get_message(data.get('id')) - if older_message is not None: - # create a copy of the new message - message = copy.deepcopy(older_message) - # update the new update - for attr in data: - if attr == 'channel_id' or attr == 'author': - continue - value = data[attr] - if 'time' in attr: - setattr(message, attr, utils.parse_time(value)) - else: - setattr(message, attr, value) - self.dispatch('message_edit', older_message, message) - # update the older message - older_message = message - - elif event == 'PRESENCE_UPDATE': - server = self._get_server(data.get('guild_id')) - if server is not None: - status = data.get('status') - user = data['user'] - member_id = user['id'] - member = utils.find(lambda m: m.id == member_id, server.members) - if member is not None: - member.status = data.get('status') - member.game_id = data.get('game_id') - member.name = user.get('username', member.name) - member.avatar = user.get('avatar', member.avatar) - - # call the event now - self.dispatch('status', member) - self.dispatch('member_update', member) - elif event == 'USER_UPDATE': - self.user = User(**data) - elif event == 'CHANNEL_DELETE': - server = self._get_server(data.get('guild_id')) - if server is not None: - channel_id = data.get('id') - channel = utils.find(lambda c: c.id == channel_id, server.channels) - server.channels.remove(channel) - self.dispatch('channel_delete', channel) - elif event == 'CHANNEL_UPDATE': - server = self._get_server(data.get('guild_id')) - if server is not None: - channel_id = data.get('id') - channel = utils.find(lambda c: c.id == channel_id, server.channels) - channel.update(server=server, **data) - self.dispatch('channel_update', channel) - elif event == 'CHANNEL_CREATE': - is_private = data.get('is_private', False) - channel = None - if is_private: - recipient = User(**data.get('recipient')) - pm_id = data.get('id') - channel = PrivateChannel(id=pm_id, user=recipient) - self.private_channels.append(channel) - else: - server = self._get_server(data.get('guild_id')) - if server is not None: - channel = Channel(server=server, **data) - server.channels.append(channel) - - self.dispatch('channel_create', channel) - elif event == 'GUILD_MEMBER_ADD': - server = self._get_server(data.get('guild_id')) - member = Member(server=server, deaf=False, mute=False, **data) - server.members.append(member) - self.dispatch('member_join', member) - elif event == 'GUILD_MEMBER_REMOVE': - server = self._get_server(data.get('guild_id')) - user_id = data['user']['id'] - member = utils.find(lambda m: m.id == user_id, server.members) - server.members.remove(member) - self.dispatch('member_remove', member) - elif event == 'GUILD_MEMBER_UPDATE': - server = self._get_server(data.get('guild_id')) - user_id = data['user']['id'] - member = utils.find(lambda m: m.id == user_id, server.members) - if member is not None: - user = data['user'] - member.name = user['username'] - member.discriminator = user['discriminator'] - member.avatar = user['avatar'] - member.roles = [] - # update the roles - for role in server.roles: - if role.id in data['roles']: - member.roles.append(role) - - self.dispatch('member_update', member) - elif event == 'GUILD_CREATE': - self._add_server(data) - self.dispatch('server_create', self.servers[-1]) - elif event == 'GUILD_DELETE': - server = self._get_server(data.get('id')) - self.servers.remove(server) - self.dispatch('server_delete', server) - - def _opened(self): - log.info('Opened at {}'.format(int(time.time()))) - - def _closed(self, code, reason=None): - log.info('Closed with {} ("{}") at {}'.format(code, reason, int(time.time()))) - self.dispatch('disconnect') + def handle_socket_update(self, event, data): + method = '_'.join(('handle', event.lower())) + getattr(self.connection, method)(data) def run(self): """Runs the client and allows it to receive messages and events.""" log.info('Client is being run') - self.ws.run_forever() + self.ws.run() @property def is_logged_in(self): @@ -399,18 +471,10 @@ class Client(object): return self._is_logged_in def get_channel(self, id): - """Returns a :class:`Channel` or :class:`PrivateChannel` with the following ID. If not found, returns None.""" - if id is None: - return None - - for server in self.servers: - for channel in server.channels: - if channel.id == id: - return channel - - for pm in self.private_channels: - if pm.id == id: - return pm + """Returns a :class:`Channel` or :class:`PrivateChannel` with the + following ID. If not found, returns None. + """ + return self.connection.get_channel(id) def start_private_message(self, user): """Starts a private message with the user. This allows you to :meth:`send_message` to it. @@ -578,7 +642,6 @@ class Client(object): response = requests.post(endpoints.LOGOUT) self.ws.close() self._is_logged_in = False - self.keep_alive.stop.set() log.debug(request_logging_format.format(name='logout', response=response)) def logs_from(self, channel, limit=500):