Browse Source

Flesh out gateway client

pull/3/head
Andrei 9 years ago
parent
commit
a240bd1e90
  1. 1
      README.md
  2. 4
      disco/client.py
  3. 73
      disco/gateway/client.py
  4. 9
      disco/util/logging.py

1
README.md

@ -3,7 +3,6 @@ A Discord Python bot built to be easy to use and scale.
## TODOS ## TODOS
- flesh out gateway paths (reconnect/resume)
- flesh out API client - flesh out API client
- storage/database/config - storage/database/config
- flesh out type methods - flesh out type methods

4
disco/client.py

@ -23,7 +23,7 @@ class DiscoClient(object):
self.gw = GatewayClient(self) self.gw = GatewayClient(self)
def run(self): def run(self):
return self.gw.run() return gevent.spawn(self.gw.run)
def run_forever(self): def run_forever(self):
return self.gw.run().join() return self.gw.run()

73
disco/gateway/client.py

@ -2,6 +2,7 @@ import websocket
import gevent import gevent
import json import json
import zlib import zlib
import six
from disco.gateway.packets import OPCode, HeartbeatPacket, ResumePacket, IdentifyPacket from disco.gateway.packets import OPCode, HeartbeatPacket, ResumePacket, IdentifyPacket
from disco.gateway.events import GatewayEvent from disco.gateway.events import GatewayEvent
@ -11,17 +12,18 @@ GATEWAY_VERSION = 6
TEN_MEGABYTES = 10490000 TEN_MEGABYTES = 10490000
def log_error(log, msg, w): # Hack to get websocket close information
def _f(*args, **kwargs): def websocket_get_close_args_override(data):
try: if data and len(data) >= 2:
return w(*args, **kwargs) code = 256 * six.byte2int(data[0:1]) + six.byte2int(data[1:2])
except: reason = data[2:].decode('utf-8')
log.exception(msg) return [code, reason]
raise return [None, None]
return _f
class GatewayClient(LoggingClass): class GatewayClient(LoggingClass):
MAX_RECONNECTS = 5
def __init__(self, client): def __init__(self, client):
super(GatewayClient, self).__init__() super(GatewayClient, self).__init__()
self.client = client self.client = client
@ -42,8 +44,6 @@ class GatewayClient(LoggingClass):
# Heartbeat # Heartbeat
self._heartbeat_task = None self._heartbeat_task = None
self._fatal_error_promise = gevent.event.AsyncResult()
def send(self, packet): def send(self, packet):
self.ws.send(json.dumps({ self.ws.send(json.dumps({
'op': int(packet.OP), 'op': int(packet.OP),
@ -61,13 +61,17 @@ class GatewayClient(LoggingClass):
self.client.events.emit(obj.__class__.__name__, obj) self.client.events.emit(obj.__class__.__name__, obj)
def handle_heartbeat(self, packet): def handle_heartbeat(self, packet):
pass self.send(HeartbeatPacket(data=self.seq))
def handle_reconnect(self, packet): def handle_reconnect(self, packet):
pass self.log.warning('Received RECONNECT request, forcing a fresh reconnect')
self.session_id = None
self.ws.close()
def handle_invalid_session(self, packet): def handle_invalid_session(self, packet):
pass self.log.warning('Recieved INVALID_SESSIOIN, forcing a fresh reconnect')
self.sesion_id = None
self.ws.close()
def handle_hello(self, packet): def handle_hello(self, packet):
self.log.info('Recieved HELLO, starting heartbeater...') self.log.info('Recieved HELLO, starting heartbeater...')
@ -88,11 +92,12 @@ class GatewayClient(LoggingClass):
self.log.info('Opening websocket connection to URL `%s`', self._cached_gateway_url) self.log.info('Opening websocket connection to URL `%s`', self._cached_gateway_url)
self.ws = websocket.WebSocketApp( self.ws = websocket.WebSocketApp(
self._cached_gateway_url, self._cached_gateway_url,
on_message=log_error(self.log, 'Error in on_message:', self.on_message), on_message=self.log_on_error('Error in on_message:', self.on_message),
on_error=log_error(self.log, 'Error in on_error:', self.on_error), on_error=self.log_on_error('Error in on_error:', self.on_error),
on_open=log_error(self.log, 'Error in on_open:', self.on_open), on_open=self.log_on_error('Error in on_open:', self.on_open),
on_close=log_error(self.log, 'Error in on_close:', self.on_close), on_close=self.log_on_error('Error in on_close:', self.on_close),
) )
self.ws._get_close_args = websocket_get_close_args_override
def on_message(self, ws, msg): def on_message(self, ws, msg):
# Detect zlib and decompress # Detect zlib and decompress
@ -125,27 +130,41 @@ class GatewayClient(LoggingClass):
raise Exception('Unknown packet: {}'.format(data['op'])) raise Exception('Unknown packet: {}'.format(data['op']))
def on_error(self, ws, error): def on_error(self, ws, error):
print 'error', error raise Exception('WS recieved error: %s', error)
def on_open(self, ws): def on_open(self, ws):
print 'on open'
if self.seq and self.session_id: if self.seq and self.session_id:
self.log.info('WS Opened: attempting resume w/ SID: %s SEQ: %s', self.session_id, self.seq)
self.send(ResumePacket(seq=self.seq, session_id=self.session_id, token=self.client.token)) self.send(ResumePacket(seq=self.seq, session_id=self.session_id, token=self.client.token))
else: else:
self.log.info('WS Opened: sending identify payload')
self.send(IdentifyPacket( self.send(IdentifyPacket(
token=self.client.token, token=self.client.token,
compress=True, compress=True,
large_threshold=250, large_threshold=250,
shard=[self.client.sharding['number'], self.client.sharding['total']])) shard=[self.client.sharding['number'], self.client.sharding['total']]))
def on_close(self, ws): def on_close(self, ws, code, reason):
print 'close' self.reconnects += 1
self.log.info('WS Closed: [%s] %s (%s)', code, reason, self.reconnects)
def run(self): if self.MAX_RECONNECTS and self.reconnects > self.MAX_RECONNECTS:
self.connect() raise Exception('Failed to reconect after {} attempts, giving up'.format(self.MAX_RECONNECTS))
# Don't resume for these error codes
if 4000 <= code <= 4010:
self.session_id = None
self.log.info('Attempting fresh reconnect')
else:
self.log.info('Attempting resume')
# Spawn a thread to run the connection loop forever wait_time = self.reconnects * 5
gevent.spawn(self.ws.run_forever) self.log.info('Will attempt to {} after {} seconds', 'resume' if self.session_id else 'reconnect', wait_time)
gevent.sleep(wait_time)
# Wait for a fatal error # Reconnect
self._fatal_error_promise.get() self.connect()
def run(self):
self.connect()
self.ws.run_forever()

9
disco/util/logging.py

@ -6,3 +6,12 @@ import logging
class LoggingClass(object): class LoggingClass(object):
def __init__(self): def __init__(self):
self.log = logging.getLogger(self.__class__.__name__) self.log = logging.getLogger(self.__class__.__name__)
def log_on_error(self, msg, f):
def _f(*args, **kwargs):
try:
return f(*args, **kwargs)
except:
self.log.exception(msg)
raise
return _f

Loading…
Cancel
Save