From 59c3f1ba3fdd909ccf5e1dba6188b63f3c1843bb Mon Sep 17 00:00:00 2001 From: Andrei Date: Wed, 28 Sep 2016 19:14:05 -0500 Subject: [PATCH] Websocket subprocess, more state stuff --- disco/gateway/client.py | 28 +++++++-------- disco/gateway/events.py | 12 +++---- disco/state.py | 53 ++++++++++++++++++++------- disco/types/base.py | 1 + disco/types/guild.py | 6 ++-- disco/util/__init__.py | 4 +-- disco/util/websocket.py | 80 ++++++++++++++++++++++++++++++++++++++++- disco/voice/client.py | 2 +- 8 files changed, 146 insertions(+), 40 deletions(-) diff --git a/disco/gateway/client.py b/disco/gateway/client.py index cca0a1d..55470db 100644 --- a/disco/gateway/client.py +++ b/disco/gateway/client.py @@ -1,11 +1,10 @@ import gevent import zlib -import ssl from disco.gateway.packets import OPCode from disco.gateway.events import GatewayEvent from disco.util.json import loads, dumps -from disco.util.websocket import Websocket +from disco.util.websocket import WebsocketProcessProxy from disco.util.logging import LoggingClass GATEWAY_VERSION = 6 @@ -94,16 +93,15 @@ class GatewayClient(LoggingClass): self._cached_gateway_url = self.client.api.gateway(version=GATEWAY_VERSION, encoding='json') self.log.info('Opening websocket connection to URL `%s`', self._cached_gateway_url) - self.ws = Websocket( - self._cached_gateway_url, - on_message=self.on_message, - on_error=self.on_error, - on_open=self.on_open, - on_close=self.on_close, - ) - self.ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) - - def on_message(self, ws, msg): + self.ws = WebsocketProcessProxy(self._cached_gateway_url) + self.ws.emitter.on('on_open', self.on_open) + self.ws.emitter.on('on_error', self.on_error) + self.ws.emitter.on('on_close', self.on_close) + self.ws.emitter.on('on_message', self.on_message) + + self.ws.run_forever() + + def on_message(self, msg): # Detect zlib and decompress if msg[0] != '{': msg = zlib.decompress(msg, 15, TEN_MEGABYTES).decode("utf-8") @@ -121,12 +119,12 @@ class GatewayClient(LoggingClass): # Emit packet self.packets.emit(OPCode[data['op']], data) - def on_error(self, ws, error): + def on_error(self, error): if isinstance(error, KeyboardInterrupt): self.shutting_down = True raise Exception('WS recieved error: %s', error) - def on_open(self, ws): + def on_open(self): if self.seq and self.session_id: self.log.info('WS Opened: attempting resume w/ SID: %s SEQ: %s', self.session_id, self.seq) self.send(OPCode.RESUME, { @@ -149,7 +147,7 @@ class GatewayClient(LoggingClass): } }) - def on_close(self, ws, code, reason): + def on_close(self, code, reason): if self.shutting_down: self.log.info('WS Closed: shutting down') return diff --git a/disco/gateway/events.py b/disco/gateway/events.py index 51b3835..2aac7bc 100644 --- a/disco/gateway/events.py +++ b/disco/gateway/events.py @@ -7,15 +7,15 @@ from disco.types import Guild, Channel, User, GuildMember, Role, Message, VoiceS class GatewayEvent(skema.Model): @staticmethod - def from_dispatch(client, obj): - cls = globals().get(inflection.camelize(obj['t'].lower())) + def from_dispatch(client, data): + cls = globals().get(inflection.camelize(data['t'].lower())) if not cls: - raise Exception('Could not find cls for {}'.format(obj['t'])) + raise Exception('Could not find cls for {}'.format(data['t'])) - obj = cls.create(obj['d']) + obj = cls.create(data['d']) - for item in skema_find_recursive_by_type(obj, skema.ModelType): - item.client = client + for field, value in skema_find_recursive_by_type(obj, skema.ModelType): + value.client = client return obj diff --git a/disco/state.py b/disco/state.py index 0833892..60d23b2 100644 --- a/disco/state.py +++ b/disco/state.py @@ -1,6 +1,7 @@ from collections import defaultdict, deque, namedtuple from weakref import WeakValueDictionary +from disco.gateway.packets import OPCode StackMessage = namedtuple('StackMessage', ['id', 'channel_id', 'author_id']) @@ -42,6 +43,7 @@ class State(object): self.client.events.on('GuildMemberAdd', self.on_guild_member_add) self.client.events.on('GuildMemberRemove', self.on_guild_member_remove) self.client.events.on('GuildMemberUpdate', self.on_guild_member_update) + self.client.events.on('GuildMemberChunk', self.on_guild_member_chunk) # Guild roles self.client.events.on('GuildRoleCreate', self.on_guild_role_create) @@ -93,6 +95,13 @@ class State(object): for member in event.guild.members.values(): self.users[member.user.id] = member.user + # Request full member list + self.client.gw.send(OPCode.REQUEST_GUILD_MEMBERS, { + 'guild_id': event.guild.id, + 'query': '', + 'limit': 0, + }) + def on_guild_update(self, event): self.guilds[event.guild.id].update(event.guild) @@ -137,34 +146,52 @@ class State(object): else: event.member.user = self.users[event.member.user.id] - if event.member.guild.id not in self.guilds: + if event.member.guild_id not in self.guilds: return - self.guilds[event.member.guild.id].members[event.member.id] = event.member + event.member.guild = self.guilds[event.member.guild_id] + self.guilds[event.member.guild_id].members[event.member.id] = event.member def on_guild_member_update(self, event): - if event.member.guild.id not in self.guilds: + if event.guild_id not in self.guilds: return - # Ensure the reference is correct - assert(event.member.user.id in self.users) - event.member.user = self.users[event.member.user.id] - self.guilds[event.member.guild.id].members[event.member.id] = event.member + self.guilds[event.guild_id].members[event.user.id].roles = event.roles + self.guilds[event.guild_id].members[event.user.id].user.update(event.user) def on_guild_member_remove(self, event): - if event.member.guild.id not in self.guilds: + if event.guild_id not in self.guilds: return - if event.member.id not in self.guilds[event.member.guild.id].members: + if event.user.id not in self.guilds[event.guild_id].members: return - del self.guilds[event.member.guild.id].members[event.member.id] + del self.guilds[event.guild_id].members[event.user.id] + + def on_guild_member_chunk(self, event): + if event.guild_id not in self.guilds: + return + + guild = self.guilds[event.guild_id] + for member in event.members: + member.guild = guild + member.guild_id = guild.id + guild.members[member.id] = member def on_guild_role_create(self, event): - pass + if event.guild_id not in self.guilds: + return + + self.guilds[event.guild_id].roles[event.role.id] = event.role def on_guild_role_update(self, event): - pass + if event.guild_id not in self.guilds: + return + + self.guilds[event.guild_id].roles[event.role.id].update(event.role) def on_guild_role_delete(self, event): - pass + if event.guild_id not in self.guilds: + return + + del self.guilds[event.guild_id].roles[event.role.id] diff --git a/disco/types/base.py b/disco/types/base.py index 3a8a6f6..96e8588 100644 --- a/disco/types/base.py +++ b/disco/types/base.py @@ -2,6 +2,7 @@ import skema import functools from disco.util import skema_find_recursive_by_type +# from disco.util.types import DeferredModel class BaseType(skema.Model): diff --git a/disco/types/guild.py b/disco/types/guild.py index f8e1774..949d3ae 100644 --- a/disco/types/guild.py +++ b/disco/types/guild.py @@ -1,4 +1,5 @@ import skema +import copy from disco.api.http import APIException from disco.util import to_snowflake @@ -29,7 +30,7 @@ class Role(BaseType): class GuildMember(BaseType): user = skema.ModelType(User) - guild = skema.ModelType(Guild) + guild_id = skema.SnowflakeType(required=False) mute = skema.BooleanType() deaf = skema.BooleanType() joined_at = PreHookType(lambda k: k[:-6], skema.DateTimeType()) @@ -68,7 +69,7 @@ class Guild(BaseType): features = skema.ListType(skema.StringType()) - members = ListToDictType('id', skema.ModelType(GuildMember)) + members = ListToDictType('id', skema.ModelType(copy.deepcopy(GuildMember))) channels = ListToDictType('id', skema.ModelType(Channel)) roles = ListToDictType('id', skema.ModelType(Role)) emojis = ListToDictType('id', skema.ModelType(Emoji)) @@ -96,6 +97,7 @@ class Guild(BaseType): if self.members: for member in self.members.values(): member.guild = self + member.guild_id = self.id def validate_channels(self, ctx): if self.channels: diff --git a/disco/util/__init__.py b/disco/util/__init__.py index d745c08..78a14b8 100644 --- a/disco/util/__init__.py +++ b/disco/util/__init__.py @@ -25,7 +25,7 @@ def _recurse(typ, field, value): for item in value: if isinstance(field.field, typ): - result.append(item) + result.append((field.field, item)) result += _recurse(typ, field.field, item) return result @@ -41,7 +41,7 @@ def skema_find_recursive_by_type(base, typ): continue if isinstance(field, typ): - result.append(v) + result.append((field, v)) result += _recurse(typ, field, v) diff --git a/disco/util/websocket.py b/disco/util/websocket.py index 9c6f2ab..978adb9 100644 --- a/disco/util/websocket.py +++ b/disco/util/websocket.py @@ -1,16 +1,28 @@ from __future__ import absolute_import +import sys +import ssl import websocket import gevent import six +import gipc +import signal + +from holster.emitter import Emitter from disco.util.logging import LoggingClass class Websocket(LoggingClass, websocket.WebSocketApp): + """ + Subclass of websocket.WebSocketApp that adds some important improvements: + - Passes exit code to on_error callback in all cases + - Spawns callbacks in a gevent greenlet, and catches/logs exceptions + """ def __init__(self, *args, **kwargs): LoggingClass.__init__(self) websocket.WebSocketApp.__init__(self, *args, **kwargs) + def _get_close_args(self, data): if data and len(data) >= 2: code = 256 * six.byte2int(data[0:1]) + six.byte2int(data[1:2]) @@ -24,5 +36,71 @@ class Websocket(LoggingClass, websocket.WebSocketApp): try: gevent.spawn(callback, self, *args) - except Exception as e: + except Exception: self.log.exception('Error in Websocket callback for {}: '.format(callback)) + + +class WebsocketProcess(Websocket): + def __init__(self, pipe, *args, **kwargs): + Websocket.__init__(self, *args, **kwargs) + self.pipe = pipe + + # Hack to get events to emit + for var in self.__dict__.keys(): + if not var.startswith('on_'): + continue + + setattr(self, var, var) + + def _callback(self, callback, *args): + if not callback: + return + + self.pipe.put((callback, args)) + + +class WebsocketProcessProxy(object): + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + self.emitter = Emitter(gevent.spawn) + + gevent.signal(signal.SIGINT, self.handle_signal) + gevent.signal(signal.SIGQUIT, self.handle_signal) + gevent.signal(signal.SIGTERM, self.handle_signal) + + def handle_signal(self, *args): + self.close() + gevent.sleep(1) + self.process.terminate() + sys.exit() + + @classmethod + def process(cls, pipe, *args, **kwargs): + proc = WebsocketProcess(pipe, *args, **kwargs) + + # TODO: ssl? + gevent.spawn(proc.run_forever, sslopt={'cert_reqs': ssl.CERT_NONE}) + + while True: + op = pipe.get() + getattr(proc, op['method'])(*op['args'], **op['kwargs']) + + def read_task(self): + while True: + try: + name, args = self.pipe.get() + except EOFError: + return + self.emitter.emit(name, *args) + + def run_forever(self): + self.pipe, pipe = gipc.pipe(True) + self.process = gipc.start_process(self.process, args=tuple([pipe] + list(self.args)), kwargs=self.kwargs) + self.read_task() + + def __getattr__(self, attr): + def _wrapped(*args, **kwargs): + self.pipe.put({'method': attr, 'args': args, 'kwargs': kwargs}) + + return _wrapped diff --git a/disco/voice/client.py b/disco/voice/client.py index 1954ce5..f6bc138 100644 --- a/disco/voice/client.py +++ b/disco/voice/client.py @@ -72,7 +72,7 @@ class UDPVoiceClient(LoggingClass): return (None, None) # Read IP and port - ip = data[4:].split('\x00', 1)[0] + ip = str(data[4:]).split('\x00', 1)[0] port = struct.unpack('