From c9297124e116a89f74df36a8dd0fae52ec30efd8 Mon Sep 17 00:00:00 2001 From: "B. Blechschmidt" Date: Thu, 23 Mar 2023 13:03:01 +0100 Subject: [PATCH] Implement first, unfinished version of DNS support and fix incomplete TCP stream bug This commit does two things: First, it implements a first, unfinished version of the virtual DNS functionality. This feature is incomplete and has known bugs. Since it needs to be enabled manually, this is not a huge issue so far. Second, the commit fixes a bug where TCP streams where not properly relayed, causing TCP connections to stall. --- src/lib.rs | 7 +- src/main.rs | 11 ++- src/tun2proxy.rs | 223 +++++++++++++++++++++++++++++++++++++---------- src/virtdns.rs | 179 +++++++++++++++++++++++++++++++++++++ tests/proxy.rs | 9 +- 5 files changed, 375 insertions(+), 54 deletions(-) create mode 100644 src/virtdns.rs diff --git a/src/lib.rs b/src/lib.rs index 248e8b8..df2910d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,5 @@ use crate::error::{s2e, Error}; -use crate::tun2proxy::Credentials; +use crate::tun2proxy::{Credentials, Options}; use crate::{http::HttpManager, socks5::Socks5Manager, tun2proxy::TunToProxy}; use std::net::{SocketAddr, ToSocketAddrs}; @@ -8,6 +8,7 @@ pub mod http; pub mod socks5; pub mod tun2proxy; pub mod virtdevice; +pub mod virtdns; #[derive(Clone, Debug)] pub struct Proxy { @@ -75,8 +76,8 @@ impl std::fmt::Display for ProxyType { } } -pub fn main_entry(tun: &str, proxy: Proxy) { - let mut ttp = TunToProxy::new(tun); +pub fn main_entry(tun: &str, proxy: Proxy, options: Options) { + let mut ttp = TunToProxy::new(tun, options); match proxy.proxy_type { ProxyType::Socks5 => { ttp.add_connection_manager(Socks5Manager::new(proxy.addr, proxy.credentials)); diff --git a/src/main.rs b/src/main.rs index 87dab34..11ffde4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,7 @@ use clap::Parser; use env_logger::Env; +use tun2proxy::tun2proxy::Options; use tun2proxy::{main_entry, Proxy}; /// Tunnel interface to proxy @@ -14,6 +15,9 @@ struct Args { /// The proxy URL in the form proto://[username[:password]@]host:port #[arg(short, long = "proxy", value_parser = Proxy::from_url, value_name = "URL")] proxy: Proxy, + + #[arg(short, long = "dns")] + virtual_dns: bool, } fn main() { @@ -24,5 +28,10 @@ fn main() { let proxy_type = args.proxy.proxy_type; log::info!("Proxy {proxy_type} server: {addr}"); - main_entry(&args.tun, args.proxy); + let mut options = Options::new(); + if args.virtual_dns { + options = options.with_virtual_dns(); + } + + main_entry(&args.tun, args.proxy, options); } diff --git a/src/tun2proxy.rs b/src/tun2proxy.rs index 28a3ed3..94ac443 100644 --- a/src/tun2proxy.rs +++ b/src/tun2proxy.rs @@ -1,5 +1,7 @@ use crate::error::Error; +use crate::tun2proxy::DestinationHost::Hostname; use crate::virtdevice::VirtualTunDevice; +use crate::virtdns::VirtualDns; use log::{error, info}; use mio::event::Event; use mio::net::TcpStream; @@ -12,7 +14,7 @@ use smoltcp::time::Instant; use smoltcp::wire::{ IpAddress, IpCidr, Ipv4Address, Ipv4Packet, Ipv6Address, Ipv6Packet, TcpPacket, UdpPacket, }; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::convert::From; use std::fmt::{Display, Formatter}; use std::io::{Read, Write}; @@ -55,6 +57,20 @@ impl From for SocketAddr { } } +impl From<&Destination> for SocketAddr { + fn from(value: &Destination) -> Self { + SocketAddr::new( + match value.host { + DestinationHost::Address(addr) => addr, + DestinationHost::Hostname(_) => { + panic!("Failed to convert hostname destination into socket address") + } + }, + value.port, + ) + } +} + impl From for Destination { fn from(addr: SocketAddr) -> Self { Self { @@ -66,17 +82,32 @@ impl From for Destination { impl Display for Destination { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}:{}", self.host.to_string(), self.port) + let host_part = match self.host { + DestinationHost::Address(addr) => match addr { + IpAddr::V4(_) => addr.to_string(), + IpAddr::V6(_) => format!("[{addr}]"), + }, + Hostname(_) => self.host.to_string(), + }; + write!(f, "{}:{}", host_part, self.port) } } #[derive(Hash, Clone, Eq, PartialEq)] -pub(crate) struct Connection { +pub struct Connection { pub(crate) src: std::net::SocketAddr, pub(crate) dst: Destination, pub(crate) proto: u8, } +impl Connection { + fn to_named(&self, name: String) -> Self { + let mut result = self.clone(); + result.dst.host = Hostname(name); + result + } +} + impl std::fmt::Display for Connection { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { write!(f, "{} -> {}", self.src, self.dst) @@ -193,11 +224,14 @@ fn connection_tuple(frame: &[u8]) -> Option<(Connection, bool, usize, usize)> { } } +const WRITE_CLOSED: u8 = 1; + struct ConnectionState { smoltcp_handle: SocketHandle, mio_stream: TcpStream, token: Token, handler: std::boxed::Box, + smoltcp_socket_state: u8, } #[derive(Default, Clone, Debug)] @@ -234,6 +268,22 @@ pub(crate) trait ConnectionManager { fn get_credentials(&self) -> &Option; } +#[derive(Default, Clone, Debug)] +pub struct Options { + virtdns: Option, +} + +impl Options { + pub fn new() -> Self { + Default::default() + } + + pub fn with_virtual_dns(mut self) -> Self { + self.virtdns = Some(VirtualDns::new()); + self + } +} + pub(crate) struct TunToProxy<'a> { tun: TunTapInterface, poll: Poll, @@ -246,10 +296,12 @@ pub(crate) struct TunToProxy<'a> { token_to_connection: HashMap, sockets: SocketSet<'a>, device: VirtualTunDevice, + options: Options, + write_sockets: HashSet, } impl<'a> TunToProxy<'a> { - pub(crate) fn new(interface: &str) -> Self { + pub(crate) fn new(interface: &str, options: Options) -> Self { let tun_token = Token(0); let tun = TunTapInterface::new(interface, Medium::Ip).unwrap(); let poll = Poll::new().unwrap(); @@ -294,6 +346,8 @@ impl<'a> TunToProxy<'a> { connection_managers: Default::default(), sockets: SocketSet::new([]), device: virt, + options, + write_sockets: Default::default(), } } @@ -320,7 +374,8 @@ impl<'a> TunToProxy<'a> { fn remove_connection(&mut self, connection: &Connection) { let mut connection_state = self.connections.remove(connection).unwrap(); - self.token_to_connection.remove(&connection_state.token); + let token = &connection_state.token; + self.token_to_connection.remove(token); self.poll .registry() .deregister(&mut connection_state.mio_stream) @@ -343,7 +398,6 @@ impl<'a> TunToProxy<'a> { fn tunsocket_read_and_forward(&mut self, connection: &Connection) { if let Some(state) = self.connections.get_mut(connection) { let closed = { - // let socket = self.iface.get_socket::(state.smoltcp_handle); let socket = self.sockets.get_mut::(state.smoltcp_handle); let mut error = Ok(()); while socket.can_recv() && error.is_ok() { @@ -369,6 +423,9 @@ impl<'a> TunToProxy<'a> { } }; + // Expect ACKs etc. from smoltcp sockets. + self.expect_smoltcp_send(); + if closed { let connection_state = self.connections.get_mut(connection).unwrap(); connection_state @@ -384,22 +441,32 @@ impl<'a> TunToProxy<'a> { if let Some((connection, first_packet, _payload_offset, _payload_size)) = connection_tuple(frame) { - if connection.proto == smoltcp::wire::IpProtocol::Tcp.into() { - let cm = self.get_connection_manager(&connection); + let resolved_conn = match &self.options.virtdns { + None => connection.clone(), + Some(virt_dns) => { + match virt_dns.ip_to_name(&SocketAddr::from(&connection.dst).ip()) { + None => connection.clone(), + Some(name) => connection.to_named(name.clone()), + } + } + }; + if resolved_conn.proto == smoltcp::wire::IpProtocol::Tcp.into() { + let cm = self.get_connection_manager(&resolved_conn); if cm.is_none() { return; } let server = cm.unwrap().get_server(); if first_packet { for manager in self.connection_managers.iter_mut() { - if let Some(handler) = manager.new_connection(&connection, manager.clone()) + if let Some(handler) = + manager.new_connection(&resolved_conn, manager.clone()) { let mut socket = smoltcp::socket::tcp::Socket::new( smoltcp::socket::tcp::SocketBuffer::new(vec![0; 4096]), smoltcp::socket::tcp::SocketBuffer::new(vec![0; 4096]), ); socket.set_ack_delay(None); - let dst = connection.dst.clone(); + let dst = connection.dst; socket .listen(>::into(dst)) .unwrap(); @@ -415,9 +482,11 @@ impl<'a> TunToProxy<'a> { mio_stream: client, token, handler, + smoltcp_socket_state: 0, }; - self.token_to_connection.insert(token, connection.clone()); + self.token_to_connection + .insert(token, resolved_conn.clone()); self.poll .registry() .register( @@ -427,13 +496,13 @@ impl<'a> TunToProxy<'a> { ) .unwrap(); - self.connections.insert(connection.clone(), state); + self.connections.insert(resolved_conn.clone(), state); - info!("CONNECT {}", connection,); + info!("CONNECT {}", resolved_conn,); break; } } - } else if !self.connections.contains_key(&connection) { + } else if !self.connections.contains_key(&resolved_conn) { return; } @@ -446,18 +515,37 @@ impl<'a> TunToProxy<'a> { self.expect_smoltcp_send(); // Read from the smoltcp socket and push the data to the connection handler. - self.tunsocket_read_and_forward(&connection); + self.tunsocket_read_and_forward(&resolved_conn); // The connection handler builds up the connection or encapsulates the data. // Therefore, we now expect it to write data to the server. - self.write_to_server(&connection); - } else if connection.proto == smoltcp::wire::IpProtocol::Udp.into() { - // UDP is not yet supported - /*if _payload_offset > frame.len() || _payload_offset + _payload_offset > frame.len() { - return; + self.write_to_server(&resolved_conn); + } else if resolved_conn.proto == smoltcp::wire::IpProtocol::Udp.into() { + if let Some(virtual_dns) = &mut self.options.virtdns { + let payload = &frame[_payload_offset.._payload_offset + _payload_size]; + if let Some(response) = virtual_dns.receive_query(payload) { + let rx_buffer = smoltcp::socket::udp::PacketBuffer::new( + vec![smoltcp::socket::udp::PacketMetadata::EMPTY], + vec![0; 4096], + ); + let tx_buffer = smoltcp::socket::udp::PacketBuffer::new( + vec![smoltcp::socket::udp::PacketMetadata::EMPTY], + vec![0; 4096], + ); + let mut socket = smoltcp::socket::udp::Socket::new(rx_buffer, tx_buffer); + let dst = resolved_conn.dst.clone(); + socket + .bind(>::into(dst)) + .unwrap(); + socket + .send_slice(response.as_slice(), resolved_conn.src.into()) + .expect("failed to send DNS response"); + let handle = self.sockets.add(socket); + self.expect_smoltcp_send(); + self.sockets.remove(handle); + } } - let payload = &frame[_payload_offset.._payload_offset + _payload_size]; - self.virtual_dns.add_query(payload);*/ + // Otherwise, UDP is not yet supported. } } } @@ -485,15 +573,43 @@ impl<'a> TunToProxy<'a> { } } - fn write_to_client(&mut self, connection: &Connection) { - if let Some(state) = self.connections.get_mut(connection) { - let event = state.handler.peek_data(OutgoingDirection::ToClient); - let socket = self.sockets.get_mut::(state.smoltcp_handle); - if socket.may_send() { - let consumed = socket.send_slice(event.buffer).unwrap(); - state - .handler - .consume_data(OutgoingDirection::ToClient, consumed); + fn write_to_client(&mut self, token: Token, connection: &Connection) { + loop { + if let Some(state) = self.connections.get_mut(connection) { + let socket_state = state.smoltcp_socket_state; + let socket_handle = state.smoltcp_handle; + let event = state.handler.peek_data(OutgoingDirection::ToClient); + let buflen = event.buffer.len(); + let consumed; + { + let socket = self.sockets.get_mut::(socket_handle); + if socket.may_send() { + consumed = socket.send_slice(event.buffer).unwrap(); + state + .handler + .consume_data(OutgoingDirection::ToClient, consumed); + self.expect_smoltcp_send(); + if consumed < buflen { + self.write_sockets.insert(token); + break; + } else { + self.write_sockets.remove(&token); + if consumed == 0 { + break; + } + } + } else { + break; + } + } + let socket = self.sockets.get_mut::(socket_handle); + if socket_state & WRITE_CLOSED != 0 && consumed == buflen { + socket.close(); + self.expect_smoltcp_send(); + self.write_sockets.remove(&token); + self.remove_connection(connection); + break; + } } } } @@ -508,20 +624,33 @@ impl<'a> TunToProxy<'a> { } } + fn send_to_smoltcp(&mut self) { + let cloned = self.write_sockets.clone(); + for token in cloned.iter() { + if let Some(connection) = self.token_to_connection.get(token) { + self.write_to_client(*token, &connection.clone()); + } + } + } + fn mio_socket_event(&mut self, event: &Event) { if let Some(conn_ref) = self.token_to_connection.get(&event.token()) { let connection = conn_ref.clone(); - if event.is_readable() { + if event.is_readable() || event.is_read_closed() { { let state = self.connections.get_mut(&connection).unwrap(); - let mut buf = [0u8; 4096]; - let read_result = state.mio_stream.read(&mut buf); - let read = if let Ok(read_result) = read_result { - read_result - } else { - error!("READ from proxy: {}", read_result.as_ref().err().unwrap()); - 0 + // TODO: Move this reading process to its own function. + let mut vecbuf = Vec::::new(); + let read_result = state.mio_stream.read_to_end(&mut vecbuf); + let read = match read_result { + Ok(read_result) => read_result, + Err(error) => { + if error.kind() != std::io::ErrorKind::WouldBlock { + error!("READ from proxy: {}", error); + } + vecbuf.len() + } }; if read == 0 { @@ -536,11 +665,12 @@ impl<'a> TunToProxy<'a> { return; } - let event = IncomingDataEvent { + let data = vecbuf.as_slice(); + let data_event = IncomingDataEvent { direction: IncomingDirection::FromServer, - buffer: &buf[0..read], + buffer: &data[0..read], }; - if let Err(error) = state.handler.push_data(event) { + if let Err(error) = state.handler.push_data(data_event) { state.mio_stream.shutdown(Both).unwrap(); { let socket = self.sockets.get_mut::( @@ -553,14 +683,14 @@ impl<'a> TunToProxy<'a> { self.remove_connection(&connection.clone()); return; } + if event.is_read_closed() { + state.smoltcp_socket_state |= WRITE_CLOSED; + } } // We have read from the proxy server and pushed the data to the connection handler. // Thus, expect data to be processed (e.g. decapsulated) and forwarded to the client. - - //self.expect_smoltcp_send(); - self.write_to_client(&connection); - self.expect_smoltcp_send(); + self.write_to_client(event.token(), &connection); } if event.is_writable() { self.write_to_server(&connection); @@ -584,6 +714,7 @@ impl<'a> TunToProxy<'a> { self.mio_socket_event(event); } } + self.send_to_smoltcp(); } } } diff --git a/src/virtdns.rs b/src/virtdns.rs new file mode 100644 index 0000000..aa607f1 --- /dev/null +++ b/src/virtdns.rs @@ -0,0 +1,179 @@ +use smoltcp::wire::{IpCidr, Ipv4Cidr}; +use std::collections::{HashMap, LinkedList}; +use std::convert::TryInto; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; +use std::str::FromStr; + +#[derive(Eq, PartialEq, Debug)] +#[allow(dead_code, clippy::upper_case_acronyms)] +enum DnsRecordType { + A = 1, + AAAA = 28, +} + +#[derive(Eq, PartialEq, Debug)] +#[allow(dead_code)] +enum DnsClass { + IN = 1, +} + +#[derive(Clone, Debug)] +#[allow(dead_code)] +pub struct VirtualDns { + mapping: HashMap, + expiry: LinkedList, + cidr: IpCidr, + next_addr: std::net::IpAddr, +} + +impl Default for VirtualDns { + fn default() -> Self { + let start_addr = std::net::Ipv4Addr::from_str("198.18.0.0").unwrap(); + Self { + cidr: Ipv4Cidr::new(start_addr.into(), 15).into(), + next_addr: start_addr.into(), + mapping: Default::default(), + expiry: Default::default(), + } + } +} + +impl VirtualDns { + pub fn new() -> Self { + Default::default() + } + + pub fn receive_query(&mut self, data: &[u8]) -> Option> { + if data.len() < 17 { + return None; + } + // bit 1: Message is a query (0) + // bits 2 - 5: Standard query opcode (0) + // bit 6: Unused + // bit 7: Message is not truncated (0) + // bit 8: Recursion desired (1) + let is_supported_query = (data[2] & 0b11111011) == 0b00000001; + let num_queries = (data[4] as u16) << 8 | data[5] as u16; + if !is_supported_query || num_queries != 1 { + return None; + } + + let result = VirtualDns::parse_qname(data, 12); + result.as_ref()?; + let (qname, offset) = result.unwrap(); + if offset + 3 >= data.len() { + return None; + } + let qtype = (data[offset] as u16) << 8 | data[offset + 1] as u16; + let qclass = (data[offset + 2] as u16) << 8 | data[offset + 3] as u16; + + if qtype != DnsRecordType::A as u16 && qtype != DnsRecordType::AAAA as u16 + || qclass != DnsClass::IN as u16 + { + return None; + } + + log::info!("DNS query: {}", qname); + + let mut response = Vec::::new(); + response.extend(&data[0..offset + 4]); + response[2] |= 0x80; // Message is a response + response[3] |= 0x80; // Recursion available + response[6] = 0; + response[7] = if qtype == DnsRecordType::A as u16 { + 1 + } else { + 0 + }; // one answer record + + // zero other sections + response[8] = 0; + response[9] = 0; + response[10] = 0; + response[11] = 0; + + if let Some(ip) = self.name_to_ip(qname) { + if qtype == DnsRecordType::A as u16 { + response.extend(&[ + 0xc0, 0x0c, // Question name pointer + 0, 1, // Record type: A + 0, 1, // Class: IN + 0, 0, 0, 1, // TTL: 30 seconds + 0, 4, // Data length: 4 bytes + ]); + match ip as std::net::IpAddr { + IpAddr::V4(ip) => response.extend(ip.octets().as_ref()), + IpAddr::V6(ip) => response.extend(ip.octets().as_ref()), + }; + } + } else { + log::error!("Virtual IP space for DNS exhausted"); + response[7] = 0; // No answers + } + Some(response) + } + + fn increment_ip(addr: IpAddr) -> IpAddr { + let mut ip_bytes = match addr as std::net::IpAddr { + IpAddr::V4(ip) => Vec::::from(ip.octets()), + IpAddr::V6(ip) => Vec::::from(ip.octets()), + }; + for j in 0..ip_bytes.len() { + let i = ip_bytes.len() - 1 - j; + if ip_bytes[i] != 255 { + ip_bytes[i] += 1; + break; + } else { + ip_bytes[i] = 0; + } + } + if addr.is_ipv4() { + let bytes: [u8; 4] = ip_bytes.as_slice().try_into().unwrap(); + IpAddr::V4(Ipv4Addr::from(bytes)) + } else { + let bytes: [u8; 16] = ip_bytes.as_slice().try_into().unwrap(); + IpAddr::V6(Ipv6Addr::from(bytes)) + } + } + + pub fn ip_to_name(&self, addr: &IpAddr) -> Option<&String> { + self.mapping.get(addr) + } + + fn name_to_ip(&mut self, name: String) -> Option { + self.next_addr = Self::increment_ip(self.next_addr); + self.mapping.insert(self.next_addr, name); + // TODO: Check if next_addr is CIDR broadcast address and overflow. + // TODO: Caching. + Some(self.next_addr) + } + + fn parse_qname(data: &[u8], mut offset: usize) -> Option<(String, usize)> { + let label_type = data[offset] & 0xC0; + if label_type != 0x00 { + return None; + } + let mut qname = String::from(""); + loop { + if offset >= data.len() { + return None; + } + let label_len = data[offset]; + if label_len == 0 { + offset += 1; + break; + } + for _ in 0..label_len { + offset += 1; + if offset >= data.len() { + return None; + } + qname.push(data[offset] as char); + } + qname.push('.'); + offset += 1; + } + + Some((qname, offset)) + } +} diff --git a/tests/proxy.rs b/tests/proxy.rs index e9b97ca..be05578 100644 --- a/tests/proxy.rs +++ b/tests/proxy.rs @@ -13,6 +13,7 @@ mod tests { use nix::unistd::Pid; use serial_test::serial; + use tun2proxy::tun2proxy::Options; use tun2proxy::{main_entry, Proxy, ProxyType}; static TUN_TEST_DEVICE: &str = "tun0"; @@ -92,7 +93,7 @@ mod tests { } default_route_args.push(String::from(route_component)); } - if default_route_args.len() > 0 { + if !default_route_args.is_empty() { break; } } @@ -137,9 +138,9 @@ mod tests { } Ok(Fork::Child) => { prctl::set_death_signal(signal::SIGKILL as isize).unwrap(); // 9 == SIGKILL - main_entry(TUN_TEST_DEVICE, test.proxy); + main_entry(TUN_TEST_DEVICE, test.proxy, Options::new()); } - Err(_) => assert!(false), + Err(_) => panic!(), } } Err(_) => { @@ -150,7 +151,7 @@ mod tests { } fn require_var(var: &str) { - env::var(var).expect(format!("{var} environment variable required").as_str()); + env::var(var).unwrap_or_else(|_| panic!("{}", "{var} environment variable required")); } #[serial]