diff --git a/src/socks.rs b/src/socks.rs index 3c71e2a..ec4bf39 100644 --- a/src/socks.rs +++ b/src/socks.rs @@ -194,6 +194,13 @@ impl SocksProxyImpl { self.state_change() } + fn send_request_socks5(&mut self) -> Result<(), Error> { + protocol::Request::new(protocol::Command::Connect, self.info.dst.clone()) + .write_to_stream(&mut self.server_outbuf)?; + self.state = SocksState::ReceiveResponse; + self.state_change() + } + fn receive_connection_status(&mut self) -> Result<(), Error> { let response = protocol::Response::retrieve_from_stream(&mut self.server_inbuf.clone()); if let Err(e) = &response { @@ -217,14 +224,6 @@ impl SocksProxyImpl { self.state_change() } - fn send_request_socks5(&mut self) -> Result<(), Error> { - // self.server_outbuf.extend(&[self.version as u8, self.command as u8, 0]); - protocol::Request::new(protocol::Command::Connect, self.info.dst.clone()) - .write_to_stream(&mut self.server_outbuf)?; - self.state = SocksState::ReceiveResponse; - self.state_change() - } - fn relay_traffic(&mut self) -> Result<(), Error> { self.client_outbuf.extend(self.server_inbuf.iter()); self.server_outbuf.extend(self.client_inbuf.iter()); diff --git a/src/tun2proxy.rs b/src/tun2proxy.rs index 1bb0170..e33f43b 100644 --- a/src/tun2proxy.rs +++ b/src/tun2proxy.rs @@ -12,7 +12,7 @@ use std::{ collections::{HashMap, HashSet}, convert::{From, TryFrom}, io::{Read, Write}, - net::{IpAddr, Ipv4Addr, Ipv6Addr, Shutdown, Shutdown::Both, SocketAddr}, + net::{IpAddr, Ipv4Addr, Ipv6Addr, Shutdown, SocketAddr}, os::unix::io::AsRawFd, rc::Rc, str::FromStr, @@ -278,6 +278,7 @@ impl<'a> TunToProxy<'a> { self.connection_managers.push(manager); } + /// Read data from virtual device (remote server) and inject it into tun interface. fn expect_smoltcp_send(&mut self) -> Result<(), Error> { self.iface.poll(Instant::now(), &mut self.device, &mut self.sockets); @@ -297,7 +298,7 @@ impl<'a> TunToProxy<'a> { fn remove_connection(&mut self, info: &ConnectionInfo) -> Result<(), Error> { if let Some(mut conn) = self.connection_map.remove(info) { - _ = conn.mio_stream.shutdown(Both); + _ = conn.mio_stream.shutdown(Shutdown::Both); if let Some(handle) = conn.smoltcp_handle { let socket = self.sockets.get_mut::(handle); socket.close(); @@ -322,11 +323,10 @@ impl<'a> TunToProxy<'a> { } fn check_change_close_state(&mut self, info: &ConnectionInfo) -> Result<(), Error> { - let state = self.connection_map.get_mut(info); - if state.is_none() { - return Ok(()); - } - let state = state.unwrap(); + let state = match self.connection_map.get_mut(info) { + None => return Ok(()), + Some(state) => state, + }; let mut closed_ends = 0; if (state.close_state & SERVER_WRITE_CLOSED) == SERVER_WRITE_CLOSED && !state @@ -452,31 +452,12 @@ impl<'a> TunToProxy<'a> { continue; } let tcp_proxy_handler = tcp_proxy_handler?; - let mut socket = tcp::Socket::new( - tcp::SocketBuffer::new(vec![0; 1024 * 128]), - tcp::SocketBuffer::new(vec![0; 1024 * 128]), - ); - socket.set_ack_delay(None); - socket.listen(dst)?; - let handle = self.sockets.add(socket); - - let mut client = TcpStream::connect(server_addr)?; - let token = self.new_token(); - let i = Interest::READABLE; - self.poll.registry().register(&mut client, token, i)?; - - let state = TcpConnectState { - smoltcp_handle: Some(handle), - mio_stream: client, - token, + self.create_new_tcp_proxy_connection( + server_addr, + dst, tcp_proxy_handler, - close_state: 0, - wait_read: true, - wait_write: false, - }; - self.connection_map.insert(connection_info.clone(), state); - - self.token_to_info.insert(token, connection_info.clone()); + connection_info.clone(), + )?; log::info!("Connect done {} ({})", connection_info, dst); done = true; @@ -492,7 +473,7 @@ impl<'a> TunToProxy<'a> { log::trace!("Subsequent packet {} ({})", connection_info, dst); } - // Inject the packet to advance the smoltcp socket state + // Inject the packet to advance the remote proxy server smoltcp socket state self.device.inject_packet(frame); // Having advanced the socket state, we expect the socket to ACK @@ -509,8 +490,8 @@ impl<'a> TunToProxy<'a> { } else if connection_info.protocol == IpProtocol::Udp { log::trace!("{} ({})", connection_info, dst); let port = connection_info.dst.port(); + let payload = &frame[payload_offset..payload_offset + payload_size]; if let (Some(virtual_dns), true) = (&mut self.options.virtual_dns, port == 53) { - let payload = &frame[payload_offset..payload_offset + payload_size]; let response = virtual_dns.receive_query(payload)?; { let rx_buffer = udp::PacketBuffer::new(vec![udp::PacketMetadata::EMPTY], vec![0; 4096]); @@ -536,6 +517,41 @@ impl<'a> TunToProxy<'a> { Ok(()) } + fn create_new_tcp_proxy_connection( + &mut self, + server_addr: SocketAddr, + dst: SocketAddr, + tcp_proxy_handler: Box, + connection_info: ConnectionInfo, + ) -> Result<()> { + let mut socket = tcp::Socket::new( + tcp::SocketBuffer::new(vec![0; 1024 * 128]), + tcp::SocketBuffer::new(vec![0; 1024 * 128]), + ); + socket.set_ack_delay(None); + socket.listen(dst)?; + let handle = self.sockets.add(socket); + + let mut client = TcpStream::connect(server_addr)?; + let token = self.new_token(); + let i = Interest::READABLE; + self.poll.registry().register(&mut client, token, i)?; + + let state = TcpConnectState { + smoltcp_handle: Some(handle), + mio_stream: client, + token, + tcp_proxy_handler, + close_state: 0, + wait_read: true, + wait_write: false, + }; + self.connection_map.insert(connection_info.clone(), state); + + self.token_to_info.insert(token, connection_info.clone()); + Ok(()) + } + fn write_to_server(&mut self, info: &ConnectionInfo) -> Result<(), Error> { if let Some(state) = self.connection_map.get_mut(info) { let event = state.tcp_proxy_handler.peek_data(OutgoingDirection::ToServer); @@ -680,7 +696,7 @@ impl<'a> TunToProxy<'a> { if state.tcp_proxy_handler.reset_connection() { _ = self.poll.registry().deregister(&mut state.mio_stream); // Closes the connection with the proxy - state.mio_stream.shutdown(Both)?; + state.mio_stream.shutdown(Shutdown::Both)?; log::info!("RESET {}", conn_info); diff --git a/src/virtdevice.rs b/src/virtdevice.rs index fc862d9..721466c 100644 --- a/src/virtdevice.rs +++ b/src/virtdevice.rs @@ -3,6 +3,7 @@ use smoltcp::{ time::Instant, }; +/// Virtual device representing the remote proxy server. #[derive(Default)] pub struct VirtualTunDevice { capabilities: DeviceCapabilities,