You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

175 lines
5.8 KiB

import websocket
import gevent
import json
import zlib
import six
import ssl
from disco.gateway.packets import OPCode, HeartbeatPacket, ResumePacket, IdentifyPacket
from disco.gateway.events import GatewayEvent
from disco.util.logging import LoggingClass
GATEWAY_VERSION = 6
TEN_MEGABYTES = 10490000
# Hack to get websocket close information
def websocket_get_close_args_override(data):
if data and len(data) >= 2:
code = 256 * six.byte2int(data[0:1]) + six.byte2int(data[1:2])
reason = data[2:].decode('utf-8')
return [code, reason]
return [None, None]
class GatewayClient(LoggingClass):
MAX_RECONNECTS = 5
def __init__(self, client):
super(GatewayClient, self).__init__()
self.client = client
self.client.events.on('Ready', self.on_ready)
# Websocket connection
self.ws = None
# State
self.seq = 0
self.session_id = None
self.reconnects = 0
self.shutting_down = False
# Cached gateway URL
self._cached_gateway_url = None
# Heartbeat
self._heartbeat_task = None
def send(self, packet):
self.ws.send(json.dumps({
'op': int(packet.OP),
'd': packet.to_dict(),
}))
def heartbeat_task(self, interval):
while True:
self.send(HeartbeatPacket(data=self.seq))
gevent.sleep(interval / 1000)
def handle_dispatch(self, packet):
obj = GatewayEvent.from_dispatch(self.client, packet)
self.log.debug('Dispatching %s', obj.__class__.__name__)
self.client.events.emit(obj.__class__.__name__, obj)
def handle_heartbeat(self, packet):
self.send(HeartbeatPacket(data=self.seq))
def handle_reconnect(self, packet):
self.log.warning('Received RECONNECT request, forcing a fresh reconnect')
self.session_id = None
self.ws.close()
def handle_invalid_session(self, packet):
self.log.warning('Recieved INVALID_SESSIOIN, forcing a fresh reconnect')
self.sesion_id = None
self.ws.close()
def handle_hello(self, packet):
self.log.info('Recieved HELLO, starting heartbeater...')
self._heartbeat_task = gevent.spawn(self.heartbeat_task, packet['d']['heartbeat_interval'])
def handle_heartbeat_ack(self, packet):
pass
def on_ready(self, ready):
self.log.info('Recieved READY')
self.session_id = ready.session_id
self.reconnects = 0
def connect_and_run(self):
if not self._cached_gateway_url:
self._cached_gateway_url = self.client.api.gateway(version=GATEWAY_VERSION, encoding='json')
self.log.info('Opening websocket connection to URL `%s`', self._cached_gateway_url)
self.ws = websocket.WebSocketApp(
self._cached_gateway_url,
on_message=self.log_on_error('Error in on_message:', self.on_message),
on_error=self.log_on_error('Error in on_error:', self.on_error),
on_open=self.log_on_error('Error in on_open:', self.on_open),
on_close=self.log_on_error('Error in on_close:', self.on_close),
)
self.ws._get_close_args = websocket_get_close_args_override
self.ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
def on_message(self, ws, msg):
# Detect zlib and decompress
if msg[0] != '{':
msg = zlib.decompress(msg, 15, TEN_MEGABYTES)
try:
data = json.loads(msg)
except:
self.log.exception('Failed to load dispatch:')
return
# Update sequence
if data['s'] and data['s'] > self.seq:
self.seq = data['s']
if data['op'] == OPCode.DISPATCH:
self.handle_dispatch(data)
elif data['op'] == OPCode.HEARTBEAT:
self.handle_heartbeat(data)
elif data['op'] == OPCode.RECONNECT:
self.handle_reconnect(data)
elif data['op'] == OPCode.INVALID_SESSION:
self.handle_invalid_session(data)
elif data['op'] == OPCode.HELLO:
self.handle_hello(data)
elif data['op'] == OPCode.HEARTBEAT_ACK:
self.handle_heartbeat_ack(data)
else:
raise Exception('Unknown packet: {}'.format(data['op']))
def on_error(self, ws, error):
if isinstance(error, KeyboardInterrupt):
self.shutting_down = True
raise Exception('WS recieved error: %s', error)
def on_open(self, ws):
if self.seq and self.session_id:
self.log.info('WS Opened: attempting resume w/ SID: %s SEQ: %s', self.session_id, self.seq)
self.send(ResumePacket(seq=self.seq, session_id=self.session_id, token=self.client.token))
else:
self.log.info('WS Opened: sending identify payload')
self.send(IdentifyPacket(
token=self.client.token,
compress=True,
large_threshold=250,
shard=[self.client.sharding['number'], self.client.sharding['total']]))
def on_close(self, ws, code, reason):
if self.shutting_down:
self.log.info('WS Closed: shutting down')
return
self.reconnects += 1
self.log.info('WS Closed: [%s] %s (%s)', code, reason, self.reconnects)
if self.MAX_RECONNECTS and self.reconnects > self.MAX_RECONNECTS:
raise Exception('Failed to reconect after {} attempts, giving up'.format(self.MAX_RECONNECTS))
# Don't resume for these error codes
if code and 4000 <= code <= 4010:
self.session_id = None
wait_time = self.reconnects * 5
self.log.info('Will attempt to %s after %s seconds', 'resume' if self.session_id else 'reconnect', wait_time)
gevent.sleep(wait_time)
# Reconnect
self.connect_and_run()
def run(self):
self.connect_and_run()