|
|
@ -7,13 +7,16 @@ from gevent import queue |
|
|
|
from gevent import event |
|
|
|
from gevent.select import select as gselect |
|
|
|
|
|
|
|
import ssl |
|
|
|
import certifi |
|
|
|
|
|
|
|
from wsproto import WSConnection, events as wsevents |
|
|
|
from wsproto.connection import ConnectionType, ConnectionState |
|
|
|
|
|
|
|
logger = logging.getLogger("Connection") |
|
|
|
|
|
|
|
|
|
|
|
class Connection(object): |
|
|
|
MAGIC = b'VT01' |
|
|
|
FMT = '<I4s' |
|
|
|
FMT_SIZE = struct.calcsize(FMT) |
|
|
|
|
|
|
|
def __init__(self): |
|
|
|
self.socket = None |
|
|
@ -48,8 +51,9 @@ class Connection(object): |
|
|
|
self._reader = gevent.spawn(self._reader_loop) |
|
|
|
self._writer = gevent.spawn(self._writer_loop) |
|
|
|
|
|
|
|
logger.debug("Connected.") |
|
|
|
self.event_connected.set() |
|
|
|
# how this gets set is implementation dependent |
|
|
|
self.event_connected.wait(timeout=10) |
|
|
|
|
|
|
|
return True |
|
|
|
|
|
|
|
def disconnect(self): |
|
|
@ -80,11 +84,46 @@ class Connection(object): |
|
|
|
|
|
|
|
def put_message(self, message): |
|
|
|
self.send_queue.put(message) |
|
|
|
|
|
|
|
def _new_socket(self): |
|
|
|
raise TypeError("{}: _new_socket is unimplemented".format(self.__class__.__name__)) |
|
|
|
|
|
|
|
def _connect(self, server_addr): |
|
|
|
raise TypeError("{}: _connect is unimplemented".format(self.__class__.__name__)) |
|
|
|
|
|
|
|
def _reader_loop(self): |
|
|
|
raise TypeError("{}: _reader_loop is unimplemented".format(self.__class__.__name__)) |
|
|
|
|
|
|
|
def _writer_loop(self): |
|
|
|
raise TypeError("{}: _writer_loop is unimplemented".format(self.__class__.__name__)) |
|
|
|
|
|
|
|
class TCPConnection(Connection): |
|
|
|
|
|
|
|
MAGIC = b'VT01' |
|
|
|
FMT = '<I4s' |
|
|
|
FMT_SIZE = struct.calcsize(FMT) |
|
|
|
|
|
|
|
def _new_socket(self): |
|
|
|
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
|
|
|
|
|
|
|
def _connect(self, server_addr): |
|
|
|
self.socket.connect(server_addr) |
|
|
|
logger.debug("Connected.") |
|
|
|
self.event_connected.set() |
|
|
|
|
|
|
|
def _read_data(self): |
|
|
|
try: |
|
|
|
return self.socket.recv(16384) |
|
|
|
except socket.error: |
|
|
|
return '' |
|
|
|
|
|
|
|
def _write_data(self, data): |
|
|
|
self.socket.sendall(data) |
|
|
|
|
|
|
|
def _writer_loop(self): |
|
|
|
while True: |
|
|
|
message = self.send_queue.get() |
|
|
|
packet = struct.pack(Connection.FMT, len(message), Connection.MAGIC) + message |
|
|
|
packet = struct.pack(TCPConnection.FMT, len(message), TCPConnection.MAGIC) + message |
|
|
|
try: |
|
|
|
self._write_data(packet) |
|
|
|
except: |
|
|
@ -108,13 +147,13 @@ class Connection(object): |
|
|
|
self._read_packets() |
|
|
|
|
|
|
|
def _read_packets(self): |
|
|
|
header_size = Connection.FMT_SIZE |
|
|
|
header_size = TCPConnection.FMT_SIZE |
|
|
|
buf = self._readbuf |
|
|
|
|
|
|
|
while len(buf) > header_size: |
|
|
|
message_length, magic = struct.unpack_from(Connection.FMT, buf) |
|
|
|
message_length, magic = struct.unpack_from(TCPConnection.FMT, buf) |
|
|
|
|
|
|
|
if magic != Connection.MAGIC: |
|
|
|
if magic != TCPConnection.MAGIC: |
|
|
|
logger.debug("invalid magic, got %s" % repr(magic)) |
|
|
|
self.disconnect() |
|
|
|
return |
|
|
@ -131,33 +170,109 @@ class Connection(object): |
|
|
|
|
|
|
|
self._readbuf = buf |
|
|
|
|
|
|
|
class WebsocketConnection(Connection): |
|
|
|
|
|
|
|
def __init__(self): |
|
|
|
super(WebsocketConnection, self).__init__() |
|
|
|
self.ws = WSConnection(ConnectionType.CLIENT) |
|
|
|
self.ssl_ctx = ssl.create_default_context(cafile=certifi.where()) |
|
|
|
self.event_wsdisconnected = event.Event() |
|
|
|
|
|
|
|
class TCPConnection(Connection): |
|
|
|
def _new_socket(self): |
|
|
|
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
|
|
|
self.raw_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
|
|
|
|
|
|
|
def _connect(self, server_addr): |
|
|
|
self.socket.connect(server_addr) |
|
|
|
|
|
|
|
def _read_data(self): |
|
|
|
try: |
|
|
|
return self.socket.recv(16384) |
|
|
|
except socket.error: |
|
|
|
return '' |
|
|
|
host, port = server_addr |
|
|
|
|
|
|
|
for res in socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM): |
|
|
|
try: |
|
|
|
# tcp socket |
|
|
|
_, _, _, _, sa = res |
|
|
|
self.raw_socket.connect(sa) |
|
|
|
self.socket = self.ssl_ctx.wrap_socket(self.raw_socket, server_hostname=host) |
|
|
|
# websocket |
|
|
|
ws_host = ':'.join(map(str,server_addr)) |
|
|
|
ws_send = self.ws.send(wsevents.Request(host=ws_host, target="/cmsocket/")) |
|
|
|
self.socket.sendall(ws_send) |
|
|
|
return |
|
|
|
except socket.error: |
|
|
|
if self.socket is not None: |
|
|
|
self.socket.close() |
|
|
|
|
|
|
|
def _writer_loop(self): |
|
|
|
while True: |
|
|
|
message = self.send_queue.get() |
|
|
|
try: |
|
|
|
logger.debug("sending message of length {}".format(len(message))) |
|
|
|
self.socket.sendall(self.ws.send(wsevents.Message(data=message))) |
|
|
|
except: |
|
|
|
logger.debug("Connection error (writer).") |
|
|
|
self.disconnect() |
|
|
|
return |
|
|
|
|
|
|
|
def _reader_loop(self): |
|
|
|
while True: |
|
|
|
rlist, _, _ = gselect([self.socket], [], []) |
|
|
|
|
|
|
|
def _write_data(self, data): |
|
|
|
self.socket.sendall(data) |
|
|
|
if self.socket in rlist: |
|
|
|
|
|
|
|
try: |
|
|
|
data = self.socket.recv(16384) |
|
|
|
except socket.error: |
|
|
|
data = '' |
|
|
|
|
|
|
|
if not data: |
|
|
|
logger.debug("Connection error (reader).") |
|
|
|
# A receive of zero bytes indicates the TCP socket has been closed. We |
|
|
|
# need to pass None to wsproto to update its internal state. |
|
|
|
logger.debug("Received 0 bytes (connection closed)") |
|
|
|
self.ws.receive_data(None) |
|
|
|
# now disconnect |
|
|
|
self.disconnect() |
|
|
|
return |
|
|
|
|
|
|
|
logger.debug("Received {} bytes".format(len(data))) |
|
|
|
self.ws.receive_data(data) |
|
|
|
self._handle_events() |
|
|
|
|
|
|
|
def _handle_events(self): |
|
|
|
for event in self.ws.events(): |
|
|
|
if isinstance(event, wsevents.AcceptConnection): |
|
|
|
logger.debug("WebSocket negotiation complete. Connected.") |
|
|
|
self.event_connected.set() |
|
|
|
elif isinstance(event, wsevents.RejectConnection): |
|
|
|
logger.debug("WebSocket connection was rejected. That's probably not good.") |
|
|
|
elif isinstance(event, wsevents.TextMessage): |
|
|
|
logger.debug("Received websocket text message of length: {}".format(len(event.data))) |
|
|
|
elif isinstance(event, wsevents.BytesMessage): |
|
|
|
logger.debug("Received websocket bytes message of length: {}".format(len(event.data))) |
|
|
|
self.recv_queue.put(event.data) |
|
|
|
elif isinstance(event, wsevents.Pong): |
|
|
|
logger.debug("Received pong: {}".format(repr(event.payload))) |
|
|
|
elif isinstance(event, wsevents.CloseConnection): |
|
|
|
logger.debug('Connection closed: code={} reason={}'.format( |
|
|
|
event.code, event.reason |
|
|
|
)) |
|
|
|
if self.ws.state == ConnectionState.REMOTE_CLOSING: |
|
|
|
self.socket.send(self.ws.send(event.response())) |
|
|
|
self.event_wsdisconnected.set() |
|
|
|
else: |
|
|
|
raise TypeError("Do not know how to handle event: {}".format((event))) |
|
|
|
|
|
|
|
def disconnect(self): |
|
|
|
self.event_wsdisconnected.clear() |
|
|
|
|
|
|
|
# WebSocket closing handshake |
|
|
|
if self.ws.state == ConnectionState.OPEN: |
|
|
|
logger.debug("Disconnect called. Sending CloseConnection message.") |
|
|
|
self.socket.sendall(self.ws.send(wsevents.CloseConnection(code=1000, reason="sample reason"))) |
|
|
|
self.socket.shutdown(socket.SHUT_WR) |
|
|
|
# wait for notification from _reader_loop that the closing response was received |
|
|
|
self.event_wsdisconnected.wait() |
|
|
|
|
|
|
|
super(WebsocketConnection, self).disconnect() |
|
|
|
|
|
|
|
class UDPConnection(Connection): |
|
|
|
def _new_socket(self): |
|
|
|
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) |
|
|
|
|
|
|
|
def _connect(self, server_addr): |
|
|
|
pass |
|
|
|
|
|
|
|
def _read_data(self): |
|
|
|
pass |
|
|
|
|
|
|
|
def _write_data(self, data): |
|
|
|
pass |
|
|
|