Browse Source

feat: add allowed domain options

pull/11/head
mostafa 2 years ago
parent
commit
66092ceab5
  1. 28
      src/main.rs

28
src/main.rs

@ -33,6 +33,10 @@ struct Cli {
/// Socks5 password
#[clap(short = 'P', long)]
password: Option<String>,
/// Comma-separated list of allowed domains
#[clap(long)]
allowed_domains: Option<String>,
}
#[tokio::main]
@ -76,9 +80,14 @@ async fn main() -> Result<()> {
};
let client: Client<SocksConnector<HttpConnector>> = hyper::Client::builder().build(connector);
let client = &*Box::leak(Box::new(client));
let allowed_domains = match args.allowed_domains {
Some(domains) => Some(domains.split(',').map(|d| d.trim().to_owned()).collect()),
None => None,
};
let allowed_domains = &*Box::leak(Box::new(allowed_domains));
let make_service = make_service_fn(move |_| async move {
Ok::<_, Infallible>(service_fn(move |req| {
proxy(req, socks_address, auth, client)
proxy(req, socks_address, auth, client, allowed_domains.clone())
}))
});
let server = Server::bind(&addr)
@ -99,8 +108,25 @@ async fn proxy(
socks_address: SocketAddr,
auth: &'static Option<Auth>,
client: &'static Client<SocksConnector<HttpConnector>>,
allowed_domains: Option<Vec<String>>,
) -> Result<Response<Body>> {
let uri = req.uri();
let method = req.method();
let headers = req.headers();
let req_str = format!("{} {} {:?}", method, uri, headers);
log::info!("Proxying request: {}", req_str);
if let Some(plain) = host_addr(req.uri()) {
if let Some(allowed_domains) = allowed_domains {
let req_domain = req.uri().host().unwrap_or("").to_owned();
if !allowed_domains.iter().any(|domain| req_domain.ends_with(domain)) {
log::warn!("Access to domain {} is not allowed through the proxy.", req_domain);
let mut resp = Response::new(Body::from("Access to this domain is not allowed through the proxy."));
*resp.status_mut() = http::StatusCode::FORBIDDEN;
return Ok(resp);
}
}
if req.method() == hyper::Method::CONNECT {
tokio::task::spawn(async move {
match hyper::upgrade::on(req).await {

Loading…
Cancel
Save