4 changed files with 173 additions and 11 deletions
@ -1,2 +1,3 @@ |
|||
testenv/ |
|||
__pycache__/ |
|||
docker-compose.yaml |
|||
@ -1,12 +1,12 @@ |
|||
class logger: |
|||
@staticmethod |
|||
def info(t): |
|||
print("[INFO]", t) |
|||
def info(*t): |
|||
print("[INFO]", " ".join([str(s) for s in t])) |
|||
|
|||
@staticmethod |
|||
def error(t): |
|||
print("[ERROR]", t) |
|||
def error(*t): |
|||
print("[ERROR]", " ".join([str(s) for s in t])) |
|||
|
|||
@staticmethod |
|||
def debug(t): |
|||
print("[DEBUG]", t) |
|||
def debug(*t): |
|||
print("[DEBUG]", " ".join([str(s) for s in t])) |
|||
@ -0,0 +1,156 @@ |
|||
import asyncio |
|||
import logging |
|||
from asyncio.streams import StreamReader, StreamWriter |
|||
from time import time |
|||
|
|||
from logger import logger |
|||
|
|||
MAGIC0 = 0x94 |
|||
MAGIC1 = 0xC3 |
|||
READ_CHUNK = 4096 |
|||
|
|||
|
|||
def encode_frame(payload): |
|||
n = len(payload) |
|||
header = bytes([MAGIC0, MAGIC1, (n >> 8) & 0xFF, n & 0xFF]) |
|||
return header + payload |
|||
|
|||
|
|||
class TCPTransport: |
|||
def __init__(self, host, port=4403): |
|||
self.host = host |
|||
self.port = int(port) |
|||
self.reader: StreamReader = None |
|||
self.writer = None |
|||
self._recv_q = asyncio.Queue() |
|||
self._buf = bytearray() |
|||
self._reader_task = None |
|||
self._error = None |
|||
self._closing = False |
|||
self.socket_start = time() |
|||
|
|||
async def start(self): |
|||
self._closing = False |
|||
self._error = None |
|||
self._buf = bytearray() |
|||
self._recv_q = asyncio.Queue() |
|||
try: |
|||
#sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
|||
#sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) |
|||
reader, writer = await asyncio.open_connection(self.host, self.port) |
|||
self.socket_start = time() |
|||
except Exception as e: |
|||
self._error = e |
|||
raise |
|||
self.reader = reader |
|||
self.writer = writer |
|||
self._reader_task = asyncio.create_task(self._reader_loop(), name="tcp-reader") |
|||
self._alive_task = asyncio.create_task(self._alive_loop(), name="alive-loop") |
|||
logger.debug("TCPTransport.start: connected to %s:%s", self.host, self.port) |
|||
|
|||
async def _alive_loop(self): |
|||
try: |
|||
while True: |
|||
if self.writer: |
|||
#logger.debug("Send alive", int(time() - self.socket_start)) |
|||
self.writer.write(b"0") |
|||
await self.writer.drain() |
|||
await asyncio.sleep(1) |
|||
except asyncio.CancelledError: |
|||
pass |
|||
except: |
|||
logger.error("Tcp socket is down, close connection") |
|||
await self.close() |
|||
|
|||
async def _reader_loop(self): |
|||
assert self.reader is not None |
|||
r = self.reader |
|||
try: |
|||
while True: |
|||
data = await r.read(READ_CHUNK) |
|||
if not data: |
|||
if self._closing: |
|||
return |
|||
self._error = ConnectionError("tcp connection closed") |
|||
logger.error("tcp read error: connection closed") |
|||
try: |
|||
self._recv_q.put_nowait(None) |
|||
except Exception: |
|||
pass |
|||
return |
|||
self._buf.extend(data) |
|||
while True: |
|||
start = -1 |
|||
for i in range(len(self._buf) - 1): |
|||
if self._buf[i] == MAGIC0 and self._buf[i + 1] == MAGIC1: |
|||
start = i |
|||
break |
|||
if start == -1: |
|||
if len(self._buf) > 1: |
|||
self._buf[:] = self._buf[-1:] |
|||
break |
|||
if start > 0: |
|||
del self._buf[:start] |
|||
if len(self._buf) < 4: |
|||
break |
|||
length = (self._buf[2] << 8) | self._buf[3] |
|||
total = 4 + length |
|||
if len(self._buf) < total: |
|||
break |
|||
payload = bytes(self._buf[4:total]) |
|||
await self._recv_q.put(payload) |
|||
del self._buf[:total] |
|||
except asyncio.CancelledError: |
|||
pass |
|||
except Exception as e: |
|||
if self._closing: |
|||
return |
|||
self._error = e |
|||
logger.error("tcp read error: %s", e) |
|||
try: |
|||
self._recv_q.put_nowait(None) |
|||
except Exception: |
|||
pass |
|||
|
|||
async def send(self, payload): |
|||
if self.writer is None: |
|||
return |
|||
if not isinstance(payload, (bytes, bytearray)): |
|||
raise TypeError("payload must be bytes") |
|||
try: |
|||
frame = encode_frame(bytes(payload)) |
|||
self.writer.write(frame) |
|||
await self.writer.drain() |
|||
except Exception as e: |
|||
if self._closing: |
|||
return |
|||
self._error = e |
|||
logger.error("tcp write error: %s", e) |
|||
try: |
|||
self._recv_q.put_nowait(None) |
|||
except Exception: |
|||
pass |
|||
|
|||
async def recv(self): |
|||
item = await self._recv_q.get() |
|||
if item is None: |
|||
raise ConnectionError(self._error or "tcp transport error") |
|||
return item |
|||
|
|||
async def close(self): |
|||
self._closing = True |
|||
if self._reader_task is not None: |
|||
self._reader_task.cancel() |
|||
try: |
|||
await self._reader_task |
|||
except Exception: |
|||
pass |
|||
self._reader_task = None |
|||
if self.writer is not None: |
|||
try: |
|||
self.writer.close() |
|||
await self.writer.wait_closed() |
|||
except Exception: |
|||
pass |
|||
self.writer = None |
|||
self.reader = None |
|||
Loading…
Reference in new issue