|
|
|
@ -15,9 +15,42 @@ pub(crate) const UDPGW_LENGTH_FIELD_SIZE: usize = std::mem::size_of::<u16>(); |
|
|
|
pub(crate) const UDPGW_MAX_CONNECTIONS: u16 = 100; |
|
|
|
pub(crate) const UDPGW_KEEPALIVE_TIME: tokio::time::Duration = std::time::Duration::from_secs(10); |
|
|
|
|
|
|
|
pub const UDPGW_FLAG_KEEPALIVE: u8 = 0x01; |
|
|
|
pub const UDPGW_FLAG_ERR: u8 = 0x20; |
|
|
|
pub const UDPGW_FLAG_DATA: u8 = 0x02; |
|
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] |
|
|
|
pub struct UdpFlag(pub u8); |
|
|
|
|
|
|
|
impl UdpFlag { |
|
|
|
pub const ZERO: UdpFlag = UdpFlag(0x00); |
|
|
|
pub const KEEPALIVE: UdpFlag = UdpFlag(0x01); |
|
|
|
pub const ERR: UdpFlag = UdpFlag(0x20); |
|
|
|
pub const DATA: UdpFlag = UdpFlag(0x02); |
|
|
|
} |
|
|
|
|
|
|
|
impl std::fmt::Display for UdpFlag { |
|
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
|
|
|
let flag = match self.0 { |
|
|
|
0x00 => "ZERO", |
|
|
|
0x01 => "KEEPALIVE", |
|
|
|
0x20 => "ERR", |
|
|
|
0x02 => "DATA", |
|
|
|
n => return write!(f, "Unknown UdpFlag(0x{:02X})", n), |
|
|
|
}; |
|
|
|
write!(f, "UdpFlag({})", flag) |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
impl std::ops::BitAnd for UdpFlag { |
|
|
|
type Output = Self; |
|
|
|
fn bitand(self, rhs: Self) -> Self::Output { |
|
|
|
UdpFlag(self.0 & rhs.0) |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
impl std::ops::BitOr for UdpFlag { |
|
|
|
type Output = Self; |
|
|
|
fn bitor(self, rhs: Self) -> Self::Output { |
|
|
|
UdpFlag(self.0 | rhs.0) |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
static TCP_COUNTER: std::sync::atomic::AtomicU32 = std::sync::atomic::AtomicU32::new(0); |
|
|
|
|
|
|
|
@ -98,7 +131,7 @@ impl TryFrom<&[u8]> for Packet { |
|
|
|
return Err(std::io::ErrorKind::InvalidData.into()); |
|
|
|
} |
|
|
|
let header = UdpgwHeader::retrieve_from_stream(&mut iter)?; |
|
|
|
let address = if header.flags & UDPGW_FLAG_DATA != 0 { |
|
|
|
let address = if header.flags & UdpFlag::DATA != UdpFlag::ZERO { |
|
|
|
Some(Address::retrieve_from_stream(&mut iter)?) |
|
|
|
} else { |
|
|
|
None |
|
|
|
@ -114,11 +147,11 @@ impl Packet { |
|
|
|
} |
|
|
|
|
|
|
|
pub fn build_keepalive_packet(conn_id: u16) -> Self { |
|
|
|
Packet::new(UdpgwHeader::new(UDPGW_FLAG_KEEPALIVE, conn_id), None, &[]) |
|
|
|
Packet::new(UdpgwHeader::new(UdpFlag::KEEPALIVE, conn_id), None, &[]) |
|
|
|
} |
|
|
|
|
|
|
|
pub fn build_error_packet(conn_id: u16) -> Self { |
|
|
|
Packet::new(UdpgwHeader::new(UDPGW_FLAG_ERR, conn_id), None, &[]) |
|
|
|
Packet::new(UdpgwHeader::new(UdpFlag::ERR, conn_id), None, &[]) |
|
|
|
} |
|
|
|
|
|
|
|
pub fn build_packet_from_address(conn_id: u16, remote_addr: &Address, data: &[u8]) -> std::io::Result<Self> { |
|
|
|
@ -132,7 +165,7 @@ impl Packet { |
|
|
|
|
|
|
|
pub fn build_ip_packet(conn_id: u16, remote_addr: SocketAddr, data: &[u8]) -> Self { |
|
|
|
let addr: Address = remote_addr.into(); |
|
|
|
Packet::new(UdpgwHeader::new(UDPGW_FLAG_DATA, conn_id), Some(addr), data) |
|
|
|
Packet::new(UdpgwHeader::new(UdpFlag::DATA, conn_id), Some(addr), data) |
|
|
|
} |
|
|
|
|
|
|
|
pub fn build_domain_packet(conn_id: u16, port: u16, domain: &str, data: &[u8]) -> std::io::Result<Self> { |
|
|
|
@ -140,7 +173,7 @@ impl Packet { |
|
|
|
return Err(std::io::ErrorKind::InvalidInput.into()); |
|
|
|
} |
|
|
|
let addr = Address::from((domain, port)); |
|
|
|
Ok(Packet::new(UdpgwHeader::new(UDPGW_FLAG_DATA, conn_id), Some(addr), data)) |
|
|
|
Ok(Packet::new(UdpgwHeader::new(UdpFlag::DATA, conn_id), Some(addr), data)) |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@ -154,7 +187,7 @@ impl StreamOperation for Packet { |
|
|
|
stream.read_exact(&mut buf)?; |
|
|
|
let length = u16::from_be_bytes(buf) as usize; |
|
|
|
let header = UdpgwHeader::retrieve_from_stream(stream)?; |
|
|
|
let address = if header.flags & UDPGW_FLAG_DATA != 0 { |
|
|
|
let address = if header.flags & UdpFlag::DATA == UdpFlag::DATA { |
|
|
|
Some(Address::retrieve_from_stream(stream)?) |
|
|
|
} else { |
|
|
|
None |
|
|
|
@ -194,7 +227,7 @@ impl AsyncStreamOperation for Packet { |
|
|
|
r.read_exact(&mut buf).await?; |
|
|
|
let length = u16::from_be_bytes(buf) as usize; |
|
|
|
let header = UdpgwHeader::retrieve_from_async_stream(r).await?; |
|
|
|
let address = if header.flags & UDPGW_FLAG_DATA != 0 { |
|
|
|
let address = if header.flags & UdpFlag::DATA == UdpFlag::DATA { |
|
|
|
Some(Address::retrieve_from_async_stream(r).await?) |
|
|
|
} else { |
|
|
|
None |
|
|
|
@ -211,14 +244,14 @@ impl AsyncStreamOperation for Packet { |
|
|
|
|
|
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] |
|
|
|
pub struct UdpgwHeader { |
|
|
|
pub flags: u8, |
|
|
|
pub flags: UdpFlag, |
|
|
|
pub conn_id: u16, |
|
|
|
} |
|
|
|
|
|
|
|
impl std::fmt::Display for UdpgwHeader { |
|
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
|
|
|
let id = self.conn_id; |
|
|
|
write!(f, "flags: 0x{:02x}, conn_id: {}", self.flags, id) |
|
|
|
write!(f, "flags: {}, conn_id: {}", self.flags, id) |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@ -257,7 +290,7 @@ impl AsyncStreamOperation for UdpgwHeader { |
|
|
|
} |
|
|
|
|
|
|
|
impl UdpgwHeader { |
|
|
|
pub fn new(flags: u8, conn_id: u16) -> Self { |
|
|
|
pub fn new(flags: UdpFlag, conn_id: u16) -> Self { |
|
|
|
UdpgwHeader { flags, conn_id } |
|
|
|
} |
|
|
|
|
|
|
|
@ -274,14 +307,14 @@ impl TryFrom<&[u8]> for UdpgwHeader { |
|
|
|
return Err(std::io::ErrorKind::InvalidData.into()); |
|
|
|
} |
|
|
|
let conn_id = u16::from_be_bytes([value[1], value[2]]); |
|
|
|
Ok(UdpgwHeader { flags: value[0], conn_id }) |
|
|
|
Ok(UdpgwHeader::new(UdpFlag(value[0]), conn_id)) |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
impl From<&UdpgwHeader> for Vec<u8> { |
|
|
|
fn from(header: &UdpgwHeader) -> Vec<u8> { |
|
|
|
let mut bytes = vec![0; header.len()]; |
|
|
|
bytes[0] = header.flags; |
|
|
|
bytes[0] = header.flags.0; |
|
|
|
bytes[1..3].copy_from_slice(&header.conn_id.to_be_bytes()); |
|
|
|
bytes |
|
|
|
} |
|
|
|
@ -296,14 +329,17 @@ pub(crate) enum UdpGwResponse { |
|
|
|
Data(Packet), |
|
|
|
} |
|
|
|
|
|
|
|
static SERIAL_NUMBER: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(1); |
|
|
|
|
|
|
|
#[derive(Debug)] |
|
|
|
pub(crate) struct UdpGwClientStream { |
|
|
|
local_addr: String, |
|
|
|
local_addr: SocketAddr, |
|
|
|
writer: Option<OwnedWriteHalf>, |
|
|
|
reader: Option<OwnedReadHalf>, |
|
|
|
conn_id: u16, |
|
|
|
closed: bool, |
|
|
|
last_activity: std::time::Instant, |
|
|
|
serial_number: u64, |
|
|
|
} |
|
|
|
|
|
|
|
impl Drop for UdpGwClientStream { |
|
|
|
@ -333,34 +369,33 @@ impl UdpGwClientStream { |
|
|
|
self.writer.take() |
|
|
|
} |
|
|
|
|
|
|
|
pub fn local_addr(&self) -> &String { |
|
|
|
&self.local_addr |
|
|
|
pub fn local_addr(&self) -> SocketAddr { |
|
|
|
self.local_addr |
|
|
|
} |
|
|
|
|
|
|
|
pub fn update_activity(&mut self) { |
|
|
|
self.last_activity = std::time::Instant::now(); |
|
|
|
} |
|
|
|
|
|
|
|
pub fn is_closed(&mut self) -> bool { |
|
|
|
pub fn is_closed(&self) -> bool { |
|
|
|
self.closed |
|
|
|
} |
|
|
|
|
|
|
|
pub fn id(&mut self) -> u16 { |
|
|
|
self.conn_id |
|
|
|
pub fn serial_number(&self) -> u64 { |
|
|
|
self.serial_number |
|
|
|
} |
|
|
|
|
|
|
|
pub fn new_id(&mut self) -> u16 { |
|
|
|
pub fn new_packet_id(&mut self) -> u16 { |
|
|
|
self.conn_id += 1; |
|
|
|
self.conn_id |
|
|
|
} |
|
|
|
|
|
|
|
pub fn new(tcp_server_stream: TcpStream) -> Self { |
|
|
|
let default = "0.0.0.0:0".parse::<SocketAddr>().unwrap(); |
|
|
|
let local_addr = tcp_server_stream.local_addr().unwrap_or(default).to_string(); |
|
|
|
let (rx, tx) = tcp_server_stream.into_split(); |
|
|
|
let writer = tx; |
|
|
|
let reader = rx; |
|
|
|
let local_addr = tcp_server_stream.local_addr().unwrap_or(default); |
|
|
|
let (reader, writer) = tcp_server_stream.into_split(); |
|
|
|
TCP_COUNTER.fetch_add(1, Relaxed); |
|
|
|
let serial_number = SERIAL_NUMBER.fetch_add(1, Relaxed); |
|
|
|
UdpGwClientStream { |
|
|
|
local_addr, |
|
|
|
reader: Some(reader), |
|
|
|
@ -368,6 +403,7 @@ impl UdpGwClientStream { |
|
|
|
last_activity: std::time::Instant::now(), |
|
|
|
closed: false, |
|
|
|
conn_id: 0, |
|
|
|
serial_number, |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
@ -378,18 +414,18 @@ pub(crate) struct UdpGwClient { |
|
|
|
max_connections: u16, |
|
|
|
udp_timeout: u64, |
|
|
|
keepalive_time: Duration, |
|
|
|
server_addr: SocketAddr, |
|
|
|
udpgw_server: SocketAddr, |
|
|
|
server_connections: Mutex<VecDeque<UdpGwClientStream>>, |
|
|
|
} |
|
|
|
|
|
|
|
impl UdpGwClient { |
|
|
|
pub fn new(udp_mtu: u16, max_connections: u16, keepalive_time: Duration, udp_timeout: u64, server_addr: SocketAddr) -> Self { |
|
|
|
pub fn new(udp_mtu: u16, max_connections: u16, keepalive_time: Duration, udp_timeout: u64, udpgw_server: SocketAddr) -> Self { |
|
|
|
let server_connections = Mutex::new(VecDeque::with_capacity(max_connections as usize)); |
|
|
|
UdpGwClient { |
|
|
|
udp_mtu, |
|
|
|
max_connections, |
|
|
|
udp_timeout, |
|
|
|
server_addr, |
|
|
|
udpgw_server, |
|
|
|
keepalive_time, |
|
|
|
server_connections, |
|
|
|
} |
|
|
|
@ -407,22 +443,17 @@ impl UdpGwClient { |
|
|
|
TCP_COUNTER.load(Relaxed) >= self.max_connections as u32 |
|
|
|
} |
|
|
|
|
|
|
|
pub(crate) async fn get_server_connection(&self) -> Option<UdpGwClientStream> { |
|
|
|
pub(crate) async fn pop_server_connection_from_queue(&self) -> Option<UdpGwClientStream> { |
|
|
|
self.server_connections.lock().await.pop_front() |
|
|
|
} |
|
|
|
|
|
|
|
pub(crate) async fn release_server_connection(&self, stream: UdpGwClientStream) { |
|
|
|
pub(crate) async fn store_server_connection(&self, stream: UdpGwClientStream) { |
|
|
|
if self.server_connections.lock().await.len() < self.max_connections as usize { |
|
|
|
self.server_connections.lock().await.push_back(stream); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
pub(crate) async fn release_server_connection_full( |
|
|
|
&self, |
|
|
|
mut stream: UdpGwClientStream, |
|
|
|
reader: OwnedReadHalf, |
|
|
|
writer: OwnedWriteHalf, |
|
|
|
) { |
|
|
|
pub(crate) async fn store_server_connection_full(&self, mut stream: UdpGwClientStream, reader: OwnedReadHalf, writer: OwnedWriteHalf) { |
|
|
|
if self.server_connections.lock().await.len() < self.max_connections as usize { |
|
|
|
stream.set_reader(Some(reader)); |
|
|
|
stream.set_writer(Some(writer)); |
|
|
|
@ -430,42 +461,51 @@ impl UdpGwClient { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
pub(crate) fn get_server_addr(&self) -> SocketAddr { |
|
|
|
self.server_addr |
|
|
|
pub(crate) fn get_udpgw_server_addr(&self) -> SocketAddr { |
|
|
|
self.udpgw_server |
|
|
|
} |
|
|
|
|
|
|
|
/// Heartbeat task asynchronous function to periodically check and maintain the active state of the server connection.
|
|
|
|
pub(crate) async fn heartbeat_task(&self) { |
|
|
|
pub(crate) async fn heartbeat_task(&self) -> std::io::Result<()> { |
|
|
|
loop { |
|
|
|
sleep(self.keepalive_time).await; |
|
|
|
if let Some(mut stream) = self.get_server_connection().await { |
|
|
|
if stream.last_activity.elapsed() < self.keepalive_time { |
|
|
|
self.release_server_connection(stream).await; |
|
|
|
continue; |
|
|
|
} |
|
|
|
let Some(mut stream) = self.pop_server_connection_from_queue().await else { |
|
|
|
continue; |
|
|
|
}; |
|
|
|
|
|
|
|
let Some(mut stream_reader) = stream.get_reader() else { |
|
|
|
continue; |
|
|
|
}; |
|
|
|
|
|
|
|
let Some(mut stream_writer) = stream.get_writer() else { |
|
|
|
continue; |
|
|
|
}; |
|
|
|
let local_addr = stream_writer.local_addr(); |
|
|
|
log::debug!("{:?}:{} send keepalive", local_addr, stream.id()); |
|
|
|
let keepalive_packet: Vec<u8> = Packet::build_keepalive_packet(stream.id()).into(); |
|
|
|
if let Err(e) = stream_writer.write_all(&keepalive_packet).await { |
|
|
|
log::warn!("{:?}:{} send keepalive failed: {}", local_addr, stream.id(), e); |
|
|
|
continue; |
|
|
|
} |
|
|
|
match UdpGwClient::recv_udpgw_packet(self.udp_mtu, 10, &mut stream_reader).await { |
|
|
|
Ok(UdpGwResponse::KeepAlive) => { |
|
|
|
stream.update_activity(); |
|
|
|
self.release_server_connection_full(stream, stream_reader, stream_writer).await; |
|
|
|
} |
|
|
|
Ok(v) => log::warn!("{:?}:{} keepalive unexpected response: {:?}", local_addr, stream.id(), v), |
|
|
|
Err(e) => log::warn!("{:?}:{} keepalive no response, error \"{}\"", local_addr, stream.id(), e), |
|
|
|
if stream.is_closed() { |
|
|
|
// This stream will be dropped
|
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
if stream.last_activity.elapsed() < self.keepalive_time { |
|
|
|
self.store_server_connection(stream).await; |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
let Some(mut stream_reader) = stream.get_reader() else { |
|
|
|
continue; |
|
|
|
}; |
|
|
|
|
|
|
|
let Some(mut stream_writer) = stream.get_writer() else { |
|
|
|
continue; |
|
|
|
}; |
|
|
|
let local_addr = stream_writer.local_addr()?; |
|
|
|
let sn = stream.serial_number(); |
|
|
|
log::trace!("stream {} {:?} send keepalive", sn, local_addr); |
|
|
|
let keepalive_packet: Vec<u8> = Packet::build_keepalive_packet(stream.new_packet_id()).into(); |
|
|
|
if let Err(e) = stream_writer.write_all(&keepalive_packet).await { |
|
|
|
log::warn!("stream {} {:?} send keepalive failed: {}", sn, local_addr, e); |
|
|
|
continue; |
|
|
|
} |
|
|
|
match UdpGwClient::recv_udpgw_packet(self.udp_mtu, 10, &mut stream_reader).await { |
|
|
|
Ok(UdpGwResponse::KeepAlive) => { |
|
|
|
stream.update_activity(); |
|
|
|
self.store_server_connection_full(stream, stream_reader, stream_writer).await; |
|
|
|
log::trace!("stream {} {:?} keepalive success", sn, local_addr); |
|
|
|
} |
|
|
|
Ok(v) => log::warn!("stream {} {:?} keepalive unexpected response: {:?}", sn, local_addr, v), |
|
|
|
Err(e) => log::warn!("stream {} {:?} keepalive no response, error \"{}\"", sn, local_addr, e), |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
@ -474,10 +514,10 @@ impl UdpGwClient { |
|
|
|
pub(crate) fn parse_udp_response(udp_mtu: u16, data: &[u8]) -> Result<UdpGwResponse> { |
|
|
|
let packet = Packet::try_from(data)?; |
|
|
|
let flags = packet.header.flags; |
|
|
|
if flags & UDPGW_FLAG_ERR != 0 { |
|
|
|
if flags & UdpFlag::ERR == UdpFlag::ERR { |
|
|
|
return Ok(UdpGwResponse::Error); |
|
|
|
} |
|
|
|
if flags & UDPGW_FLAG_KEEPALIVE != 0 { |
|
|
|
if flags & UdpFlag::KEEPALIVE == UdpFlag::KEEPALIVE { |
|
|
|
return Ok(UdpGwResponse::KeepAlive); |
|
|
|
} |
|
|
|
if packet.data.len() > udp_mtu as usize { |
|
|
|
|