diff --git a/src/http.rs b/src/http.rs index cc312fa..f5268eb 100644 --- a/src/http.rs +++ b/src/http.rs @@ -1,4 +1,4 @@ -use crate::tun2proxy::{Connection, TcpProxy, IncomingDirection, OutgoingDirection, OutgoingDataEvent, IncomingDataEvent, ConnectionManager}; +use crate::tun2proxy::{Connection, TcpProxy, IncomingDirection, OutgoingDirection, OutgoingDataEvent, IncomingDataEvent, ConnectionManager, ProxyError}; use std::collections::VecDeque; use std::net::SocketAddr; @@ -6,6 +6,7 @@ use std::net::SocketAddr; #[allow(dead_code)] enum HttpState { SendRequest, + ExpectStatusCode, ExpectResponse, Established } @@ -23,7 +24,7 @@ pub struct HttpConnection { impl HttpConnection { fn new(connection: &Connection) -> Self { let mut result = Self { - state: HttpState::ExpectResponse, + state: HttpState::ExpectStatusCode, client_inbuf: Default::default(), server_inbuf: Default::default(), client_outbuf: Default::default(), @@ -34,29 +35,27 @@ impl HttpConnection { result.server_outbuf.extend(b"CONNECT ".iter()); - result.destination_to_server_outbuf(connection); + result.server_outbuf.extend(connection.dst.to_string().as_bytes()); result.server_outbuf.extend(b" HTTP/1.1\r\nHost: ".iter()); - result.destination_to_server_outbuf(connection); + result.server_outbuf.extend(connection.dst.to_string().as_bytes()); result.server_outbuf.extend(b"\r\n\r\n".iter()); result } - fn destination_to_server_outbuf(&mut self, connection: &Connection) { - let ipv6 = connection.dst.is_ipv6(); - if ipv6 { - self.server_outbuf.extend(b"[".iter()); - } - self.server_outbuf.extend(connection.dst.ip().to_string().as_bytes()); - if ipv6 { - self.server_outbuf.extend(b"]".iter()); - } - self.server_outbuf.extend(b":".iter()); - self.server_outbuf.extend(connection.dst.port().to_string().as_bytes()); - } - - fn state_change(&mut self) { + fn state_change(&mut self) -> Result<(), ProxyError> { match self.state { + HttpState::ExpectStatusCode if self.server_inbuf.len() >= "HTTP/1.1 200 ".len() => { + let status_line: Vec = self.server_inbuf.range(0.."HTTP/1.1 200 ".len()).map(|&x| x).collect(); + let slice = &status_line.as_slice()[0.."HTTP/1.1 2".len()]; + if slice != b"HTTP/1.1 2" && slice != b"HTTP/1.0 2" + || self.server_inbuf["HTTP/1.1 200 ".len() - 1] != b' '{ + let status_str = String::from_utf8_lossy(&status_line.as_slice()[0.."HTTP/1.1 200".len()]); + return Err(ProxyError::new("Expected success status code. Server replied with ".to_owned() + &*status_str + ".")); + } + self.state = HttpState::ExpectResponse; + return self.state_change(); + } HttpState::ExpectResponse => { let mut counter = 0usize; for b_ref in self.server_inbuf.iter() { @@ -74,13 +73,8 @@ impl HttpConnection { self.server_outbuf.append(&mut self.data_buf); self.data_buf.clear(); - self.client_outbuf.extend(self.server_inbuf.iter()); - self.server_outbuf.extend(self.client_inbuf.iter()); - self.server_inbuf.clear(); - self.client_inbuf.clear(); - self.state = HttpState::Established; - return; + return self.state_change(); } } @@ -93,15 +87,15 @@ impl HttpConnection { self.client_inbuf.clear(); } _ => { - unreachable!(); } } + Ok(()) } } impl TcpProxy for HttpConnection { - fn push_data(&mut self, event: IncomingDataEvent<'_>) { + fn push_data(&mut self, event: IncomingDataEvent<'_>) -> Result<(), ProxyError> { let direction = event.direction; let buffer = event.buffer; match direction { @@ -117,7 +111,7 @@ impl TcpProxy for HttpConnection { } } - self.state_change(); + self.state_change() } @@ -143,6 +137,10 @@ impl TcpProxy for HttpConnection { }; return event; } + + fn connection_established(&self) -> bool { + return self.state == HttpState::Established + } } pub struct HttpManager { diff --git a/src/main.rs b/src/main.rs index 5fe71a9..14a3401 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,6 @@ #![feature(deque_make_contiguous)] +#![feature(deque_range)] + mod virtdevice; mod socks5; mod http; @@ -7,8 +9,7 @@ mod tun2proxy; use socks5::*; use crate::http::HttpManager; use crate::tun2proxy::TunToProxy; -use std::net::SocketAddr; -use std::str::FromStr; +use std::net::ToSocketAddrs; fn main() { let matches = clap::App::new(env!("CARGO_PKG_NAME")) @@ -44,8 +45,10 @@ fn main() { let tun_name = matches.value_of("tun").unwrap(); let mut ttp = TunToProxy::new(tun_name); if let Some(addr) = matches.value_of("socks5_server") { - if let Ok(server) = SocketAddr::from_str(addr) + if let Ok(mut servers) = addr.to_socket_addrs() { + let server = servers.next().unwrap(); + println!("SOCKS5 server: {}", server); ttp.add_connection_manager(Box::new(Socks5Manager::new(server))); } else { eprintln!("Invalid server address."); @@ -54,8 +57,10 @@ fn main() { } if let Some(addr) = matches.value_of("http_server") { - if let Ok(server) = SocketAddr::from_str(addr) + if let Ok(mut servers) = addr.to_socket_addrs() { + let server = servers.next().unwrap(); + println!("HTTP server: {}", server); ttp.add_connection_manager(Box::new(HttpManager::new(server))); } else { eprintln!("Invalid server address."); diff --git a/src/socks5.rs b/src/socks5.rs index fd3a2f7..bbae431 100644 --- a/src/socks5.rs +++ b/src/socks5.rs @@ -1,4 +1,4 @@ -use crate::tun2proxy::{Connection, OutgoingDirection, OutgoingDataEvent, IncomingDirection, IncomingDataEvent, ConnectionManager, TcpProxy}; +use crate::tun2proxy::{Connection, OutgoingDirection, OutgoingDataEvent, IncomingDirection, IncomingDataEvent, ConnectionManager, TcpProxy, ProxyError}; use std::collections::VecDeque; use std::net::{IpAddr, SocketAddr}; @@ -28,6 +28,26 @@ enum SocksAuthentication { Password = 2 } +#[allow(dead_code)] +#[repr(u8)] +#[derive(Debug, Eq, PartialEq)] +enum SocksReplies { + Succeeded, + GeneralFailure, + ConnectionDisallowed, + NetworkUnreachable, + ConnectionRefused, + TtlExpired, + CommandUnsupported, + AddressUnsupported +} + +impl std::fmt::Display for SocksReplies { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{:?}", self) + } +} + pub struct SocksConnection { connection: Connection, state: SocksState, @@ -54,20 +74,23 @@ impl SocksConnection { result } - fn forward_data(&mut self) { - self.client_outbuf.extend(self.server_inbuf.iter()); - self.server_outbuf.extend(self.client_inbuf.iter()); - self.server_inbuf.clear(); - self.client_inbuf.clear(); - } - - pub fn state_change(&mut self) { + pub fn state_change(&mut self) -> Result<(), ProxyError> { let dst_ip = self.connection.dst.ip(); match self.state { - SocksState::ServerHello if self.server_inbuf.len() == 2 => { - assert!(self.server_inbuf[0] == 5 && self.server_inbuf[1] == 0); + SocksState::ServerHello if self.server_inbuf.len() >= 2 => { + if self.server_inbuf[0] != 5 { + return Err(ProxyError::new( + "SOCKS server replied with an unexpected version.".into())); + } + + if self.server_inbuf[1] != 0 { + return Err(ProxyError::new( + "SOCKS server requires an unsupported authentication method.".into())); + } + + self.server_inbuf.drain(0..2); let cmd = if dst_ip.is_ipv4() { 1 } else { 4 }; @@ -82,31 +105,36 @@ impl SocksConnection { ]); self.state = SocksState::ReceiveResponse; - } - - SocksState::ServerHello if self.server_inbuf.len() > 2 => { - panic!("Socks protocol error!") + return self.state_change(); } SocksState::ReceiveResponse if self.server_inbuf.len() >= 4 => { - let _ver = self.server_inbuf[0]; - let _rep = self.server_inbuf[1]; + let ver = self.server_inbuf[0]; + let rep = self.server_inbuf[1]; let _rsv = self.server_inbuf[2]; let atyp = self.server_inbuf[3]; + if ver != 5 { + return Err(ProxyError::new("SOCKS server replied with an unexpected version.".into())); + } + + if rep != 0 { + return Err(ProxyError::new("SOCKS connection unsuccessful.".into())); + } + if atyp != SocksAddressType::Ipv4 as u8 && atyp != SocksAddressType::Ipv6 as u8 && atyp != SocksAddressType::DomainName as u8 { - panic!("Invalid address type"); + return Err(ProxyError::new("SOCKS server replied with unrecognized address type.".into())); } if atyp == SocksAddressType::DomainName as u8 && self.server_inbuf.len() < 5 { - return; + return Ok(()); } if atyp == SocksAddressType::DomainName as u8 && self.server_inbuf.len() < 7 + (self.server_inbuf[4] as usize) { - return; + return Ok(()); } let message_length = if atyp == SocksAddressType::Ipv4 as u8 { @@ -121,21 +149,25 @@ impl SocksConnection { self.server_outbuf.append(&mut self.data_buf); self.data_buf.clear(); - self.forward_data(); self.state = SocksState::Established; + return self.state_change(); } SocksState::Established => { - self.forward_data(); + self.client_outbuf.extend(self.server_inbuf.iter()); + self.server_outbuf.extend(self.client_inbuf.iter()); + self.server_inbuf.clear(); + self.client_inbuf.clear(); } _ => {} } + Ok(()) } } impl TcpProxy for SocksConnection { - fn push_data(&mut self, event: IncomingDataEvent<'_>) { + fn push_data(&mut self, event: IncomingDataEvent<'_>) -> Result<(), ProxyError> { let direction = event.direction; let buffer = event.buffer; match direction { @@ -151,7 +183,7 @@ impl TcpProxy for SocksConnection { } } - self.state_change(); + self.state_change() } @@ -177,6 +209,10 @@ impl TcpProxy for SocksConnection { }; return event; } + + fn connection_established(&self) -> bool { + return self.state == SocksState::Established + } } pub struct Socks5Manager { diff --git a/src/tun2proxy.rs b/src/tun2proxy.rs index 7be9556..8e295aa 100644 --- a/src/tun2proxy.rs +++ b/src/tun2proxy.rs @@ -14,7 +14,23 @@ use smoltcp::socket::{SocketHandle, SocketSet, TcpSocket, TcpSocketBuffer}; use smoltcp::time::Instant; use smoltcp::wire::{IpAddress, IpCidr, Ipv4Address, Ipv4Packet, TcpPacket, UdpPacket, Ipv6Packet}; use crate::virtdevice::VirtualTunDevice; +use std::net::Shutdown::Both; +pub struct ProxyError { + message: String +} + +impl ProxyError { + pub fn new(message: String) -> Self { + Self { + message + } + } + + pub fn message(&self) -> String { + self.message.clone() + } +} #[derive(Hash, Clone, Copy)] pub struct Connection { @@ -23,6 +39,12 @@ pub struct Connection { pub proto: u8 } +impl std::fmt::Display for Connection { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{} -> {}", self.src, self.dst) + } +} + impl Eq for Connection {} impl PartialEq for Connection { @@ -141,9 +163,10 @@ struct ConnectionState { } pub(crate) trait TcpProxy { - fn push_data(&mut self, event: IncomingDataEvent<'_>); + fn push_data(&mut self, event: IncomingDataEvent<'_>) -> Result<(), ProxyError>; fn consume_data(&mut self, dir: OutgoingDirection, size: usize); fn peek_data(&mut self, dir: OutgoingDirection) -> OutgoingDataEvent; + fn connection_established(&self) -> bool; } pub(crate) trait ConnectionManager { @@ -227,7 +250,7 @@ impl<'a> TunToProxy<'a> { let mut connection_state = self.connections.remove(connection).unwrap(); self.token_to_connection.remove(&connection_state.token); self.poll.registry().deregister(&mut connection_state.mio_stream).unwrap(); - println!("[{:?}] CLOSE {} -> {} (TCP)", chrono::offset::Local::now(), connection.src, connection.dst); + println!("[{:?}] CLOSE {}", chrono::offset::Local::now(), connection); } fn get_connection_manager(&self, connection: &Connection) -> Option<&Box>{ @@ -239,25 +262,35 @@ impl<'a> TunToProxy<'a> { None } + fn print_error(error: ProxyError) { + println!("Error: {}", error.message()); + } + fn tunsocket_read_and_forward(&mut self, connection: &Connection) { if let Some(handler) = self.managers.get_mut(&connection) { let closed = { let conn_info = self.connections.get_mut(&connection).unwrap(); let mut socket = self.socketset.get::(conn_info.smoltcp_handle); - while socket.can_recv() { + let mut error = Ok(()); + while socket.can_recv() && error.is_ok() { socket.recv(|data| { let event = IncomingDataEvent { direction: IncomingDirection::FromClient, buffer: data, }; - handler.push_data(event); + error = handler.push_data(event); (data.len(), ()) }).unwrap(); } - socket.state() == smoltcp::socket::TcpState::CloseWait + if error.is_err() { + Self::print_error(error.unwrap_err()); + true + } else { + socket.state() == smoltcp::socket::TcpState::CloseWait + } }; if closed { @@ -284,7 +317,11 @@ impl<'a> TunToProxy<'a> { socket.listen(connection.dst).unwrap(); let handle = self.socketset.add(socket); - let socket = MioTcp::new_v4().unwrap(); + let socket = if server.is_ipv4() { + MioTcp::new_v4().unwrap() + } else { + MioTcp::new_v6().unwrap() + }; let client = socket.connect(server).unwrap(); let token = Token(self.next_token); @@ -309,7 +346,7 @@ impl<'a> TunToProxy<'a> { } - println!("[{:?}] CONNECT {} -> {} (TCP)", chrono::offset::Local::now(), connection.src, connection.dst); + println!("[{:?}] CONNECT {}", chrono::offset::Local::now(), connection); } else if !self.connections.contains_key(&connection) { return; } @@ -388,16 +425,6 @@ impl<'a> TunToProxy<'a> { fn mio_socket_event(&mut self, event: &Event) { let connection = *self.token_to_connection.get(&event.token()).unwrap(); - if event.is_read_closed() { - { - let mut socket = self.socketset.get::(self.connections.get(&connection).unwrap().smoltcp_handle); - socket.close(); - } - self.expect_smoltcp_send(); - self.remove_connection(&connection.clone()); - return; - } - if event.is_readable() { { let conn = self.managers.get_mut(&connection).unwrap(); @@ -406,12 +433,32 @@ impl<'a> TunToProxy<'a> { let mut buf = [0u8; 4096]; let read = state.mio_stream.read(&mut buf).unwrap(); + if read == 0 { + { + let mut socket = self.socketset.get::(self.connections.get(&connection).unwrap().smoltcp_handle); + socket.close(); + } + self.expect_smoltcp_send(); + self.remove_connection(&connection.clone()); + return; + } + let event = IncomingDataEvent { direction: IncomingDirection::FromServer, buffer: &buf[0..read], }; - conn.push_data(event); + if let Err(error) = conn.push_data(event) { + state.mio_stream.shutdown(Both).unwrap(); + { + let mut socket = self.socketset.get::(self.connections.get(&connection).unwrap().smoltcp_handle); + socket.close(); + } + self.expect_smoltcp_send(); + Self::print_error(error); + self.remove_connection(&connection.clone()); + return; + } } // We have read from the proxy server and pushed the data to the connection handler.