diff --git a/src/bin/main.rs b/src/bin/main.rs index 0c82309..d7174d9 100644 --- a/src/bin/main.rs +++ b/src/bin/main.rs @@ -37,6 +37,7 @@ async fn main_async(args: Args) -> Result<(), BoxError> { let shutdown_token = tokio_util::sync::CancellationToken::new(); let main_loop_handle = tokio::spawn({ + let args = args.clone(); let shutdown_token = shutdown_token.clone(); async move { #[cfg(target_os = "linux")] @@ -44,7 +45,7 @@ async fn main_async(args: Args) -> Result<(), BoxError> { if let Err(err) = namespace_proxy_main(args, shutdown_token).await { log::error!("namespace proxy error: {}", err); } - return; + return Ok(0); } unsafe extern "C" fn traffic_cb(status: *const tun2proxy::TrafficStatus, _: *mut std::ffi::c_void) { @@ -53,9 +54,11 @@ async fn main_async(args: Args) -> Result<(), BoxError> { } unsafe { tun2proxy::tun2proxy_set_traffic_status_callback(1, Some(traffic_cb), std::ptr::null_mut()) }; - if let Err(err) = tun2proxy::general_run_async(args, tun::DEFAULT_MTU, cfg!(target_os = "macos"), shutdown_token).await { - log::error!("main loop error: {}", err); + let ret = tun2proxy::general_run_async(args, tun::DEFAULT_MTU, cfg!(target_os = "macos"), shutdown_token).await; + if let Err(err) = &ret { + log::error!("main loop error: {err}"); } + ret } }); @@ -68,13 +71,19 @@ async fn main_async(args: Args) -> Result<(), BoxError> { }) .await; - main_loop_handle.await?; + let tasks = main_loop_handle.await??; if ctrlc_fired.load(std::sync::atomic::Ordering::SeqCst) { log::info!("Ctrl-C fired, waiting the handler to finish..."); ctrlc_handel.await.map_err(|err| err.to_string())?; } + if args.exit_on_fatal_error && tasks >= args.max_sessions { + // Because `main_async` function perhaps stuck in `await` state, so we need to exit the process forcefully + log::info!("Internal fatal error, max sessions reached ({tasks}/{})", args.max_sessions); + std::process::exit(-1); + } + Ok(()) } diff --git a/src/general_api.rs b/src/general_api.rs index 2838271..9e7dd3c 100644 --- a/src/general_api.rs +++ b/src/general_api.rs @@ -120,11 +120,18 @@ pub fn general_run_for_api(args: Args, tun_mtu: u16, packet_information: bool) - return -3; }; match rt.block_on(async move { - if let Err(err) = general_run_async(args, tun_mtu, packet_information, shutdown_token).await { - log::error!("main loop error: {}", err); - return Err(err); + let ret = general_run_async(args.clone(), tun_mtu, packet_information, shutdown_token).await; + match &ret { + Ok(sessions) => { + if args.exit_on_fatal_error && *sessions >= args.max_sessions { + log::error!("Forced exit due to max sessions reached ({sessions}/{})", args.max_sessions); + std::process::exit(-1); + } + log::debug!("tun2proxy exited normally, current sessions: {sessions}"); + } + Err(err) => log::error!("main loop error: {err}"), } - Ok(()) + ret }) { Ok(_) => 0, Err(e) => { @@ -140,7 +147,7 @@ pub async fn general_run_async( tun_mtu: u16, _packet_information: bool, shutdown_token: tokio_util::sync::CancellationToken, -) -> std::io::Result<()> { +) -> std::io::Result { let mut tun_config = tun::Configuration::default(); #[cfg(any(target_os = "linux", target_os = "windows", target_os = "macos"))] diff --git a/src/lib.rs b/src/lib.rs index c62e08b..06e3697 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -64,7 +64,7 @@ pub mod win_svc; const DNS_PORT: u16 = 53; -static TASK_COUNT: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0); +static TASK_COUNT: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(0); use std::sync::atomic::Ordering::Relaxed; #[allow(unused)] @@ -154,7 +154,9 @@ async fn create_udp_stream(socket_queue: &Option>, peer: Socket /// * `mtu` - The MTU of the network device /// * `args` - The arguments to use /// * `shutdown_token` - The token to exit the server -pub async fn run(device: D, mtu: u16, args: Args, shutdown_token: CancellationToken) -> crate::Result<()> +/// # Returns +/// * The number of sessions while exiting +pub async fn run(device: D, mtu: u16, args: Args, shutdown_token: CancellationToken) -> crate::Result where D: AsyncRead + AsyncWrite + Unpin + Send + 'static, { @@ -265,10 +267,10 @@ where ip_stack_stream? } }; - let max_sessions = args.max_sessions as u64; + let max_sessions = args.max_sessions; match ip_stack_stream { IpStackStream::Tcp(tcp) => { - if TASK_COUNT.load(Relaxed) > max_sessions { + if TASK_COUNT.load(Relaxed) >= max_sessions { if args.exit_on_fatal_error { log::info!("Too many sessions that over {max_sessions}, exiting..."); break; @@ -276,7 +278,7 @@ where log::warn!("Too many sessions that over {max_sessions}, dropping new session"); continue; } - log::trace!("Session count {}", TASK_COUNT.fetch_add(1, Relaxed) + 1); + log::trace!("Session count {}", TASK_COUNT.fetch_add(1, Relaxed).saturating_add(1)); let info = SessionInfo::new(tcp.local_addr(), tcp.peer_addr(), IpProtocol::Tcp); let domain_name = if let Some(virtual_dns) = &virtual_dns { let mut virtual_dns = virtual_dns.lock().await; @@ -291,11 +293,11 @@ where if let Err(err) = handle_tcp_session(tcp, proxy_handler, socket_queue).await { log::error!("{} error \"{}\"", info, err); } - log::trace!("Session count {}", TASK_COUNT.fetch_sub(1, Relaxed) - 1); + log::trace!("Session count {}", TASK_COUNT.fetch_sub(1, Relaxed).saturating_sub(1)); }); } IpStackStream::Udp(udp) => { - if TASK_COUNT.load(Relaxed) > max_sessions { + if TASK_COUNT.load(Relaxed) >= max_sessions { if args.exit_on_fatal_error { log::info!("Too many sessions that over {max_sessions}, exiting..."); break; @@ -303,7 +305,7 @@ where log::warn!("Too many sessions that over {max_sessions}, dropping new session"); continue; } - log::trace!("Session count {}", TASK_COUNT.fetch_add(1, Relaxed) + 1); + log::trace!("Session count {}", TASK_COUNT.fetch_add(1, Relaxed).saturating_add(1)); let mut info = SessionInfo::new(udp.local_addr(), udp.peer_addr(), IpProtocol::Udp); if info.dst.port() == DNS_PORT { if is_private_ip(info.dst.ip()) { @@ -317,7 +319,7 @@ where if let Err(err) = handle_dns_over_tcp_session(udp, proxy_handler, socket_queue, ipv6_enabled).await { log::error!("{} error \"{}\"", info, err); } - log::trace!("Session count {}", TASK_COUNT.fetch_sub(1, Relaxed) - 1); + log::trace!("Session count {}", TASK_COUNT.fetch_sub(1, Relaxed).saturating_sub(1)); }); continue; } @@ -328,7 +330,7 @@ where log::error!("{} error \"{}\"", info, err); } } - log::trace!("Session count {}", TASK_COUNT.fetch_sub(1, Relaxed) - 1); + log::trace!("Session count {}", TASK_COUNT.fetch_sub(1, Relaxed).saturating_sub(1)); }); continue; } @@ -359,7 +361,7 @@ where if let Err(e) = handle_udp_gateway_session(udp, udpgw, &dst_addr, proxy_handler, queue, ipv6_enabled).await { log::info!("Ending {} with \"{}\"", info, e); } - log::trace!("Session count {}", TASK_COUNT.fetch_sub(1, Relaxed) - 1); + log::trace!("Session count {}", TASK_COUNT.fetch_sub(1, Relaxed).saturating_sub(1)); }); continue; } @@ -371,7 +373,7 @@ where if let Err(err) = handle_udp_associate_session(udp, ty, proxy_handler, socket_queue, ipv6_enabled).await { log::info!("Ending {} with \"{}\"", info, err); } - log::trace!("Session count {}", TASK_COUNT.fetch_sub(1, Relaxed) - 1); + log::trace!("Session count {}", TASK_COUNT.fetch_sub(1, Relaxed).saturating_sub(1)); }); } Err(e) => { @@ -390,7 +392,7 @@ where } } } - Ok(()) + Ok(TASK_COUNT.load(Relaxed)) } async fn handle_virtual_dns_session(mut udp: IpStackUdpStream, dns: Arc>) -> crate::Result<()> { diff --git a/src/win_svc.rs b/src/win_svc.rs index a4090e3..5ee416e 100644 --- a/src/win_svc.rs +++ b/src/win_svc.rs @@ -78,8 +78,16 @@ fn run_service(_arguments: Vec) -> Result<(), crate::BoxErro } unsafe { crate::tun2proxy_set_traffic_status_callback(1, Some(traffic_cb), std::ptr::null_mut()) }; - if let Err(err) = crate::general_run_async(args, tun::DEFAULT_MTU, false, shutdown_token).await { - log::error!("main loop error: {}", err); + let ret = crate::general_run_async(args.clone(), tun::DEFAULT_MTU, false, shutdown_token).await; + match &ret { + Ok(sessions) => { + if args.exit_on_fatal_error && *sessions >= args.max_sessions { + log::error!("Forced exit due to max sessions reached ({sessions}/{})", args.max_sessions); + std::process::exit(-1); + } + log::debug!("tun2proxy exited normally, current sessions: {sessions}"); + } + Err(err) => log::error!("main loop error: {err}"), } Ok::<(), crate::Error>(()) })?;