You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
156 lines
5.1 KiB
156 lines
5.1 KiB
import asyncio
|
|
import logging
|
|
from asyncio.streams import StreamReader, StreamWriter
|
|
from time import time
|
|
import socket
|
|
|
|
from logger import logger
|
|
|
|
MAGIC0 = 0x94
|
|
MAGIC1 = 0xC3
|
|
READ_CHUNK = 4096
|
|
|
|
|
|
def encode_frame(payload):
|
|
n = len(payload)
|
|
header = bytes([MAGIC0, MAGIC1, (n >> 8) & 0xFF, n & 0xFF])
|
|
return header + payload
|
|
|
|
|
|
class TCPTransport:
|
|
def __init__(self, host, port=4403, alive_pool_connect = 60):
|
|
self.host = host
|
|
self.port = int(port)
|
|
self.reader: StreamReader = None
|
|
self.writer = None
|
|
self._recv_q = asyncio.Queue()
|
|
self._buf = bytearray()
|
|
self._reader_task = None
|
|
self._error = None
|
|
self._closing = False
|
|
self.socket_start = time()
|
|
self.alive_pool_connect = alive_pool_connect
|
|
|
|
async def start(self):
|
|
self._closing = False
|
|
self._error = None
|
|
self._buf = bytearray()
|
|
self._recv_q = asyncio.Queue()
|
|
try:
|
|
reader, writer = await asyncio.open_connection(host = self.host, port = self.port)
|
|
sock = writer.transport.get_extra_info('socket')
|
|
if sock:
|
|
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
|
|
self.socket_start = time()
|
|
except Exception as e:
|
|
self._error = e
|
|
raise
|
|
self.reader = reader
|
|
self.writer = writer
|
|
self._reader_task = asyncio.create_task(self._reader_loop(), name="tcp-reader")
|
|
self._alive_task = asyncio.create_task(self._alive_loop(), name="alive-loop")
|
|
logger.debug("TCPTransport.start: connected to %s:%s", self.host, self.port)
|
|
|
|
async def _alive_loop(self):
|
|
try:
|
|
while True:
|
|
await self.send(b"0")
|
|
await asyncio.sleep(self.alive_pool_connect)
|
|
except asyncio.CancelledError:
|
|
pass
|
|
except:
|
|
logger.error("Tcp socket is down, close connection")
|
|
await self.close()
|
|
|
|
async def _reader_loop(self):
|
|
assert self.reader is not None
|
|
r = self.reader
|
|
try:
|
|
while True:
|
|
data = await r.read(READ_CHUNK)
|
|
if not data:
|
|
if self._closing:
|
|
return
|
|
self._error = ConnectionError("tcp connection closed")
|
|
logger.error("tcp read error: connection closed")
|
|
try:
|
|
self._recv_q.put_nowait(None)
|
|
except Exception:
|
|
pass
|
|
return
|
|
self._buf.extend(data)
|
|
while True:
|
|
start = -1
|
|
for i in range(len(self._buf) - 1):
|
|
if self._buf[i] == MAGIC0 and self._buf[i + 1] == MAGIC1:
|
|
start = i
|
|
break
|
|
if start == -1:
|
|
if len(self._buf) > 1:
|
|
self._buf[:] = self._buf[-1:]
|
|
break
|
|
if start > 0:
|
|
del self._buf[:start]
|
|
if len(self._buf) < 4:
|
|
break
|
|
length = (self._buf[2] << 8) | self._buf[3]
|
|
total = 4 + length
|
|
if len(self._buf) < total:
|
|
break
|
|
payload = bytes(self._buf[4:total])
|
|
await self._recv_q.put(payload)
|
|
del self._buf[:total]
|
|
except asyncio.CancelledError:
|
|
pass
|
|
except Exception as e:
|
|
if self._closing:
|
|
return
|
|
self._error = e
|
|
logger.error("tcp read error: %s", e)
|
|
try:
|
|
self._recv_q.put_nowait(None)
|
|
except Exception:
|
|
pass
|
|
|
|
async def send(self, payload):
|
|
if self.writer is None:
|
|
return
|
|
if not isinstance(payload, (bytes, bytearray)):
|
|
raise TypeError("payload must be bytes")
|
|
try:
|
|
frame = encode_frame(bytes(payload))
|
|
self.writer.write(frame)
|
|
await self.writer.drain()
|
|
except Exception as e:
|
|
if self._closing:
|
|
return
|
|
self._error = e
|
|
logger.error("tcp write error: %s", e)
|
|
try:
|
|
self._recv_q.put_nowait(None)
|
|
except Exception:
|
|
pass
|
|
|
|
async def recv(self):
|
|
item = await self._recv_q.get()
|
|
if item is None:
|
|
raise ConnectionError(self._error or "tcp transport error")
|
|
return item
|
|
|
|
async def close(self):
|
|
self._closing = True
|
|
if self._reader_task is not None:
|
|
self._reader_task.cancel()
|
|
try:
|
|
await self._reader_task
|
|
except Exception:
|
|
pass
|
|
self._reader_task = None
|
|
if self.writer is not None:
|
|
try:
|
|
self.writer.close()
|
|
await self.writer.wait_closed()
|
|
except Exception:
|
|
pass
|
|
self.writer = None
|
|
self.reader = None
|