Browse Source

Websocket subprocess, more state stuff

pull/3/head
Andrei 9 years ago
parent
commit
59c3f1ba3f
  1. 28
      disco/gateway/client.py
  2. 12
      disco/gateway/events.py
  3. 53
      disco/state.py
  4. 1
      disco/types/base.py
  5. 6
      disco/types/guild.py
  6. 4
      disco/util/__init__.py
  7. 80
      disco/util/websocket.py
  8. 2
      disco/voice/client.py

28
disco/gateway/client.py

@ -1,11 +1,10 @@
import gevent import gevent
import zlib import zlib
import ssl
from disco.gateway.packets import OPCode from disco.gateway.packets import OPCode
from disco.gateway.events import GatewayEvent from disco.gateway.events import GatewayEvent
from disco.util.json import loads, dumps from disco.util.json import loads, dumps
from disco.util.websocket import Websocket from disco.util.websocket import WebsocketProcessProxy
from disco.util.logging import LoggingClass from disco.util.logging import LoggingClass
GATEWAY_VERSION = 6 GATEWAY_VERSION = 6
@ -94,16 +93,15 @@ class GatewayClient(LoggingClass):
self._cached_gateway_url = self.client.api.gateway(version=GATEWAY_VERSION, encoding='json') self._cached_gateway_url = self.client.api.gateway(version=GATEWAY_VERSION, encoding='json')
self.log.info('Opening websocket connection to URL `%s`', self._cached_gateway_url) self.log.info('Opening websocket connection to URL `%s`', self._cached_gateway_url)
self.ws = Websocket( self.ws = WebsocketProcessProxy(self._cached_gateway_url)
self._cached_gateway_url, self.ws.emitter.on('on_open', self.on_open)
on_message=self.on_message, self.ws.emitter.on('on_error', self.on_error)
on_error=self.on_error, self.ws.emitter.on('on_close', self.on_close)
on_open=self.on_open, self.ws.emitter.on('on_message', self.on_message)
on_close=self.on_close,
) self.ws.run_forever()
self.ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
def on_message(self, msg):
def on_message(self, ws, msg):
# Detect zlib and decompress # Detect zlib and decompress
if msg[0] != '{': if msg[0] != '{':
msg = zlib.decompress(msg, 15, TEN_MEGABYTES).decode("utf-8") msg = zlib.decompress(msg, 15, TEN_MEGABYTES).decode("utf-8")
@ -121,12 +119,12 @@ class GatewayClient(LoggingClass):
# Emit packet # Emit packet
self.packets.emit(OPCode[data['op']], data) self.packets.emit(OPCode[data['op']], data)
def on_error(self, ws, error): def on_error(self, error):
if isinstance(error, KeyboardInterrupt): if isinstance(error, KeyboardInterrupt):
self.shutting_down = True self.shutting_down = True
raise Exception('WS recieved error: %s', error) raise Exception('WS recieved error: %s', error)
def on_open(self, ws): def on_open(self):
if self.seq and self.session_id: if self.seq and self.session_id:
self.log.info('WS Opened: attempting resume w/ SID: %s SEQ: %s', self.session_id, self.seq) self.log.info('WS Opened: attempting resume w/ SID: %s SEQ: %s', self.session_id, self.seq)
self.send(OPCode.RESUME, { self.send(OPCode.RESUME, {
@ -149,7 +147,7 @@ class GatewayClient(LoggingClass):
} }
}) })
def on_close(self, ws, code, reason): def on_close(self, code, reason):
if self.shutting_down: if self.shutting_down:
self.log.info('WS Closed: shutting down') self.log.info('WS Closed: shutting down')
return return

12
disco/gateway/events.py

@ -7,15 +7,15 @@ from disco.types import Guild, Channel, User, GuildMember, Role, Message, VoiceS
class GatewayEvent(skema.Model): class GatewayEvent(skema.Model):
@staticmethod @staticmethod
def from_dispatch(client, obj): def from_dispatch(client, data):
cls = globals().get(inflection.camelize(obj['t'].lower())) cls = globals().get(inflection.camelize(data['t'].lower()))
if not cls: if not cls:
raise Exception('Could not find cls for {}'.format(obj['t'])) raise Exception('Could not find cls for {}'.format(data['t']))
obj = cls.create(obj['d']) obj = cls.create(data['d'])
for item in skema_find_recursive_by_type(obj, skema.ModelType): for field, value in skema_find_recursive_by_type(obj, skema.ModelType):
item.client = client value.client = client
return obj return obj

53
disco/state.py

@ -1,6 +1,7 @@
from collections import defaultdict, deque, namedtuple from collections import defaultdict, deque, namedtuple
from weakref import WeakValueDictionary from weakref import WeakValueDictionary
from disco.gateway.packets import OPCode
StackMessage = namedtuple('StackMessage', ['id', 'channel_id', 'author_id']) StackMessage = namedtuple('StackMessage', ['id', 'channel_id', 'author_id'])
@ -42,6 +43,7 @@ class State(object):
self.client.events.on('GuildMemberAdd', self.on_guild_member_add) self.client.events.on('GuildMemberAdd', self.on_guild_member_add)
self.client.events.on('GuildMemberRemove', self.on_guild_member_remove) self.client.events.on('GuildMemberRemove', self.on_guild_member_remove)
self.client.events.on('GuildMemberUpdate', self.on_guild_member_update) self.client.events.on('GuildMemberUpdate', self.on_guild_member_update)
self.client.events.on('GuildMemberChunk', self.on_guild_member_chunk)
# Guild roles # Guild roles
self.client.events.on('GuildRoleCreate', self.on_guild_role_create) self.client.events.on('GuildRoleCreate', self.on_guild_role_create)
@ -93,6 +95,13 @@ class State(object):
for member in event.guild.members.values(): for member in event.guild.members.values():
self.users[member.user.id] = member.user self.users[member.user.id] = member.user
# Request full member list
self.client.gw.send(OPCode.REQUEST_GUILD_MEMBERS, {
'guild_id': event.guild.id,
'query': '',
'limit': 0,
})
def on_guild_update(self, event): def on_guild_update(self, event):
self.guilds[event.guild.id].update(event.guild) self.guilds[event.guild.id].update(event.guild)
@ -137,34 +146,52 @@ class State(object):
else: else:
event.member.user = self.users[event.member.user.id] event.member.user = self.users[event.member.user.id]
if event.member.guild.id not in self.guilds: if event.member.guild_id not in self.guilds:
return return
self.guilds[event.member.guild.id].members[event.member.id] = event.member event.member.guild = self.guilds[event.member.guild_id]
self.guilds[event.member.guild_id].members[event.member.id] = event.member
def on_guild_member_update(self, event): def on_guild_member_update(self, event):
if event.member.guild.id not in self.guilds: if event.guild_id not in self.guilds:
return return
# Ensure the reference is correct self.guilds[event.guild_id].members[event.user.id].roles = event.roles
assert(event.member.user.id in self.users) self.guilds[event.guild_id].members[event.user.id].user.update(event.user)
event.member.user = self.users[event.member.user.id]
self.guilds[event.member.guild.id].members[event.member.id] = event.member
def on_guild_member_remove(self, event): def on_guild_member_remove(self, event):
if event.member.guild.id not in self.guilds: if event.guild_id not in self.guilds:
return
if event.user.id not in self.guilds[event.guild_id].members:
return return
if event.member.id not in self.guilds[event.member.guild.id].members: del self.guilds[event.guild_id].members[event.user.id]
def on_guild_member_chunk(self, event):
if event.guild_id not in self.guilds:
return return
del self.guilds[event.member.guild.id].members[event.member.id] guild = self.guilds[event.guild_id]
for member in event.members:
member.guild = guild
member.guild_id = guild.id
guild.members[member.id] = member
def on_guild_role_create(self, event): def on_guild_role_create(self, event):
pass if event.guild_id not in self.guilds:
return
self.guilds[event.guild_id].roles[event.role.id] = event.role
def on_guild_role_update(self, event): def on_guild_role_update(self, event):
pass if event.guild_id not in self.guilds:
return
self.guilds[event.guild_id].roles[event.role.id].update(event.role)
def on_guild_role_delete(self, event): def on_guild_role_delete(self, event):
pass if event.guild_id not in self.guilds:
return
del self.guilds[event.guild_id].roles[event.role.id]

1
disco/types/base.py

@ -2,6 +2,7 @@ import skema
import functools import functools
from disco.util import skema_find_recursive_by_type from disco.util import skema_find_recursive_by_type
# from disco.util.types import DeferredModel
class BaseType(skema.Model): class BaseType(skema.Model):

6
disco/types/guild.py

@ -1,4 +1,5 @@
import skema import skema
import copy
from disco.api.http import APIException from disco.api.http import APIException
from disco.util import to_snowflake from disco.util import to_snowflake
@ -29,7 +30,7 @@ class Role(BaseType):
class GuildMember(BaseType): class GuildMember(BaseType):
user = skema.ModelType(User) user = skema.ModelType(User)
guild = skema.ModelType(Guild) guild_id = skema.SnowflakeType(required=False)
mute = skema.BooleanType() mute = skema.BooleanType()
deaf = skema.BooleanType() deaf = skema.BooleanType()
joined_at = PreHookType(lambda k: k[:-6], skema.DateTimeType()) joined_at = PreHookType(lambda k: k[:-6], skema.DateTimeType())
@ -68,7 +69,7 @@ class Guild(BaseType):
features = skema.ListType(skema.StringType()) features = skema.ListType(skema.StringType())
members = ListToDictType('id', skema.ModelType(GuildMember)) members = ListToDictType('id', skema.ModelType(copy.deepcopy(GuildMember)))
channels = ListToDictType('id', skema.ModelType(Channel)) channels = ListToDictType('id', skema.ModelType(Channel))
roles = ListToDictType('id', skema.ModelType(Role)) roles = ListToDictType('id', skema.ModelType(Role))
emojis = ListToDictType('id', skema.ModelType(Emoji)) emojis = ListToDictType('id', skema.ModelType(Emoji))
@ -96,6 +97,7 @@ class Guild(BaseType):
if self.members: if self.members:
for member in self.members.values(): for member in self.members.values():
member.guild = self member.guild = self
member.guild_id = self.id
def validate_channels(self, ctx): def validate_channels(self, ctx):
if self.channels: if self.channels:

4
disco/util/__init__.py

@ -25,7 +25,7 @@ def _recurse(typ, field, value):
for item in value: for item in value:
if isinstance(field.field, typ): if isinstance(field.field, typ):
result.append(item) result.append((field.field, item))
result += _recurse(typ, field.field, item) result += _recurse(typ, field.field, item)
return result return result
@ -41,7 +41,7 @@ def skema_find_recursive_by_type(base, typ):
continue continue
if isinstance(field, typ): if isinstance(field, typ):
result.append(v) result.append((field, v))
result += _recurse(typ, field, v) result += _recurse(typ, field, v)

80
disco/util/websocket.py

@ -1,16 +1,28 @@
from __future__ import absolute_import from __future__ import absolute_import
import sys
import ssl
import websocket import websocket
import gevent import gevent
import six import six
import gipc
import signal
from holster.emitter import Emitter
from disco.util.logging import LoggingClass from disco.util.logging import LoggingClass
class Websocket(LoggingClass, websocket.WebSocketApp): class Websocket(LoggingClass, websocket.WebSocketApp):
"""
Subclass of websocket.WebSocketApp that adds some important improvements:
- Passes exit code to on_error callback in all cases
- Spawns callbacks in a gevent greenlet, and catches/logs exceptions
"""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
LoggingClass.__init__(self) LoggingClass.__init__(self)
websocket.WebSocketApp.__init__(self, *args, **kwargs) websocket.WebSocketApp.__init__(self, *args, **kwargs)
def _get_close_args(self, data): def _get_close_args(self, data):
if data and len(data) >= 2: if data and len(data) >= 2:
code = 256 * six.byte2int(data[0:1]) + six.byte2int(data[1:2]) code = 256 * six.byte2int(data[0:1]) + six.byte2int(data[1:2])
@ -24,5 +36,71 @@ class Websocket(LoggingClass, websocket.WebSocketApp):
try: try:
gevent.spawn(callback, self, *args) gevent.spawn(callback, self, *args)
except Exception as e: except Exception:
self.log.exception('Error in Websocket callback for {}: '.format(callback)) self.log.exception('Error in Websocket callback for {}: '.format(callback))
class WebsocketProcess(Websocket):
def __init__(self, pipe, *args, **kwargs):
Websocket.__init__(self, *args, **kwargs)
self.pipe = pipe
# Hack to get events to emit
for var in self.__dict__.keys():
if not var.startswith('on_'):
continue
setattr(self, var, var)
def _callback(self, callback, *args):
if not callback:
return
self.pipe.put((callback, args))
class WebsocketProcessProxy(object):
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
self.emitter = Emitter(gevent.spawn)
gevent.signal(signal.SIGINT, self.handle_signal)
gevent.signal(signal.SIGQUIT, self.handle_signal)
gevent.signal(signal.SIGTERM, self.handle_signal)
def handle_signal(self, *args):
self.close()
gevent.sleep(1)
self.process.terminate()
sys.exit()
@classmethod
def process(cls, pipe, *args, **kwargs):
proc = WebsocketProcess(pipe, *args, **kwargs)
# TODO: ssl?
gevent.spawn(proc.run_forever, sslopt={'cert_reqs': ssl.CERT_NONE})
while True:
op = pipe.get()
getattr(proc, op['method'])(*op['args'], **op['kwargs'])
def read_task(self):
while True:
try:
name, args = self.pipe.get()
except EOFError:
return
self.emitter.emit(name, *args)
def run_forever(self):
self.pipe, pipe = gipc.pipe(True)
self.process = gipc.start_process(self.process, args=tuple([pipe] + list(self.args)), kwargs=self.kwargs)
self.read_task()
def __getattr__(self, attr):
def _wrapped(*args, **kwargs):
self.pipe.put({'method': attr, 'args': args, 'kwargs': kwargs})
return _wrapped

2
disco/voice/client.py

@ -72,7 +72,7 @@ class UDPVoiceClient(LoggingClass):
return (None, None) return (None, None)
# Read IP and port # Read IP and port
ip = data[4:].split('\x00', 1)[0] ip = str(data[4:]).split('\x00', 1)[0]
port = struct.unpack('<H', data[-2:])[0] port = struct.unpack('<H', data[-2:])[0]
# Spawn read thread so we don't max buffers # Spawn read thread so we don't max buffers

Loading…
Cancel
Save