diff --git a/src/http.rs b/src/http.rs index b7ddfce..cdbeca2 100644 --- a/src/http.rs +++ b/src/http.rs @@ -8,7 +8,7 @@ use crate::{ use base64::Engine; use httparse::Response; use smoltcp::wire::IpProtocol; -use socks5_impl::protocol::{Address, UserKey}; +use socks5_impl::protocol::UserKey; use std::{ cell::RefCell, collections::{hash_map::RandomState, HashMap, VecDeque}, @@ -52,7 +52,7 @@ pub struct HttpConnection { digest_state: Rc>>, before: bool, credentials: Option, - destination: Address, + info: ConnectionInfo, } static PROXY_AUTHENTICATE: &str = "Proxy-Authenticate"; @@ -80,7 +80,7 @@ impl HttpConnection { digest_state, before: false, credentials, - destination: info.dst.clone(), + info: info.clone(), }; res.send_tunnel_request()?; @@ -89,9 +89,9 @@ impl HttpConnection { fn send_tunnel_request(&mut self) -> Result<(), Error> { self.server_outbuf.extend(b"CONNECT "); - self.server_outbuf.extend(self.destination.to_string().as_bytes()); + self.server_outbuf.extend(self.info.dst.to_string().as_bytes()); self.server_outbuf.extend(b" HTTP/1.1\r\nHost: "); - self.server_outbuf.extend(self.destination.to_string().as_bytes()); + self.server_outbuf.extend(self.info.dst.to_string().as_bytes()); self.server_outbuf.extend(b"\r\n"); self.send_auth_data(if self.digest_state.borrow().is_none() { @@ -111,7 +111,7 @@ impl HttpConnection { match scheme { AuthenticationScheme::Digest => { - let uri = self.destination.to_string(); + let uri = self.info.dst.to_string(); let context = digest_auth::AuthContext::new_with_method( &credentials.username, @@ -318,6 +318,10 @@ impl HttpConnection { } impl TcpProxy for HttpConnection { + fn get_connection_info(&self) -> &ConnectionInfo { + &self.info + } + fn push_data(&mut self, event: IncomingDataEvent<'_>) -> Result<(), Error> { let direction = event.direction; let buffer = event.buffer; @@ -378,6 +382,10 @@ impl TcpProxy for HttpConnection { fn reset_connection(&self) -> bool { self.state == HttpState::Reset } + + fn get_udp_associate(&self) -> Option { + None + } } pub(crate) struct HttpManager { @@ -387,11 +395,7 @@ pub(crate) struct HttpManager { } impl ConnectionManager for HttpManager { - fn handles_connection(&self, info: &ConnectionInfo) -> bool { - info.protocol == IpProtocol::Tcp - } - - fn new_tcp_proxy(&self, info: &ConnectionInfo) -> Result, Error> { + fn new_tcp_proxy(&self, info: &ConnectionInfo, _: bool) -> Result, Error> { if info.protocol != IpProtocol::Tcp { return Err("Invalid protocol".into()); } diff --git a/src/lib.rs b/src/lib.rs index 01ea7ba..bdbf7a0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -130,7 +130,7 @@ pub fn tun_to_proxy<'a>( ProxyType::Socks5 => Rc::new(SocksProxyManager::new(server, Version::V5, credentials)) as Rc, ProxyType::Http => Rc::new(HttpManager::new(server, credentials)) as Rc, }; - ttp.add_connection_manager(mgr); + ttp.set_connection_manager(Some(mgr)); Ok(ttp) } diff --git a/src/socks.rs b/src/socks.rs index ec4bf39..3e4a7a9 100644 --- a/src/socks.rs +++ b/src/socks.rs @@ -1,13 +1,12 @@ use crate::{ - error::Error, + error::{Error, Result}, tun2proxy::{ ConnectionInfo, ConnectionManager, Direction, IncomingDataEvent, IncomingDirection, OutgoingDataEvent, OutgoingDirection, TcpProxy, }, }; -use smoltcp::wire::IpProtocol; use socks5_impl::protocol::{self, handshake, password_method, Address, AuthMethod, StreamOperation, UserKey, Version}; -use std::{collections::VecDeque, net::SocketAddr}; +use std::{collections::VecDeque, convert::TryFrom, net::SocketAddr}; #[derive(Eq, PartialEq, Debug)] #[allow(dead_code)] @@ -31,10 +30,17 @@ struct SocksProxyImpl { data_buf: VecDeque, version: Version, credentials: Option, + command: protocol::Command, + udp_associate: Option, } impl SocksProxyImpl { - fn new(info: &ConnectionInfo, credentials: Option, version: Version) -> Result { + fn new( + info: &ConnectionInfo, + credentials: Option, + version: Version, + command: protocol::Command, + ) -> Result { let mut result = Self { info: info.clone(), state: SocksState::ServerHello, @@ -45,6 +51,8 @@ impl SocksProxyImpl { data_buf: VecDeque::default(), version, credentials, + command, + udp_associate: None, }; result.send_client_hello()?; Ok(result) @@ -150,11 +158,11 @@ impl SocksProxyImpl { return Err("SOCKS5 server requires an unsupported authentication method.".into()); } - if auth_method == AuthMethod::UserPass { - self.state = SocksState::SendAuthData; + self.state = if auth_method == AuthMethod::UserPass { + SocksState::SendAuthData } else { - self.state = SocksState::SendRequest; - } + SocksState::SendRequest + }; self.state_change() } @@ -195,8 +203,7 @@ impl SocksProxyImpl { } fn send_request_socks5(&mut self) -> Result<(), Error> { - protocol::Request::new(protocol::Command::Connect, self.info.dst.clone()) - .write_to_stream(&mut self.server_outbuf)?; + protocol::Request::new(self.command, self.info.dst.clone()).write_to_stream(&mut self.server_outbuf)?; self.state = SocksState::ReceiveResponse; self.state_change() } @@ -216,6 +223,11 @@ impl SocksProxyImpl { if response.reply != protocol::Reply::Succeeded { return Err(format!("SOCKS connection failed: {}", response.reply).into()); } + if self.command == protocol::Command::UdpAssociate { + self.udp_associate = Some(SocketAddr::try_from(&response.address)?); + assert!(self.data_buf.is_empty()); + log::debug!("UDP associate: {}", response.address); + } self.server_outbuf.append(&mut self.data_buf); self.data_buf.clear(); @@ -252,6 +264,10 @@ impl SocksProxyImpl { } impl TcpProxy for SocksProxyImpl { + fn get_connection_info(&self) -> &ConnectionInfo { + &self.info + } + fn push_data(&mut self, event: IncomingDataEvent<'_>) -> Result<(), Error> { let direction = event.direction; let buffer = event.buffer; @@ -312,6 +328,10 @@ impl TcpProxy for SocksProxyImpl { fn reset_connection(&self) -> bool { false } + + fn get_udp_associate(&self) -> Option { + self.udp_associate + } } pub(crate) struct SocksProxyManager { @@ -321,19 +341,11 @@ pub(crate) struct SocksProxyManager { } impl ConnectionManager for SocksProxyManager { - fn handles_connection(&self, info: &ConnectionInfo) -> bool { - info.protocol == IpProtocol::Tcp - } - - fn new_tcp_proxy(&self, info: &ConnectionInfo) -> Result, Error> { - if info.protocol != IpProtocol::Tcp { - return Err("Invalid protocol".into()); - } - Ok(Box::new(SocksProxyImpl::new( - info, - self.credentials.clone(), - self.version, - )?)) + fn new_tcp_proxy(&self, info: &ConnectionInfo, udp_associate: bool) -> Result> { + use socks5_impl::protocol::Command::{Connect, UdpAssociate}; + let command = if udp_associate { UdpAssociate } else { Connect }; + let credentials = self.credentials.clone(); + Ok(Box::new(SocksProxyImpl::new(info, credentials, self.version, command)?)) } fn close_connection(&self, _: &ConnectionInfo) {} diff --git a/src/tun2proxy.rs b/src/tun2proxy.rs index e33f43b..ea1078d 100644 --- a/src/tun2proxy.rs +++ b/src/tun2proxy.rs @@ -7,7 +7,7 @@ use smoltcp::{ time::Instant, wire::{IpCidr, IpProtocol, Ipv4Packet, Ipv6Packet, TcpPacket, UdpPacket, UDP_HEADER_LEN}, }; -use socks5_impl::protocol::{Address, UserKey}; +use socks5_impl::protocol::{Address, StreamOperation, UdpHeader, UserKey}; use std::{ collections::{HashMap, HashSet}, convert::{From, TryFrom}, @@ -18,7 +18,7 @@ use std::{ str::FromStr, }; -#[derive(Hash, Clone, Eq, PartialEq, Debug)] +#[derive(Hash, Clone, Eq, PartialEq, PartialOrd, Ord, Debug)] pub(crate) struct ConnectionInfo { pub(crate) src: SocketAddr, pub(crate) dst: Address, @@ -36,7 +36,6 @@ impl Default for ConnectionInfo { } impl ConnectionInfo { - #[allow(dead_code)] pub fn new(src: SocketAddr, dst: Address, protocol: IpProtocol) -> Self { Self { src, dst, protocol } } @@ -133,11 +132,11 @@ fn connection_tuple(frame: &[u8]) -> Result<(ConnectionInfo, bool, usize, usize) let (ports, first_packet, payload_offset, payload_size) = get_transport_info(protocol, header_len, &frame[header_len..])?; - let info = ConnectionInfo { - src: SocketAddr::new(src_addr, ports.0), - dst: SocketAddr::new(dst_addr, ports.1).into(), + let info = ConnectionInfo::new( + SocketAddr::new(src_addr, ports.0), + SocketAddr::new(dst_addr, ports.1).into(), protocol, - }; + ); return Ok((info, first_packet, payload_offset, payload_size)); } @@ -154,11 +153,11 @@ fn connection_tuple(frame: &[u8]) -> Result<(ConnectionInfo, bool, usize, usize) let (ports, first_packet, payload_offset, payload_size) = get_transport_info(protocol, header_len, &frame[header_len..])?; - let info = ConnectionInfo { - src: SocketAddr::new(src_addr, ports.0), - dst: SocketAddr::new(dst_addr, ports.1).into(), + let info = ConnectionInfo::new( + SocketAddr::new(src_addr, ports.0), + SocketAddr::new(dst_addr, ports.1).into(), protocol, - }; + ); return Ok((info, first_packet, payload_offset, payload_size)); } Err("Neither IPv6 nor IPv4 packet".into()) @@ -178,12 +177,14 @@ struct TcpConnectState { } pub(crate) trait TcpProxy { + fn get_connection_info(&self) -> &ConnectionInfo; fn push_data(&mut self, event: IncomingDataEvent<'_>) -> Result<(), Error>; fn consume_data(&mut self, dir: OutgoingDirection, size: usize); fn peek_data(&mut self, dir: OutgoingDirection) -> OutgoingDataEvent; fn connection_established(&self) -> bool; fn have_data(&mut self, dir: Direction) -> bool; fn reset_connection(&self) -> bool; + fn get_udp_associate(&self) -> Option; } pub(crate) trait UdpProxy { @@ -192,8 +193,7 @@ pub(crate) trait UdpProxy { } pub(crate) trait ConnectionManager { - fn handles_connection(&self, info: &ConnectionInfo) -> bool; - fn new_tcp_proxy(&self, info: &ConnectionInfo) -> Result, Error>; + fn new_tcp_proxy(&self, info: &ConnectionInfo, udp_associate: bool) -> Result>; fn close_connection(&self, info: &ConnectionInfo); fn get_server_addr(&self) -> SocketAddr; fn get_credentials(&self) -> &Option; @@ -207,9 +207,8 @@ pub struct TunToProxy<'a> { poll: Poll, iface: Interface, connection_map: HashMap, - connection_managers: Vec>, + connection_manager: Option>, next_token: usize, - token_to_info: HashMap, sockets: SocketSet<'a>, device: VirtualTunDevice, options: Options, @@ -256,8 +255,7 @@ impl<'a> TunToProxy<'a> { iface, connection_map: HashMap::default(), next_token: usize::from(EXIT_TOKEN) + 1, - token_to_info: HashMap::default(), - connection_managers: Vec::default(), + connection_manager: None, sockets: SocketSet::new([]), device, options, @@ -274,8 +272,8 @@ impl<'a> TunToProxy<'a> { token } - pub(crate) fn add_connection_manager(&mut self, manager: Rc) { - self.connection_managers.push(manager); + pub(crate) fn set_connection_manager(&mut self, manager: Option>) { + self.connection_manager = manager; } /// Read data from virtual device (remote server) and inject it into tun interface. @@ -296,36 +294,41 @@ impl<'a> TunToProxy<'a> { Ok(()) } + fn find_info_by_token(&self, token: Token) -> Option<&ConnectionInfo> { + self.connection_map + .iter() + .find_map(|(info, state)| if state.token == token { Some(info) } else { None }) + } + + /// Destroy connection state machine fn remove_connection(&mut self, info: &ConnectionInfo) -> Result<(), Error> { - if let Some(mut conn) = self.connection_map.remove(info) { - _ = conn.mio_stream.shutdown(Shutdown::Both); - if let Some(handle) = conn.smoltcp_handle { + if let Some(mut state) = self.connection_map.remove(info) { + _ = state.mio_stream.shutdown(Shutdown::Both); + if let Some(handle) = state.smoltcp_handle { let socket = self.sockets.get_mut::(handle); socket.close(); self.sockets.remove(handle); } + + // FIXME: Does this line should be moved up to the beginning of this function? self.expect_smoltcp_send()?; - let token = &conn.token; - self.token_to_info.remove(token); - _ = self.poll.registry().deregister(&mut conn.mio_stream); + + _ = self.poll.registry().deregister(&mut state.mio_stream); + log::info!("Close {}", info); } Ok(()) } - fn get_connection_manager(&self, info: &ConnectionInfo) -> Option> { - for manager in self.connection_managers.iter() { - if manager.handles_connection(info) { - return Some(manager.clone()); - } - } - None + fn get_connection_manager(&self) -> Option> { + self.connection_manager.clone() } + /// Scan connection state machine and check if any connection should be closed. fn check_change_close_state(&mut self, info: &ConnectionInfo) -> Result<(), Error> { let state = match self.connection_map.get_mut(info) { - None => return Ok(()), Some(state) => state, + None => return Ok(()), }; let mut closed_ends = 0; if (state.close_state & SERVER_WRITE_CLOSED) == SERVER_WRITE_CLOSED @@ -336,8 +339,9 @@ impl<'a> TunToProxy<'a> { .tcp_proxy_handler .have_data(Direction::Outgoing(OutgoingDirection::ToClient)) { - if let Some(socket_handle) = state.smoltcp_handle { - let socket = self.sockets.get_mut::(socket_handle); + if let Some(handle) = state.smoltcp_handle { + // Close tun interface + let socket = self.sockets.get_mut::(handle); socket.close(); } closed_ends += 1; @@ -351,17 +355,20 @@ impl<'a> TunToProxy<'a> { .tcp_proxy_handler .have_data(Direction::Outgoing(OutgoingDirection::ToServer)) { + // Close remote server _ = state.mio_stream.shutdown(Shutdown::Write); closed_ends += 1; } if closed_ends == 2 { + // Close connection state machine self.remove_connection(info)?; } Ok(()) } fn tunsocket_read_and_forward(&mut self, info: &ConnectionInfo) -> Result<(), Error> { + // 1. Read data from tun and write to proxy handler (remote server). // Scope for mutable borrow of self. { let state = match self.connection_map.get_mut(info) { @@ -393,10 +400,10 @@ impl<'a> TunToProxy<'a> { // need to send data. state.close_state |= CLIENT_WRITE_CLOSED; } - - // Expect ACKs etc. from smoltcp sockets. - self.expect_smoltcp_send()?; } + // 2. Write data from proxy handler (remote server) to tun. + // Expect ACKs etc. from smoltcp sockets. + self.expect_smoltcp_send()?; self.check_change_close_state(info)?; @@ -426,7 +433,12 @@ impl<'a> TunToProxy<'a> { // A raw packet was received on the tunnel interface. fn receive_tun(&mut self, frame: &mut [u8]) -> Result<(), Error> { let mut handler = || -> Result<(), Error> { - let (info, first_packet, payload_offset, payload_size) = connection_tuple(frame)?; + let result = connection_tuple(frame); + if let Err(error) = result { + log::info!("{}, ignored", error); + return Ok(()); + } + let (info, _first_packet, payload_offset, payload_size) = result?; let dst = SocketAddr::try_from(&info.dst)?; let connection_info = match &mut self.options.virtual_dns { None => info.clone(), @@ -439,33 +451,17 @@ impl<'a> TunToProxy<'a> { } } }; + + let manager = self.get_connection_manager().ok_or("get connection manager")?; + let server_addr = manager.get_server_addr(); + if connection_info.protocol == IpProtocol::Tcp { - let server_addr = self - .get_connection_manager(&connection_info) - .ok_or("get_connection_manager")? - .get_server_addr(); - if first_packet { - let mut done = false; - for manager in self.connection_managers.iter_mut() { - let tcp_proxy_handler = manager.new_tcp_proxy(&connection_info); - if tcp_proxy_handler.is_err() { - continue; - } - let tcp_proxy_handler = tcp_proxy_handler?; - self.create_new_tcp_proxy_connection( - server_addr, - dst, - tcp_proxy_handler, - connection_info.clone(), - )?; - - log::info!("Connect done {} ({})", connection_info, dst); - done = true; - break; - } - if !done { - log::debug!("No connection manager for {} ({})", connection_info, dst); - } + if _first_packet { + let tcp_proxy_handler = manager.new_tcp_proxy(&connection_info, false)?; + let state = self.create_new_tcp_connection_state(server_addr, dst, tcp_proxy_handler)?; + self.connection_map.insert(connection_info.clone(), state); + + log::info!("Connect done {} ({})", connection_info, dst); } else if !self.connection_map.contains_key(&connection_info) { log::debug!("Not found {} ({})", connection_info, dst); return Ok(()); @@ -493,19 +489,33 @@ impl<'a> TunToProxy<'a> { let payload = &frame[payload_offset..payload_offset + payload_size]; if let (Some(virtual_dns), true) = (&mut self.options.virtual_dns, port == 53) { let response = virtual_dns.receive_query(payload)?; - { - let rx_buffer = udp::PacketBuffer::new(vec![udp::PacketMetadata::EMPTY], vec![0; 4096]); - let tx_buffer = udp::PacketBuffer::new(vec![udp::PacketMetadata::EMPTY], vec![0; 4096]); - let mut socket = udp::Socket::new(rx_buffer, tx_buffer); - socket.bind(dst)?; - let meta = UdpMetadata::from(connection_info.src); - socket.send_slice(response.as_slice(), meta)?; - let handle = self.sockets.add(socket); - self.expect_smoltcp_send()?; - self.sockets.remove(handle); + self.send_udp_packet(dst, connection_info.src, response.as_slice())?; + } else { + // Another UDP packet + if !self.connection_map.contains_key(&connection_info) { + log::trace!("New UDP connection {} ({})", connection_info, dst); + let tcp_proxy_handler = manager.new_tcp_proxy(&connection_info, true)?; + let state = self.create_new_tcp_connection_state(server_addr, dst, tcp_proxy_handler)?; + self.connection_map.insert(connection_info.clone(), state); + } + + self.expect_smoltcp_send()?; + self.tunsocket_read_and_forward(&connection_info)?; + self.write_to_server(&connection_info)?; + + let mut s5_udp_data = Vec::::new(); + UdpHeader::new(0, connection_info.dst.clone()).write_to_stream(&mut s5_udp_data)?; + s5_udp_data.extend_from_slice(payload); + + let state = self.connection_map.get(&connection_info).ok_or("udp associate state")?; + if let Some(udp_associate) = state.tcp_proxy_handler.get_udp_associate() { + log::debug!("UDP associate address: {}", udp_associate); + // Send packets via UDP associate... + // self.send_udp_packet(connection_info.src, udp_associate, &s5_udp_data)?; + } else { + // UDP associate tunnel not ready yet, we must cache the packet... } } - // Otherwise, UDP is not yet supported. } else { log::warn!("Unsupported protocol: {} ({})", connection_info, dst); } @@ -517,13 +527,12 @@ impl<'a> TunToProxy<'a> { Ok(()) } - fn create_new_tcp_proxy_connection( + fn create_new_tcp_connection_state( &mut self, server_addr: SocketAddr, dst: SocketAddr, tcp_proxy_handler: Box, - connection_info: ConnectionInfo, - ) -> Result<()> { + ) -> Result { let mut socket = tcp::Socket::new( tcp::SocketBuffer::new(vec![0; 1024 * 128]), tcp::SocketBuffer::new(vec![0; 1024 * 128]), @@ -546,9 +555,18 @@ impl<'a> TunToProxy<'a> { wait_read: true, wait_write: false, }; - self.connection_map.insert(connection_info.clone(), state); + Ok(state) + } - self.token_to_info.insert(token, connection_info.clone()); + fn send_udp_packet(&mut self, src: SocketAddr, dst: SocketAddr, data: &[u8]) -> Result<()> { + let rx_buffer = udp::PacketBuffer::new(vec![udp::PacketMetadata::EMPTY], vec![0; 4096]); + let tx_buffer = udp::PacketBuffer::new(vec![udp::PacketMetadata::EMPTY], vec![0; 4096]); + let mut socket = udp::Socket::new(rx_buffer, tx_buffer); + socket.bind(src)?; + socket.send_slice(data, UdpMetadata::from(dst))?; + let handle = self.sockets.add(socket); + self.expect_smoltcp_send()?; + self.sockets.remove(handle); Ok(()) } @@ -587,7 +605,7 @@ impl<'a> TunToProxy<'a> { fn write_to_client(&mut self, token: Token, info: &ConnectionInfo) -> Result<(), Error> { while let Some(state) = self.connection_map.get_mut(info) { - let socket_handle = match state.smoltcp_handle { + let handle = match state.smoltcp_handle { Some(handle) => handle, None => break, }; @@ -595,7 +613,7 @@ impl<'a> TunToProxy<'a> { let buflen = event.buffer.len(); let consumed; { - let socket = self.sockets.get_mut::(socket_handle); + let socket = self.sockets.get_mut::(handle); if socket.may_send() { if let Some(virtual_dns) = &mut self.options.virtual_dns { // Unwrapping is fine because every smoltcp socket is bound to an. @@ -635,11 +653,10 @@ impl<'a> TunToProxy<'a> { } fn send_to_smoltcp(&mut self) -> Result<(), Error> { - let cloned = self.write_sockets.clone(); - for token in cloned.iter() { - if let Some(connection) = self.token_to_info.get(token) { + for token in self.write_sockets.clone().into_iter() { + if let Some(connection) = self.find_info_by_token(token) { let connection = connection.clone(); - if let Err(error) = self.write_to_client(*token, &connection) { + if let Err(error) = self.write_to_client(token, &connection) { self.remove_connection(&connection)?; log::error!("Write to client: {}: ", error); } @@ -649,19 +666,19 @@ impl<'a> TunToProxy<'a> { } fn mio_socket_event(&mut self, event: &Event) -> Result<(), Error> { - let e = "connection not found"; - let conn_info = match self.token_to_info.get(&event.token()) { + let conn_info = match self.find_info_by_token(event.token()) { Some(conn_info) => conn_info.clone(), None => { // We may have closed the connection in an earlier iteration over the poll events, // e.g. because an event through the tunnel interface indicated that the connection // should be closed. - log::trace!("{e}"); + log::trace!("Connection info not found"); return Ok(()); } }; - let server = self.get_connection_manager(&conn_info).ok_or(e)?.get_server_addr(); + let e = "connection manager not found"; + let server = self.get_connection_manager().ok_or(e)?.get_server_addr(); let mut block = || -> Result<(), Error> { if event.is_readable() || event.is_read_closed() {