diff --git a/src/main.rs b/src/main.rs index 964c186..8dc0bb7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -33,6 +33,10 @@ struct Cli { /// Socks5 password #[clap(short = 'P', long)] password: Option, + + /// Comma-separated list of allowed domains + #[clap(long)] + allowed_domains: Option, } #[tokio::main] @@ -76,9 +80,14 @@ async fn main() -> Result<()> { }; let client: Client> = hyper::Client::builder().build(connector); let client = &*Box::leak(Box::new(client)); + let allowed_domains = match args.allowed_domains { + Some(domains) => Some(domains.split(',').map(|d| d.trim().to_owned()).collect()), + None => None, + }; + let allowed_domains = &*Box::leak(Box::new(allowed_domains)); let make_service = make_service_fn(move |_| async move { Ok::<_, Infallible>(service_fn(move |req| { - proxy(req, socks_address, auth, client) + proxy(req, socks_address, auth, client, allowed_domains.clone()) })) }); let server = Server::bind(&addr) @@ -99,8 +108,25 @@ async fn proxy( socks_address: SocketAddr, auth: &'static Option, client: &'static Client>, + allowed_domains: Option>, ) -> Result> { + let uri = req.uri(); + let method = req.method(); + let headers = req.headers(); + let req_str = format!("{} {} {:?}", method, uri, headers); + log::info!("Proxying request: {}", req_str); + if let Some(plain) = host_addr(req.uri()) { + if let Some(allowed_domains) = allowed_domains { + let req_domain = req.uri().host().unwrap_or("").to_owned(); + if !allowed_domains.iter().any(|domain| req_domain.ends_with(domain)) { + log::warn!("Access to domain {} is not allowed through the proxy.", req_domain); + let mut resp = Response::new(Body::from("Access to this domain is not allowed through the proxy.")); + *resp.status_mut() = http::StatusCode::FORBIDDEN; + return Ok(resp); + } + } + if req.method() == hyper::Method::CONNECT { tokio::task::spawn(async move { match hyper::upgrade::on(req).await {