pythonhacktoberfeststeamauthenticationauthenticatorsteam-authenticatorsteam-clientsteam-guard-codessteam-websteamworksvalvewebapi
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
278 lines
9.0 KiB
278 lines
9.0 KiB
import struct
|
|
import logging
|
|
|
|
import gevent
|
|
from gevent import socket
|
|
from gevent import queue
|
|
from gevent import event
|
|
from gevent.select import select as gselect
|
|
|
|
import ssl
|
|
import certifi
|
|
|
|
from wsproto import WSConnection, events as wsevents
|
|
from wsproto.connection import ConnectionType, ConnectionState
|
|
|
|
logger = logging.getLogger("Connection")
|
|
|
|
|
|
class Connection(object):
|
|
|
|
def __init__(self):
|
|
self.socket = None
|
|
self.connected = False
|
|
self.server_addr = None
|
|
|
|
self._reader = None
|
|
self._writer = None
|
|
self._readbuf = b''
|
|
self.send_queue = queue.Queue()
|
|
self.recv_queue = queue.Queue()
|
|
|
|
self.event_connected = event.Event()
|
|
|
|
@property
|
|
def local_address(self):
|
|
return self.socket.getsockname()[0]
|
|
|
|
def connect(self, server_addr):
|
|
self._new_socket()
|
|
|
|
logger.debug("Attempting connection to %s", str(server_addr))
|
|
|
|
try:
|
|
self._connect(server_addr)
|
|
except socket.error:
|
|
return False
|
|
|
|
self.server_addr = server_addr
|
|
self.recv_queue.queue.clear()
|
|
|
|
self._reader = gevent.spawn(self._reader_loop)
|
|
self._writer = gevent.spawn(self._writer_loop)
|
|
|
|
# how this gets set is implementation dependent
|
|
self.event_connected.wait(timeout=10)
|
|
|
|
return True
|
|
|
|
def disconnect(self):
|
|
if not self.event_connected.is_set():
|
|
return
|
|
self.event_connected.clear()
|
|
|
|
self.server_addr = None
|
|
|
|
if self._reader:
|
|
self._reader.kill(block=False)
|
|
self._reader = None
|
|
if self._writer:
|
|
self._writer.kill(block=False)
|
|
self._writer = None
|
|
|
|
self._readbuf = b''
|
|
self.send_queue.queue.clear()
|
|
self.recv_queue.queue.clear()
|
|
self.recv_queue.put(StopIteration)
|
|
|
|
self.socket.close()
|
|
|
|
logger.debug("Disconnected.")
|
|
|
|
def __iter__(self):
|
|
return self.recv_queue
|
|
|
|
def put_message(self, message):
|
|
self.send_queue.put(message)
|
|
|
|
def _new_socket(self):
|
|
raise TypeError("{}: _new_socket is unimplemented".format(self.__class__.__name__))
|
|
|
|
def _connect(self, server_addr):
|
|
raise TypeError("{}: _connect is unimplemented".format(self.__class__.__name__))
|
|
|
|
def _reader_loop(self):
|
|
raise TypeError("{}: _reader_loop is unimplemented".format(self.__class__.__name__))
|
|
|
|
def _writer_loop(self):
|
|
raise TypeError("{}: _writer_loop is unimplemented".format(self.__class__.__name__))
|
|
|
|
class TCPConnection(Connection):
|
|
|
|
MAGIC = b'VT01'
|
|
FMT = '<I4s'
|
|
FMT_SIZE = struct.calcsize(FMT)
|
|
|
|
def _new_socket(self):
|
|
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
|
|
def _connect(self, server_addr):
|
|
self.socket.connect(server_addr)
|
|
logger.debug("Connected.")
|
|
self.event_connected.set()
|
|
|
|
def _read_data(self):
|
|
try:
|
|
return self.socket.recv(16384)
|
|
except socket.error:
|
|
return ''
|
|
|
|
def _write_data(self, data):
|
|
self.socket.sendall(data)
|
|
|
|
def _writer_loop(self):
|
|
while True:
|
|
message = self.send_queue.get()
|
|
packet = struct.pack(TCPConnection.FMT, len(message), TCPConnection.MAGIC) + message
|
|
try:
|
|
self._write_data(packet)
|
|
except:
|
|
logger.debug("Connection error (writer).")
|
|
self.disconnect()
|
|
return
|
|
|
|
def _reader_loop(self):
|
|
while True:
|
|
rlist, _, _ = gselect([self.socket], [], [])
|
|
|
|
if self.socket in rlist:
|
|
data = self._read_data()
|
|
|
|
if not data:
|
|
logger.debug("Connection error (reader).")
|
|
self.disconnect()
|
|
return
|
|
|
|
self._readbuf += data
|
|
self._read_packets()
|
|
|
|
def _read_packets(self):
|
|
header_size = TCPConnection.FMT_SIZE
|
|
buf = self._readbuf
|
|
|
|
while len(buf) > header_size:
|
|
message_length, magic = struct.unpack_from(TCPConnection.FMT, buf)
|
|
|
|
if magic != TCPConnection.MAGIC:
|
|
logger.debug("invalid magic, got %s" % repr(magic))
|
|
self.disconnect()
|
|
return
|
|
|
|
packet_length = header_size + message_length
|
|
|
|
if len(buf) < packet_length:
|
|
return
|
|
|
|
message = buf[header_size:packet_length]
|
|
buf = buf[packet_length:]
|
|
|
|
self.recv_queue.put(message)
|
|
|
|
self._readbuf = buf
|
|
|
|
class WebsocketConnection(Connection):
|
|
|
|
def __init__(self):
|
|
super(WebsocketConnection, self).__init__()
|
|
self.ws = WSConnection(ConnectionType.CLIENT)
|
|
self.ssl_ctx = ssl.create_default_context(cafile=certifi.where())
|
|
self.event_wsdisconnected = event.Event()
|
|
|
|
def _new_socket(self):
|
|
self.raw_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
|
|
def _connect(self, server_addr):
|
|
|
|
host, port = server_addr
|
|
|
|
for res in socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM):
|
|
try:
|
|
# tcp socket
|
|
_, _, _, _, sa = res
|
|
self.raw_socket.connect(sa)
|
|
self.socket = self.ssl_ctx.wrap_socket(self.raw_socket, server_hostname=host)
|
|
# websocket
|
|
ws_host = ':'.join(map(str,server_addr))
|
|
ws_send = self.ws.send(wsevents.Request(host=ws_host, target="/cmsocket/"))
|
|
self.socket.sendall(ws_send)
|
|
return
|
|
except socket.error:
|
|
if self.socket is not None:
|
|
self.socket.close()
|
|
|
|
def _writer_loop(self):
|
|
while True:
|
|
message = self.send_queue.get()
|
|
try:
|
|
logger.debug("sending message of length {}".format(len(message)))
|
|
self.socket.sendall(self.ws.send(wsevents.Message(data=message)))
|
|
except:
|
|
logger.debug("Connection error (writer).")
|
|
self.disconnect()
|
|
return
|
|
|
|
def _reader_loop(self):
|
|
while True:
|
|
rlist, _, _ = gselect([self.socket], [], [])
|
|
|
|
if self.socket in rlist:
|
|
|
|
try:
|
|
data = self.socket.recv(16384)
|
|
except socket.error:
|
|
data = ''
|
|
|
|
if not data:
|
|
logger.debug("Connection error (reader).")
|
|
# A receive of zero bytes indicates the TCP socket has been closed. We
|
|
# need to pass None to wsproto to update its internal state.
|
|
logger.debug("Received 0 bytes (connection closed)")
|
|
self.ws.receive_data(None)
|
|
# now disconnect
|
|
self.disconnect()
|
|
return
|
|
|
|
logger.debug("Received {} bytes".format(len(data)))
|
|
self.ws.receive_data(data)
|
|
self._handle_events()
|
|
|
|
def _handle_events(self):
|
|
for event in self.ws.events():
|
|
if isinstance(event, wsevents.AcceptConnection):
|
|
logger.debug("WebSocket negotiation complete. Connected.")
|
|
self.event_connected.set()
|
|
elif isinstance(event, wsevents.RejectConnection):
|
|
logger.debug("WebSocket connection was rejected. That's probably not good.")
|
|
elif isinstance(event, wsevents.TextMessage):
|
|
logger.debug("Received websocket text message of length: {}".format(len(event.data)))
|
|
elif isinstance(event, wsevents.BytesMessage):
|
|
logger.debug("Received websocket bytes message of length: {}".format(len(event.data)))
|
|
self.recv_queue.put(event.data)
|
|
elif isinstance(event, wsevents.Pong):
|
|
logger.debug("Received pong: {}".format(repr(event.payload)))
|
|
elif isinstance(event, wsevents.CloseConnection):
|
|
logger.debug('Connection closed: code={} reason={}'.format(
|
|
event.code, event.reason
|
|
))
|
|
if self.ws.state == ConnectionState.REMOTE_CLOSING:
|
|
self.socket.send(self.ws.send(event.response()))
|
|
self.event_wsdisconnected.set()
|
|
else:
|
|
raise TypeError("Do not know how to handle event: {}".format((event)))
|
|
|
|
def disconnect(self):
|
|
self.event_wsdisconnected.clear()
|
|
|
|
# WebSocket closing handshake
|
|
if self.ws.state == ConnectionState.OPEN:
|
|
logger.debug("Disconnect called. Sending CloseConnection message.")
|
|
self.socket.sendall(self.ws.send(wsevents.CloseConnection(code=1000, reason="sample reason")))
|
|
self.socket.shutdown(socket.SHUT_WR)
|
|
# wait for notification from _reader_loop that the closing response was received
|
|
self.event_wsdisconnected.wait()
|
|
|
|
super(WebsocketConnection, self).disconnect()
|
|
|
|
class UDPConnection(Connection):
|
|
def _new_socket(self):
|
|
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
|