From 88423039c643d010ef36723aa251e53932624397 Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Sun, 20 Apr 2025 19:56:36 +0800 Subject: [PATCH] make TASK_COUNT as local task_count variable --- src/lib.rs | 39 ++++++++++++++++++++------------------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 06e3697..5f77672 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -64,9 +64,6 @@ pub mod win_svc; const DNS_PORT: u16 = 53; -static TASK_COUNT: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(0); -use std::sync::atomic::Ordering::Relaxed; - #[allow(unused)] #[derive(Hash, Copy, Clone, Eq, PartialEq, Debug)] #[cfg_attr( @@ -224,11 +221,11 @@ where let socket_queue = None; use socks5_impl::protocol::Version::{V4, V5}; - let mgr = match args.proxy.proxy_type { - ProxyType::Socks5 => Arc::new(SocksProxyManager::new(server_addr, V5, key)) as Arc, - ProxyType::Socks4 => Arc::new(SocksProxyManager::new(server_addr, V4, key)) as Arc, - ProxyType::Http => Arc::new(HttpManager::new(server_addr, key)) as Arc, - ProxyType::None => Arc::new(NoProxyManager::new()) as Arc, + let mgr: Arc = match args.proxy.proxy_type { + ProxyType::Socks5 => Arc::new(SocksProxyManager::new(server_addr, V5, key)), + ProxyType::Socks4 => Arc::new(SocksProxyManager::new(server_addr, V4, key)), + ProxyType::Http => Arc::new(HttpManager::new(server_addr, key)), + ProxyType::None => Arc::new(NoProxyManager::new()), }; let mut ipstack_config = ipstack::IpStackConfig::default(); @@ -256,7 +253,11 @@ where client }); + let task_count = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0)); + use std::sync::atomic::Ordering::Relaxed; + loop { + let task_count = task_count.clone(); let virtual_dns = virtual_dns.clone(); let ip_stack_stream = tokio::select! { _ = shutdown_token.cancelled() => { @@ -270,7 +271,7 @@ where let max_sessions = args.max_sessions; match ip_stack_stream { IpStackStream::Tcp(tcp) => { - if TASK_COUNT.load(Relaxed) >= max_sessions { + if task_count.load(Relaxed) >= max_sessions { if args.exit_on_fatal_error { log::info!("Too many sessions that over {max_sessions}, exiting..."); break; @@ -278,7 +279,7 @@ where log::warn!("Too many sessions that over {max_sessions}, dropping new session"); continue; } - log::trace!("Session count {}", TASK_COUNT.fetch_add(1, Relaxed).saturating_add(1)); + log::trace!("Session count {}", task_count.fetch_add(1, Relaxed).saturating_add(1)); let info = SessionInfo::new(tcp.local_addr(), tcp.peer_addr(), IpProtocol::Tcp); let domain_name = if let Some(virtual_dns) = &virtual_dns { let mut virtual_dns = virtual_dns.lock().await; @@ -293,11 +294,11 @@ where if let Err(err) = handle_tcp_session(tcp, proxy_handler, socket_queue).await { log::error!("{} error \"{}\"", info, err); } - log::trace!("Session count {}", TASK_COUNT.fetch_sub(1, Relaxed).saturating_sub(1)); + log::trace!("Session count {}", task_count.fetch_sub(1, Relaxed).saturating_sub(1)); }); } IpStackStream::Udp(udp) => { - if TASK_COUNT.load(Relaxed) >= max_sessions { + if task_count.load(Relaxed) >= max_sessions { if args.exit_on_fatal_error { log::info!("Too many sessions that over {max_sessions}, exiting..."); break; @@ -305,11 +306,11 @@ where log::warn!("Too many sessions that over {max_sessions}, dropping new session"); continue; } - log::trace!("Session count {}", TASK_COUNT.fetch_add(1, Relaxed).saturating_add(1)); + log::trace!("Session count {}", task_count.fetch_add(1, Relaxed).saturating_add(1)); let mut info = SessionInfo::new(udp.local_addr(), udp.peer_addr(), IpProtocol::Udp); if info.dst.port() == DNS_PORT { if is_private_ip(info.dst.ip()) { - info.dst.set_ip(dns_addr); + info.dst.set_ip(dns_addr); // !!! Here we change the destination address to remote DNS server!!! } if args.dns == ArgDns::OverTcp { info.protocol = IpProtocol::Tcp; @@ -319,7 +320,7 @@ where if let Err(err) = handle_dns_over_tcp_session(udp, proxy_handler, socket_queue, ipv6_enabled).await { log::error!("{} error \"{}\"", info, err); } - log::trace!("Session count {}", TASK_COUNT.fetch_sub(1, Relaxed).saturating_sub(1)); + log::trace!("Session count {}", task_count.fetch_sub(1, Relaxed).saturating_sub(1)); }); continue; } @@ -330,7 +331,7 @@ where log::error!("{} error \"{}\"", info, err); } } - log::trace!("Session count {}", TASK_COUNT.fetch_sub(1, Relaxed).saturating_sub(1)); + log::trace!("Session count {}", task_count.fetch_sub(1, Relaxed).saturating_sub(1)); }); continue; } @@ -361,7 +362,7 @@ where if let Err(e) = handle_udp_gateway_session(udp, udpgw, &dst_addr, proxy_handler, queue, ipv6_enabled).await { log::info!("Ending {} with \"{}\"", info, e); } - log::trace!("Session count {}", TASK_COUNT.fetch_sub(1, Relaxed).saturating_sub(1)); + log::trace!("Session count {}", task_count.fetch_sub(1, Relaxed).saturating_sub(1)); }); continue; } @@ -373,7 +374,7 @@ where if let Err(err) = handle_udp_associate_session(udp, ty, proxy_handler, socket_queue, ipv6_enabled).await { log::info!("Ending {} with \"{}\"", info, err); } - log::trace!("Session count {}", TASK_COUNT.fetch_sub(1, Relaxed).saturating_sub(1)); + log::trace!("Session count {}", task_count.fetch_sub(1, Relaxed).saturating_sub(1)); }); } Err(e) => { @@ -392,7 +393,7 @@ where } } } - Ok(TASK_COUNT.load(Relaxed)) + Ok(task_count.load(Relaxed)) } async fn handle_virtual_dns_session(mut udp: IpStackUdpStream, dns: Arc>) -> crate::Result<()> {