diff --git a/README.md b/README.md index 339ee16..7bcaf17 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,6 @@ A Discord Python bot built to be easy to use and scale. ## TODOS -- flesh out gateway paths (reconnect/resume) - flesh out API client - storage/database/config - flesh out type methods diff --git a/disco/client.py b/disco/client.py index d4df839..0d6cc26 100644 --- a/disco/client.py +++ b/disco/client.py @@ -23,7 +23,7 @@ class DiscoClient(object): self.gw = GatewayClient(self) def run(self): - return self.gw.run() + return gevent.spawn(self.gw.run) def run_forever(self): - return self.gw.run().join() + return self.gw.run() diff --git a/disco/gateway/client.py b/disco/gateway/client.py index a677777..47eecc0 100644 --- a/disco/gateway/client.py +++ b/disco/gateway/client.py @@ -2,6 +2,7 @@ import websocket import gevent import json import zlib +import six from disco.gateway.packets import OPCode, HeartbeatPacket, ResumePacket, IdentifyPacket from disco.gateway.events import GatewayEvent @@ -11,17 +12,18 @@ GATEWAY_VERSION = 6 TEN_MEGABYTES = 10490000 -def log_error(log, msg, w): - def _f(*args, **kwargs): - try: - return w(*args, **kwargs) - except: - log.exception(msg) - raise - return _f +# Hack to get websocket close information +def websocket_get_close_args_override(data): + if data and len(data) >= 2: + code = 256 * six.byte2int(data[0:1]) + six.byte2int(data[1:2]) + reason = data[2:].decode('utf-8') + return [code, reason] + return [None, None] class GatewayClient(LoggingClass): + MAX_RECONNECTS = 5 + def __init__(self, client): super(GatewayClient, self).__init__() self.client = client @@ -42,8 +44,6 @@ class GatewayClient(LoggingClass): # Heartbeat self._heartbeat_task = None - self._fatal_error_promise = gevent.event.AsyncResult() - def send(self, packet): self.ws.send(json.dumps({ 'op': int(packet.OP), @@ -61,13 +61,17 @@ class GatewayClient(LoggingClass): self.client.events.emit(obj.__class__.__name__, obj) def handle_heartbeat(self, packet): - pass + self.send(HeartbeatPacket(data=self.seq)) def handle_reconnect(self, packet): - pass + self.log.warning('Received RECONNECT request, forcing a fresh reconnect') + self.session_id = None + self.ws.close() def handle_invalid_session(self, packet): - pass + self.log.warning('Recieved INVALID_SESSIOIN, forcing a fresh reconnect') + self.sesion_id = None + self.ws.close() def handle_hello(self, packet): self.log.info('Recieved HELLO, starting heartbeater...') @@ -88,11 +92,12 @@ class GatewayClient(LoggingClass): self.log.info('Opening websocket connection to URL `%s`', self._cached_gateway_url) self.ws = websocket.WebSocketApp( self._cached_gateway_url, - on_message=log_error(self.log, 'Error in on_message:', self.on_message), - on_error=log_error(self.log, 'Error in on_error:', self.on_error), - on_open=log_error(self.log, 'Error in on_open:', self.on_open), - on_close=log_error(self.log, 'Error in on_close:', self.on_close), + on_message=self.log_on_error('Error in on_message:', self.on_message), + on_error=self.log_on_error('Error in on_error:', self.on_error), + on_open=self.log_on_error('Error in on_open:', self.on_open), + on_close=self.log_on_error('Error in on_close:', self.on_close), ) + self.ws._get_close_args = websocket_get_close_args_override def on_message(self, ws, msg): # Detect zlib and decompress @@ -125,27 +130,41 @@ class GatewayClient(LoggingClass): raise Exception('Unknown packet: {}'.format(data['op'])) def on_error(self, ws, error): - print 'error', error + raise Exception('WS recieved error: %s', error) def on_open(self, ws): - print 'on open' if self.seq and self.session_id: + self.log.info('WS Opened: attempting resume w/ SID: %s SEQ: %s', self.session_id, self.seq) self.send(ResumePacket(seq=self.seq, session_id=self.session_id, token=self.client.token)) else: + self.log.info('WS Opened: sending identify payload') self.send(IdentifyPacket( token=self.client.token, compress=True, large_threshold=250, shard=[self.client.sharding['number'], self.client.sharding['total']])) - def on_close(self, ws): - print 'close' + def on_close(self, ws, code, reason): + self.reconnects += 1 + self.log.info('WS Closed: [%s] %s (%s)', code, reason, self.reconnects) - def run(self): - self.connect() + if self.MAX_RECONNECTS and self.reconnects > self.MAX_RECONNECTS: + raise Exception('Failed to reconect after {} attempts, giving up'.format(self.MAX_RECONNECTS)) + + # Don't resume for these error codes + if 4000 <= code <= 4010: + self.session_id = None + self.log.info('Attempting fresh reconnect') + else: + self.log.info('Attempting resume') - # Spawn a thread to run the connection loop forever - gevent.spawn(self.ws.run_forever) + wait_time = self.reconnects * 5 + self.log.info('Will attempt to {} after {} seconds', 'resume' if self.session_id else 'reconnect', wait_time) + gevent.sleep(wait_time) - # Wait for a fatal error - self._fatal_error_promise.get() + # Reconnect + self.connect() + + def run(self): + self.connect() + self.ws.run_forever() diff --git a/disco/util/logging.py b/disco/util/logging.py index 24bfe6c..7feca4d 100644 --- a/disco/util/logging.py +++ b/disco/util/logging.py @@ -6,3 +6,12 @@ import logging class LoggingClass(object): def __init__(self): self.log = logging.getLogger(self.__class__.__name__) + + def log_on_error(self, msg, f): + def _f(*args, **kwargs): + try: + return f(*args, **kwargs) + except: + self.log.exception(msg) + raise + return _f