|
|
@ -1,7 +1,7 @@ |
|
|
|
mod auth; |
|
|
|
|
|
|
|
use crate::auth::Auth; |
|
|
|
use clap::{Args, Parser}; |
|
|
|
use clap::{Args, Parser, value_parser}; |
|
|
|
use color_eyre::eyre::Result; |
|
|
|
|
|
|
|
use tokio_socks::tcp::Socks5Stream; |
|
|
@ -17,6 +17,9 @@ use hyper::server::conn::http1; |
|
|
|
use hyper::service::service_fn; |
|
|
|
use hyper::upgrade::Upgraded; |
|
|
|
use hyper::{Method, Request, Response}; |
|
|
|
use hyper::header::{HeaderValue, PROXY_AUTHENTICATE}; |
|
|
|
use base64::engine::general_purpose; |
|
|
|
use base64::Engine; |
|
|
|
|
|
|
|
use hyper_util::rt::TokioIo; |
|
|
|
use tokio::net::TcpListener; |
|
|
@ -53,6 +56,14 @@ struct Cli { |
|
|
|
/// Comma-separated list of allowed domains
|
|
|
|
#[arg(long, value_delimiter = ',')] |
|
|
|
allowed_domains: Option<Vec<String>>, |
|
|
|
|
|
|
|
/// HTTP Basic Auth credentials in the format "user:passwd"
|
|
|
|
#[arg(long)] |
|
|
|
http_basic: Option<String>, |
|
|
|
|
|
|
|
/// Disable HTTP authentication [default: 1]
|
|
|
|
#[arg(long, value_parser = value_parser ! (u8).range(0..=1), default_value_t = 1)] |
|
|
|
no_httpauth: u8, |
|
|
|
} |
|
|
|
|
|
|
|
#[tokio::main] |
|
|
@ -72,6 +83,9 @@ async fn main() -> Result<()> { |
|
|
|
let addr = SocketAddr::from((args.listen_ip, port)); |
|
|
|
let allowed_domains = args.allowed_domains; |
|
|
|
let allowed_domains = &*Box::leak(Box::new(allowed_domains)); |
|
|
|
let http_basic = args.http_basic.map(|hb| format!("Basic {}", general_purpose::STANDARD.encode(hb))); |
|
|
|
let http_basic = &*Box::leak(Box::new(http_basic)); |
|
|
|
let no_httpauth = args.no_httpauth == 1; |
|
|
|
|
|
|
|
let listener = TcpListener::bind(addr).await?; |
|
|
|
info!("Listening on http://{}", addr); |
|
|
@ -80,7 +94,7 @@ async fn main() -> Result<()> { |
|
|
|
let (stream, _) = listener.accept().await?; |
|
|
|
let io = TokioIo::new(stream); |
|
|
|
|
|
|
|
let serve_connection = service_fn(move |req| proxy(req, socks_addr, auth, allowed_domains)); |
|
|
|
let serve_connection = service_fn(move |req| proxy(req, socks_addr, auth, &http_basic, allowed_domains, no_httpauth)); |
|
|
|
|
|
|
|
tokio::task::spawn(async move { |
|
|
|
if let Err(err) = http1::Builder::new() |
|
|
@ -100,11 +114,47 @@ async fn proxy( |
|
|
|
req: Request<hyper::body::Incoming>, |
|
|
|
socks_addr: SocketAddr, |
|
|
|
auth: &'static Option<Auth>, |
|
|
|
http_basic: &Option<String>, |
|
|
|
allowed_domains: &Option<Vec<String>>, |
|
|
|
no_httpauth: bool, |
|
|
|
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> { |
|
|
|
let uri = req.uri(); |
|
|
|
let mut http_authed = false; |
|
|
|
let hm = req.headers(); |
|
|
|
|
|
|
|
if no_httpauth { |
|
|
|
http_authed = true; |
|
|
|
} else if hm.contains_key("proxy-authorization") { |
|
|
|
let config_auth = match http_basic { |
|
|
|
Some(value) => value.clone(), |
|
|
|
None => String::new(), |
|
|
|
}; |
|
|
|
let http_auth = hm.get("proxy-authorization").unwrap(); |
|
|
|
if http_auth == &HeaderValue::from_str(&config_auth).unwrap() { |
|
|
|
http_authed = true; |
|
|
|
} |
|
|
|
} else { |
|
|
|
// When the request does not contain a Proxy-Authorization header,
|
|
|
|
// send a 407 response code and a Proxy-Authenticate header
|
|
|
|
let mut response = Response::new(full("Proxy authentication required")); |
|
|
|
*response.status_mut() = http::StatusCode::PROXY_AUTHENTICATION_REQUIRED; |
|
|
|
response.headers_mut().insert( |
|
|
|
PROXY_AUTHENTICATE, |
|
|
|
HeaderValue::from_static("Basic realm=\"proxy\""), |
|
|
|
); |
|
|
|
return Ok(response); |
|
|
|
} |
|
|
|
|
|
|
|
if !http_authed { |
|
|
|
warn!("Failed to authenticate: {:?}", hm); |
|
|
|
let mut resp = Response::new(full( |
|
|
|
"Authorization failed, you are not allowed through the proxy.", |
|
|
|
)); |
|
|
|
*resp.status_mut() = http::StatusCode::FORBIDDEN; |
|
|
|
return Ok(resp); |
|
|
|
} |
|
|
|
|
|
|
|
let method = req.method(); |
|
|
|
debug!("Proxying request: {} {}", method, uri); |
|
|
|
debug!("Proxying request: {} {}", method, req.uri()); |
|
|
|
if let (Some(allowed_domains), Some(request_domain)) = (allowed_domains, req.uri().host()) { |
|
|
|
let domain = request_domain.to_owned(); |
|
|
|
if !allowed_domains.contains(&domain) { |
|
|
|