diff --git a/Cargo.lock b/Cargo.lock index 4311915..56cb959 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -108,6 +108,12 @@ version = "0.21.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + [[package]] name = "bitflags" version = "1.3.2" @@ -956,7 +962,7 @@ version = "0.11.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dd67538700a17451e7cba03ac727fb961abb7607553461627b97de0b89cf4a62" dependencies = [ - "base64", + "base64 0.21.7", "bytes", "encoding_rs", "futures-core", @@ -1024,7 +1030,7 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c" dependencies = [ - "base64", + "base64 0.21.7", ] [[package]] @@ -1184,6 +1190,7 @@ dependencies = [ name = "sthp" version = "0.5.0-alpha1" dependencies = [ + "base64 0.22.1", "bytes", "clap", "color-eyre", diff --git a/Cargo.toml b/Cargo.toml index ec6b68f..13335ad 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,7 +22,8 @@ http-body-util = "0.1.0-rc.2" tracing = "0.1.37" tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } hyper-util = { version="0.1.5",features = ["tokio"] } +base64 = "0.22.1" [dev-dependencies] socksprox = { version = "0.1" } -reqwest = { version = "0.11" } \ No newline at end of file +reqwest = { version = "0.11" } diff --git a/src/main.rs b/src/main.rs index 45efd45..2a3ebbf 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,6 +8,9 @@ use tracing_subscriber::EnvFilter; use std::net::{Ipv4Addr, SocketAddr}; +use base64::engine::general_purpose; +use base64::Engine; +use hyper::header::HeaderValue; use tokio::net::TcpListener; #[derive(Debug, Args)] @@ -42,6 +45,10 @@ struct Cli { /// Comma-separated list of allowed domains #[arg(long, value_delimiter = ',')] allowed_domains: Option>, + + /// HTTP Basic Auth credentials in the format "user:passwd" + #[arg(long)] + http_basic: Option, } #[tokio::main] @@ -61,18 +68,26 @@ async fn main() -> Result<()> { let addr = SocketAddr::from((args.listen_ip, port)); let allowed_domains = args.allowed_domains; let allowed_domains = &*Box::leak(Box::new(allowed_domains)); + let http_basic = args + .http_basic + .map(|hb| format!("Basic {}", general_purpose::STANDARD.encode(hb))) + .map(|hb| HeaderValue::from_str(&hb)) + .transpose()?; + let http_basic = &*Box::leak(Box::new(http_basic)); let listener = TcpListener::bind(addr).await?; info!("Listening on http://{}", addr); loop { - let (stream, _) = listener.accept().await?; + let (stream, client_addr) = listener.accept().await?; tokio::task::spawn(async move { if let Err(e) = proxy_request( stream, + client_addr, socks_addr, auth_details.as_ref(), allowed_domains.as_ref(), + http_basic.as_ref(), ) .await { diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index b687a09..18562dd 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -3,6 +3,7 @@ use color_eyre::eyre::Result; pub mod auth; +use hyper::header::{HeaderValue, PROXY_AUTHENTICATE}; use hyper::service::service_fn; use tokio::net::TcpStream; use tokio_socks::tcp::Socks5Stream; @@ -23,10 +24,42 @@ use hyper::server::conn::http1; async fn proxy( req: Request, + client_addr: SocketAddr, socks_addr: SocketAddr, auth: Option<&'static Auth>, allowed_domains: Option<&'static Vec>, + basic_http_header: Option<&HeaderValue>, ) -> Result>> { + let mut authenticated = false; + let hm = req.headers(); + + if let Some(basic_http_header) = basic_http_header { + let Some(http_auth) = hm.get("proxy-authorization") else { + // When the request does not contain a Proxy-Authorization header, + // send a 407 response code and a Proxy-Authenticate header + let mut response = Response::new(full("Proxy Authentication Required")); + *response.status_mut() = hyper::StatusCode::PROXY_AUTHENTICATION_REQUIRED; + response.headers_mut().insert( + PROXY_AUTHENTICATE, + HeaderValue::from_static("Basic realm=\"proxy\""), + ); + return Ok(response); + }; + if http_auth == basic_http_header { + authenticated = true; + } + } else { + authenticated = true; + } + + if !authenticated { + warn!("Failed auth attempt from: {}", client_addr); + // http response code reference taken from tinyproxy + let mut resp = Response::new(full("Unauthorized")); + *resp.status_mut() = hyper::StatusCode::UNAUTHORIZED; + return Ok(resp); + } + let uri = req.uri(); let method = req.method(); debug!("Proxying request: {} {}", method, uri); @@ -146,14 +179,24 @@ async fn tunnel( pub async fn proxy_request( stream: TcpStream, + client_addr: SocketAddr, socks_addr: SocketAddr, auth_details: Option<&'static Auth>, allowed_domains: Option<&'static Vec>, + basic_http_header: Option<&'static HeaderValue>, ) -> color_eyre::Result<()> { let io = TokioIo::new(stream); - let serve_connection = - service_fn(move |req| proxy(req, socks_addr, auth_details, allowed_domains)); + let serve_connection = service_fn(move |req| { + proxy( + req, + client_addr, + socks_addr, + auth_details, + allowed_domains, + basic_http_header, + ) + }); tokio::task::spawn(async move { if let Err(err) = http1::Builder::new() diff --git a/tests/proxy.rs b/tests/proxy.rs index 61dcb31..136d8a1 100644 --- a/tests/proxy.rs +++ b/tests/proxy.rs @@ -35,7 +35,7 @@ async fn simple_test() -> Result<()> { let addr = listener.local_addr()?; let _ = tokio::task::spawn(async move { let (stream, proxy_addr) = listener.accept().await?; - proxy_request(stream, socks_proxy_addr, None, None).await?; + proxy_request(stream, proxy_addr, socks_proxy_addr, None, None, None).await?; eprintln!("new connection from: {:?}", proxy_addr); Ok::<_, color_eyre::eyre::Error>(()) });