You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

156 lines
5.1 KiB

import asyncio
import logging
from asyncio.streams import StreamReader, StreamWriter
from time import time
import socket
from logger import logger
MAGIC0 = 0x94
MAGIC1 = 0xC3
READ_CHUNK = 4096
def encode_frame(payload):
n = len(payload)
header = bytes([MAGIC0, MAGIC1, (n >> 8) & 0xFF, n & 0xFF])
return header + payload
class TCPTransport:
def __init__(self, host, port=4403, alive_pool_connect = 60):
self.host = host
self.port = int(port)
self.reader: StreamReader = None
self.writer = None
self._recv_q = asyncio.Queue()
self._buf = bytearray()
self._reader_task = None
self._error = None
self._closing = False
self.socket_start = time()
self.alive_pool_connect = alive_pool_connect
async def start(self):
self._closing = False
self._error = None
self._buf = bytearray()
self._recv_q = asyncio.Queue()
try:
reader, writer = await asyncio.open_connection(host = self.host, port = self.port)
sock = writer.transport.get_extra_info('socket')
if sock:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
self.socket_start = time()
except Exception as e:
self._error = e
raise
self.reader = reader
self.writer = writer
self._reader_task = asyncio.create_task(self._reader_loop(), name="tcp-reader")
self._alive_task = asyncio.create_task(self._alive_loop(), name="alive-loop")
logger.debug("TCPTransport.start: connected to %s:%s", self.host, self.port)
async def _alive_loop(self):
try:
while True:
await self.send(b"0")
await asyncio.sleep(self.alive_pool_connect)
except asyncio.CancelledError:
pass
except:
logger.error("Tcp socket is down, close connection")
await self.close()
async def _reader_loop(self):
assert self.reader is not None
r = self.reader
try:
while True:
data = await r.read(READ_CHUNK)
if not data:
if self._closing:
return
self._error = ConnectionError("tcp connection closed")
logger.error("tcp read error: connection closed")
try:
self._recv_q.put_nowait(None)
except Exception:
pass
return
self._buf.extend(data)
while True:
start = -1
for i in range(len(self._buf) - 1):
if self._buf[i] == MAGIC0 and self._buf[i + 1] == MAGIC1:
start = i
break
if start == -1:
if len(self._buf) > 1:
self._buf[:] = self._buf[-1:]
break
if start > 0:
del self._buf[:start]
if len(self._buf) < 4:
break
length = (self._buf[2] << 8) | self._buf[3]
total = 4 + length
if len(self._buf) < total:
break
payload = bytes(self._buf[4:total])
await self._recv_q.put(payload)
del self._buf[:total]
except asyncio.CancelledError:
pass
except Exception as e:
if self._closing:
return
self._error = e
logger.error("tcp read error: %s", e)
try:
self._recv_q.put_nowait(None)
except Exception:
pass
async def send(self, payload):
if self.writer is None:
return
if not isinstance(payload, (bytes, bytearray)):
raise TypeError("payload must be bytes")
try:
frame = encode_frame(bytes(payload))
self.writer.write(frame)
await self.writer.drain()
except Exception as e:
if self._closing:
return
self._error = e
logger.error("tcp write error: %s", e)
try:
self._recv_q.put_nowait(None)
except Exception:
pass
async def recv(self):
item = await self._recv_q.get()
if item is None:
raise ConnectionError(self._error or "tcp transport error")
return item
async def close(self):
self._closing = True
if self._reader_task is not None:
self._reader_task.cancel()
try:
await self._reader_task
except Exception:
pass
self._reader_task = None
if self.writer is not None:
try:
self.writer.close()
await self.writer.wait_closed()
except Exception:
pass
self.writer = None
self.reader = None