diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..6b7ea04 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,2 @@ +pub mod proxy; +pub use proxy::proxy_request; diff --git a/src/main.rs b/src/main.rs index 576111c..b410220 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,29 +1,18 @@ -mod auth; - -use crate::auth::Auth; use clap::{Args, Parser}; use color_eyre::eyre::Result; -use tokio_socks::tcp::Socks5Stream; -use tracing::{debug, info, warn}; +use sthp::proxy::auth::Auth; +use sthp::proxy_request; +use tracing::info; use tracing_subscriber::EnvFilter; use std::net::{Ipv4Addr, SocketAddr}; -use bytes::Bytes; -use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full}; -use hyper::client::conn::http1::Builder; -use hyper::server::conn::http1; -use hyper::service::service_fn; -use hyper::upgrade::Upgraded; -use hyper::{Method, Request, Response}; - -use hyper_util::rt::TokioIo; use tokio::net::TcpListener; #[derive(Debug, Args)] #[group()] -struct Auths { +struct AuthParams { /// Socks5 username #[arg(short = 'u', long, required = false)] username: String, @@ -44,7 +33,7 @@ struct Cli { listen_ip: Ipv4Addr, #[command(flatten)] - auth: Option, + auth: Option, /// Socks5 proxy address #[arg(short, long, default_value = "127.0.0.1:1080")] @@ -65,10 +54,10 @@ async fn main() -> Result<()> { let socks_addr = args.socks_address; let port = args.port; - let auth = args + let auth_details = args .auth .map(|auth| Auth::new(auth.username, auth.password)); - let auth = &*Box::leak(Box::new(auth)); + let auth_details = &*Box::leak(Box::new(auth_details)); let addr = SocketAddr::from((args.listen_ip, port)); let allowed_domains = args.allowed_domains; let allowed_domains = &*Box::leak(Box::new(allowed_domains)); @@ -78,143 +67,8 @@ async fn main() -> Result<()> { loop { let (stream, _) = listener.accept().await?; - let io = TokioIo::new(stream); - - let serve_connection = service_fn(move |req| proxy(req, socks_addr, auth, allowed_domains)); - tokio::task::spawn(async move { - if let Err(err) = http1::Builder::new() - .preserve_header_case(true) - .title_case_headers(true) - .serve_connection(io, serve_connection) - .with_upgrades() - .await - { - warn!("Failed to serve connection: {:?}", err); - } + proxy_request(stream, socks_addr, auth_details, allowed_domains).await }); } } - -async fn proxy( - req: Request, - socks_addr: SocketAddr, - auth: &'static Option, - allowed_domains: &Option>, -) -> Result>, hyper::Error> { - let uri = req.uri(); - let method = req.method(); - debug!("Proxying request: {} {}", method, uri); - if let (Some(allowed_domains), Some(request_domain)) = (allowed_domains, req.uri().host()) { - let domain = request_domain.to_owned(); - if !allowed_domains.contains(&domain) { - warn!( - "Access to domain {} is not allowed through the proxy.", - domain - ); - let mut resp = Response::new(full( - "Access to this domain is not allowed through the proxy.", - )); - *resp.status_mut() = http::StatusCode::FORBIDDEN; - return Ok(resp); - } - } - - if Method::CONNECT == req.method() { - if let Some(addr) = host_addr(req.uri()) { - tokio::task::spawn(async move { - match hyper::upgrade::on(req).await { - Ok(upgraded) => { - if let Err(e) = tunnel(upgraded, addr, socks_addr, auth).await { - warn!("server io error: {}", e); - }; - } - Err(e) => warn!("upgrade error: {}", e), - } - }); - - Ok(Response::new(empty())) - } else { - warn!("CONNECT host is not socket addr: {:?}", req.uri()); - let mut resp = Response::new(full("CONNECT must be to a socket address")); - *resp.status_mut() = http::StatusCode::BAD_REQUEST; - - Ok(resp) - } - } else { - let host = req.uri().host().expect("uri has no host"); - let port = req.uri().port_u16().unwrap_or(80); - let addr = format!("{}:{}", host, port); - - let stream = match auth { - Some(auth) => Socks5Stream::connect_with_password( - socks_addr, - addr, - &auth.username, - &auth.password, - ) - .await - .unwrap(), - None => Socks5Stream::connect(socks_addr, addr).await.unwrap(), - }; - let io = TokioIo::new(stream); - - let (mut sender, conn) = Builder::new() - .preserve_header_case(true) - .title_case_headers(true) - .handshake(io) - .await?; - tokio::task::spawn(async move { - if let Err(err) = conn.await { - warn!("Connection failed: {:?}", err); - } - }); - - let resp = sender.send_request(req).await?; - Ok(resp.map(|b| b.boxed())) - } -} - -fn host_addr(uri: &http::Uri) -> Option { - uri.authority().map(|auth| auth.to_string()) -} - -fn empty() -> BoxBody { - Empty::::new() - .map_err(|never| match never {}) - .boxed() -} - -fn full>(chunk: T) -> BoxBody { - Full::new(chunk.into()) - .map_err(|never| match never {}) - .boxed() -} - -async fn tunnel( - upgraded: Upgraded, - addr: String, - socks_addr: SocketAddr, - auth: &Option, -) -> Result<()> { - let mut stream = match auth { - Some(auth) => { - Socks5Stream::connect_with_password(socks_addr, addr, &auth.username, &auth.password) - .await? - } - None => Socks5Stream::connect(socks_addr, addr).await?, - }; - - let mut upgraded = TokioIo::new(upgraded); - - // Proxying data - let (from_client, from_server) = - tokio::io::copy_bidirectional(&mut upgraded, &mut stream).await?; - - // Print message when done - debug!( - "client wrote {} bytes and received {} bytes", - from_client, from_server - ); - Ok(()) -} diff --git a/src/auth.rs b/src/proxy/auth.rs similarity index 100% rename from src/auth.rs rename to src/proxy/auth.rs diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs new file mode 100644 index 0000000..be3b131 --- /dev/null +++ b/src/proxy/mod.rs @@ -0,0 +1,169 @@ +use auth::Auth; +use color_eyre::eyre::Result; + +pub mod auth; + +use hyper::service::service_fn; +use tokio::net::TcpStream; +use tokio_socks::tcp::Socks5Stream; +use tracing::{debug, warn}; + +use std::net::SocketAddr; + +use bytes::Bytes; +use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full}; + +use hyper::upgrade::Upgraded; +use hyper::{Method, Request, Response}; + +use hyper_util::rt::TokioIo; + +use hyper::client::conn::http1::Builder; +use hyper::server::conn::http1; + +async fn proxy( + req: Request, + socks_addr: SocketAddr, + auth: &'static Option, + allowed_domains: &Option>, +) -> Result>, hyper::Error> { + let uri = req.uri(); + let method = req.method(); + debug!("Proxying request: {} {}", method, uri); + if let (Some(allowed_domains), Some(request_domain)) = (allowed_domains, req.uri().host()) { + let domain = request_domain.to_owned(); + if !allowed_domains.contains(&domain) { + warn!( + "Access to domain {} is not allowed through the proxy.", + domain + ); + let mut resp = Response::new(full( + "Access to this domain is not allowed through the proxy.", + )); + *resp.status_mut() = http::StatusCode::FORBIDDEN; + return Ok(resp); + } + } + + if Method::CONNECT == req.method() { + if let Some(addr) = host_addr(req.uri()) { + tokio::task::spawn(async move { + match hyper::upgrade::on(req).await { + Ok(upgraded) => { + if let Err(e) = tunnel(upgraded, addr, socks_addr, auth).await { + warn!("server io error: {}", e); + }; + } + Err(e) => warn!("upgrade error: {}", e), + } + }); + + Ok(Response::new(empty())) + } else { + warn!("CONNECT host is not socket addr: {:?}", req.uri()); + let mut resp = Response::new(full("CONNECT must be to a socket address")); + *resp.status_mut() = http::StatusCode::BAD_REQUEST; + + Ok(resp) + } + } else { + let host = req.uri().host().expect("uri has no host"); + let port = req.uri().port_u16().unwrap_or(80); + let addr = format!("{}:{}", host, port); + + let stream = match auth { + Some(auth) => Socks5Stream::connect_with_password( + socks_addr, + addr, + &auth.username, + &auth.password, + ) + .await + .unwrap(), + None => Socks5Stream::connect(socks_addr, addr).await.unwrap(), + }; + let io = TokioIo::new(stream); + + let (mut sender, conn) = Builder::new() + .preserve_header_case(true) + .title_case_headers(true) + .handshake(io) + .await?; + tokio::task::spawn(async move { + if let Err(err) = conn.await { + warn!("Connection failed: {:?}", err); + } + }); + + let resp = sender.send_request(req).await?; + Ok(resp.map(|b| b.boxed())) + } +} + +fn host_addr(uri: &http::Uri) -> Option { + uri.authority().map(|auth| auth.to_string()) +} + +fn empty() -> BoxBody { + Empty::::new() + .map_err(|never| match never {}) + .boxed() +} + +fn full>(chunk: T) -> BoxBody { + Full::new(chunk.into()) + .map_err(|never| match never {}) + .boxed() +} + +async fn tunnel( + upgraded: Upgraded, + addr: String, + socks_addr: SocketAddr, + auth: &Option, +) -> Result<()> { + let mut stream = match auth { + Some(auth) => { + Socks5Stream::connect_with_password(socks_addr, addr, &auth.username, &auth.password) + .await? + } + None => Socks5Stream::connect(socks_addr, addr).await?, + }; + + let mut upgraded = TokioIo::new(upgraded); + + // Proxying data + let (from_client, from_server) = + tokio::io::copy_bidirectional(&mut upgraded, &mut stream).await?; + + // Print message when done + debug!( + "client wrote {} bytes and received {} bytes", + from_client, from_server + ); + Ok(()) +} + +pub async fn proxy_request( + stream: TcpStream, + socks_addr: SocketAddr, + auth_details: &'static Option, + allowed_domains: &'static Option>, +) { + let io = TokioIo::new(stream); + + let serve_connection = + service_fn(move |req| proxy(req, socks_addr, auth_details, allowed_domains)); + + tokio::task::spawn(async move { + if let Err(err) = http1::Builder::new() + .preserve_header_case(true) + .title_case_headers(true) + .serve_connection(io, serve_connection) + .with_upgrades() + .await + { + warn!("Failed to serve connection: {:?}", err); + } + }); +}