diff --git a/src/socks5.rs b/src/socks5.rs index 7448146..76889fa 100644 --- a/src/socks5.rs +++ b/src/socks5.rs @@ -6,6 +6,7 @@ use crate::tun2proxy::{ use crate::Credentials; use smoltcp::wire::IpProtocol; use std::collections::VecDeque; +use std::convert::TryFrom; use std::net::{IpAddr, SocketAddr}; use std::rc::Rc; @@ -22,13 +23,31 @@ enum SocksState { } #[repr(u8)] -#[derive(Copy, Clone)] +#[derive(Copy, Clone, PartialEq, Debug)] enum SocksAddressType { Ipv4 = 1, DomainName = 3, Ipv6 = 4, } +impl TryFrom for SocksAddressType { + type Error = Error; + fn try_from(value: u8) -> Result { + match value { + 1 => Ok(SocksAddressType::Ipv4), + 3 => Ok(SocksAddressType::DomainName), + 4 => Ok(SocksAddressType::Ipv6), + _ => Err(format!("Unknown address type: {}", value).into()), + } + } +} + +impl From for u8 { + fn from(value: SocksAddressType) -> Self { + value as u8 + } +} + #[derive(Copy, Clone)] pub enum SocksVersion { V4 = 4, @@ -82,11 +101,11 @@ impl SocksConnection { let mut result = Self { connection: connection.clone(), state: SocksState::ServerHello, - client_inbuf: Default::default(), - server_inbuf: Default::default(), - client_outbuf: Default::default(), - server_outbuf: Default::default(), - data_buf: Default::default(), + client_inbuf: VecDeque::default(), + server_inbuf: VecDeque::default(), + client_outbuf: VecDeque::default(), + server_outbuf: VecDeque::default(), + data_buf: VecDeque::default(), manager, version, }; @@ -235,29 +254,18 @@ impl SocksConnection { return Err("SOCKS5 connection unsuccessful.".into()); } - if atyp != SocksAddressType::Ipv4 as u8 - && atyp != SocksAddressType::Ipv6 as u8 - && atyp != SocksAddressType::DomainName as u8 - { - return Err("SOCKS5 server replied with unrecognized address type.".into()); - } - - if atyp == SocksAddressType::DomainName as u8 && self.server_inbuf.len() < 5 { - return Ok(()); - } - - if atyp == SocksAddressType::DomainName as u8 - && self.server_inbuf.len() < 7 + (self.server_inbuf[4] as usize) - { - return Ok(()); - } - - let message_length = if atyp == SocksAddressType::Ipv4 as u8 { - 10 - } else if atyp == SocksAddressType::Ipv6 as u8 { - 22 - } else { - 7 + (self.server_inbuf[4] as usize) + let message_length = match SocksAddressType::try_from(atyp)? { + SocksAddressType::DomainName => { + if self.server_inbuf.len() < 5 { + return Ok(()); + } + if self.server_inbuf.len() < 7 + (self.server_inbuf[4] as usize) { + return Ok(()); + } + 7 + (self.server_inbuf[4] as usize) + } + SocksAddressType::Ipv4 => 10, + SocksAddressType::Ipv6 => 22, }; self.server_inbuf.drain(0..message_length); @@ -277,7 +285,7 @@ impl SocksConnection { } else { SocksAddressType::Ipv6 }; - self.server_outbuf.extend(&[cmd as u8]); + self.server_outbuf.extend(&[u8::from(cmd)]); match dst_ip { IpAddr::V4(ip) => self.server_outbuf.extend(ip.octets().as_ref()), IpAddr::V6(ip) => self.server_outbuf.extend(ip.octets().as_ref()), @@ -285,7 +293,7 @@ impl SocksConnection { } DestinationHost::Hostname(host) => { self.server_outbuf - .extend(&[SocksAddressType::DomainName as u8, host.len() as u8]); + .extend(&[u8::from(SocksAddressType::DomainName), host.len() as u8]); self.server_outbuf.extend(host.as_bytes()); } }