Browse Source

Websocket subprocess, more state stuff

pull/3/head
Andrei 9 years ago
parent
commit
59c3f1ba3f
  1. 28
      disco/gateway/client.py
  2. 12
      disco/gateway/events.py
  3. 53
      disco/state.py
  4. 1
      disco/types/base.py
  5. 6
      disco/types/guild.py
  6. 4
      disco/util/__init__.py
  7. 80
      disco/util/websocket.py
  8. 2
      disco/voice/client.py

28
disco/gateway/client.py

@ -1,11 +1,10 @@
import gevent
import zlib
import ssl
from disco.gateway.packets import OPCode
from disco.gateway.events import GatewayEvent
from disco.util.json import loads, dumps
from disco.util.websocket import Websocket
from disco.util.websocket import WebsocketProcessProxy
from disco.util.logging import LoggingClass
GATEWAY_VERSION = 6
@ -94,16 +93,15 @@ class GatewayClient(LoggingClass):
self._cached_gateway_url = self.client.api.gateway(version=GATEWAY_VERSION, encoding='json')
self.log.info('Opening websocket connection to URL `%s`', self._cached_gateway_url)
self.ws = Websocket(
self._cached_gateway_url,
on_message=self.on_message,
on_error=self.on_error,
on_open=self.on_open,
on_close=self.on_close,
)
self.ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
def on_message(self, ws, msg):
self.ws = WebsocketProcessProxy(self._cached_gateway_url)
self.ws.emitter.on('on_open', self.on_open)
self.ws.emitter.on('on_error', self.on_error)
self.ws.emitter.on('on_close', self.on_close)
self.ws.emitter.on('on_message', self.on_message)
self.ws.run_forever()
def on_message(self, msg):
# Detect zlib and decompress
if msg[0] != '{':
msg = zlib.decompress(msg, 15, TEN_MEGABYTES).decode("utf-8")
@ -121,12 +119,12 @@ class GatewayClient(LoggingClass):
# Emit packet
self.packets.emit(OPCode[data['op']], data)
def on_error(self, ws, error):
def on_error(self, error):
if isinstance(error, KeyboardInterrupt):
self.shutting_down = True
raise Exception('WS recieved error: %s', error)
def on_open(self, ws):
def on_open(self):
if self.seq and self.session_id:
self.log.info('WS Opened: attempting resume w/ SID: %s SEQ: %s', self.session_id, self.seq)
self.send(OPCode.RESUME, {
@ -149,7 +147,7 @@ class GatewayClient(LoggingClass):
}
})
def on_close(self, ws, code, reason):
def on_close(self, code, reason):
if self.shutting_down:
self.log.info('WS Closed: shutting down')
return

12
disco/gateway/events.py

@ -7,15 +7,15 @@ from disco.types import Guild, Channel, User, GuildMember, Role, Message, VoiceS
class GatewayEvent(skema.Model):
@staticmethod
def from_dispatch(client, obj):
cls = globals().get(inflection.camelize(obj['t'].lower()))
def from_dispatch(client, data):
cls = globals().get(inflection.camelize(data['t'].lower()))
if not cls:
raise Exception('Could not find cls for {}'.format(obj['t']))
raise Exception('Could not find cls for {}'.format(data['t']))
obj = cls.create(obj['d'])
obj = cls.create(data['d'])
for item in skema_find_recursive_by_type(obj, skema.ModelType):
item.client = client
for field, value in skema_find_recursive_by_type(obj, skema.ModelType):
value.client = client
return obj

53
disco/state.py

@ -1,6 +1,7 @@
from collections import defaultdict, deque, namedtuple
from weakref import WeakValueDictionary
from disco.gateway.packets import OPCode
StackMessage = namedtuple('StackMessage', ['id', 'channel_id', 'author_id'])
@ -42,6 +43,7 @@ class State(object):
self.client.events.on('GuildMemberAdd', self.on_guild_member_add)
self.client.events.on('GuildMemberRemove', self.on_guild_member_remove)
self.client.events.on('GuildMemberUpdate', self.on_guild_member_update)
self.client.events.on('GuildMemberChunk', self.on_guild_member_chunk)
# Guild roles
self.client.events.on('GuildRoleCreate', self.on_guild_role_create)
@ -93,6 +95,13 @@ class State(object):
for member in event.guild.members.values():
self.users[member.user.id] = member.user
# Request full member list
self.client.gw.send(OPCode.REQUEST_GUILD_MEMBERS, {
'guild_id': event.guild.id,
'query': '',
'limit': 0,
})
def on_guild_update(self, event):
self.guilds[event.guild.id].update(event.guild)
@ -137,34 +146,52 @@ class State(object):
else:
event.member.user = self.users[event.member.user.id]
if event.member.guild.id not in self.guilds:
if event.member.guild_id not in self.guilds:
return
self.guilds[event.member.guild.id].members[event.member.id] = event.member
event.member.guild = self.guilds[event.member.guild_id]
self.guilds[event.member.guild_id].members[event.member.id] = event.member
def on_guild_member_update(self, event):
if event.member.guild.id not in self.guilds:
if event.guild_id not in self.guilds:
return
# Ensure the reference is correct
assert(event.member.user.id in self.users)
event.member.user = self.users[event.member.user.id]
self.guilds[event.member.guild.id].members[event.member.id] = event.member
self.guilds[event.guild_id].members[event.user.id].roles = event.roles
self.guilds[event.guild_id].members[event.user.id].user.update(event.user)
def on_guild_member_remove(self, event):
if event.member.guild.id not in self.guilds:
if event.guild_id not in self.guilds:
return
if event.member.id not in self.guilds[event.member.guild.id].members:
if event.user.id not in self.guilds[event.guild_id].members:
return
del self.guilds[event.member.guild.id].members[event.member.id]
del self.guilds[event.guild_id].members[event.user.id]
def on_guild_member_chunk(self, event):
if event.guild_id not in self.guilds:
return
guild = self.guilds[event.guild_id]
for member in event.members:
member.guild = guild
member.guild_id = guild.id
guild.members[member.id] = member
def on_guild_role_create(self, event):
pass
if event.guild_id not in self.guilds:
return
self.guilds[event.guild_id].roles[event.role.id] = event.role
def on_guild_role_update(self, event):
pass
if event.guild_id not in self.guilds:
return
self.guilds[event.guild_id].roles[event.role.id].update(event.role)
def on_guild_role_delete(self, event):
pass
if event.guild_id not in self.guilds:
return
del self.guilds[event.guild_id].roles[event.role.id]

1
disco/types/base.py

@ -2,6 +2,7 @@ import skema
import functools
from disco.util import skema_find_recursive_by_type
# from disco.util.types import DeferredModel
class BaseType(skema.Model):

6
disco/types/guild.py

@ -1,4 +1,5 @@
import skema
import copy
from disco.api.http import APIException
from disco.util import to_snowflake
@ -29,7 +30,7 @@ class Role(BaseType):
class GuildMember(BaseType):
user = skema.ModelType(User)
guild = skema.ModelType(Guild)
guild_id = skema.SnowflakeType(required=False)
mute = skema.BooleanType()
deaf = skema.BooleanType()
joined_at = PreHookType(lambda k: k[:-6], skema.DateTimeType())
@ -68,7 +69,7 @@ class Guild(BaseType):
features = skema.ListType(skema.StringType())
members = ListToDictType('id', skema.ModelType(GuildMember))
members = ListToDictType('id', skema.ModelType(copy.deepcopy(GuildMember)))
channels = ListToDictType('id', skema.ModelType(Channel))
roles = ListToDictType('id', skema.ModelType(Role))
emojis = ListToDictType('id', skema.ModelType(Emoji))
@ -96,6 +97,7 @@ class Guild(BaseType):
if self.members:
for member in self.members.values():
member.guild = self
member.guild_id = self.id
def validate_channels(self, ctx):
if self.channels:

4
disco/util/__init__.py

@ -25,7 +25,7 @@ def _recurse(typ, field, value):
for item in value:
if isinstance(field.field, typ):
result.append(item)
result.append((field.field, item))
result += _recurse(typ, field.field, item)
return result
@ -41,7 +41,7 @@ def skema_find_recursive_by_type(base, typ):
continue
if isinstance(field, typ):
result.append(v)
result.append((field, v))
result += _recurse(typ, field, v)

80
disco/util/websocket.py

@ -1,16 +1,28 @@
from __future__ import absolute_import
import sys
import ssl
import websocket
import gevent
import six
import gipc
import signal
from holster.emitter import Emitter
from disco.util.logging import LoggingClass
class Websocket(LoggingClass, websocket.WebSocketApp):
"""
Subclass of websocket.WebSocketApp that adds some important improvements:
- Passes exit code to on_error callback in all cases
- Spawns callbacks in a gevent greenlet, and catches/logs exceptions
"""
def __init__(self, *args, **kwargs):
LoggingClass.__init__(self)
websocket.WebSocketApp.__init__(self, *args, **kwargs)
def _get_close_args(self, data):
if data and len(data) >= 2:
code = 256 * six.byte2int(data[0:1]) + six.byte2int(data[1:2])
@ -24,5 +36,71 @@ class Websocket(LoggingClass, websocket.WebSocketApp):
try:
gevent.spawn(callback, self, *args)
except Exception as e:
except Exception:
self.log.exception('Error in Websocket callback for {}: '.format(callback))
class WebsocketProcess(Websocket):
def __init__(self, pipe, *args, **kwargs):
Websocket.__init__(self, *args, **kwargs)
self.pipe = pipe
# Hack to get events to emit
for var in self.__dict__.keys():
if not var.startswith('on_'):
continue
setattr(self, var, var)
def _callback(self, callback, *args):
if not callback:
return
self.pipe.put((callback, args))
class WebsocketProcessProxy(object):
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
self.emitter = Emitter(gevent.spawn)
gevent.signal(signal.SIGINT, self.handle_signal)
gevent.signal(signal.SIGQUIT, self.handle_signal)
gevent.signal(signal.SIGTERM, self.handle_signal)
def handle_signal(self, *args):
self.close()
gevent.sleep(1)
self.process.terminate()
sys.exit()
@classmethod
def process(cls, pipe, *args, **kwargs):
proc = WebsocketProcess(pipe, *args, **kwargs)
# TODO: ssl?
gevent.spawn(proc.run_forever, sslopt={'cert_reqs': ssl.CERT_NONE})
while True:
op = pipe.get()
getattr(proc, op['method'])(*op['args'], **op['kwargs'])
def read_task(self):
while True:
try:
name, args = self.pipe.get()
except EOFError:
return
self.emitter.emit(name, *args)
def run_forever(self):
self.pipe, pipe = gipc.pipe(True)
self.process = gipc.start_process(self.process, args=tuple([pipe] + list(self.args)), kwargs=self.kwargs)
self.read_task()
def __getattr__(self, attr):
def _wrapped(*args, **kwargs):
self.pipe.put({'method': attr, 'args': args, 'kwargs': kwargs})
return _wrapped

2
disco/voice/client.py

@ -72,7 +72,7 @@ class UDPVoiceClient(LoggingClass):
return (None, None)
# Read IP and port
ip = data[4:].split('\x00', 1)[0]
ip = str(data[4:]).split('\x00', 1)[0]
port = struct.unpack('<H', data[-2:])[0]
# Spawn read thread so we don't max buffers

Loading…
Cancel
Save