Browse Source

perf(client): lock-free hot paths, parallel PoW, per-client_id throttle

- relayPool/sessionPool: atomic.Pointer copy-on-write, drop RWMutex from pick()
- DTLS read loop caches activeLocalPeer locally to skip type-assert per packet
- solvePoW parallelised across runtime.NumCPU() workers
- vkRequestMu replaced with per-client_id throttle so distinct client_ids run in parallel
- inboundChan 2000 -> 8192, periodic drop-counter logging
- listener caches addr.String() to avoid redundant atomic.Value stores
pull/151/head
samosvalishe 2 months ago
parent
commit
5c3b6a681c
  1. 222
      client/main.go

222
client/main.go

@ -22,6 +22,7 @@ import (
neturl "net/url"
"os"
"os/signal"
"runtime"
"strconv"
"strings"
"sync"
@ -465,15 +466,45 @@ func fetchCaptchaBootstrap(ctx context.Context, redirectURI string, client tlscl
func solvePoW(powInput string, difficulty int) (string, error) {
target := strings.Repeat("0", difficulty)
for nonce := 1; nonce <= 10000000; nonce++ {
data := powInput + strconv.Itoa(nonce)
hash := sha256.Sum256([]byte(data))
hexHash := hex.EncodeToString(hash[:])
if strings.HasPrefix(hexHash, target) {
return hexHash, nil
}
const maxNonce = 10000000
workers := runtime.NumCPU()
if workers < 1 {
workers = 1
}
var (
found atomic.Bool
resultCh = make(chan string, 1)
wg sync.WaitGroup
)
for w := 0; w < workers; w++ {
wg.Add(1)
go func(start int) {
defer wg.Done()
for nonce := start; nonce <= maxNonce; nonce += workers {
if found.Load() {
return
}
data := powInput + strconv.Itoa(nonce)
hash := sha256.Sum256([]byte(data))
hexHash := hex.EncodeToString(hash[:])
if strings.HasPrefix(hexHash, target) {
if found.CompareAndSwap(false, true) {
resultCh <- hexHash
}
return
}
}
}(w + 1)
}
return "", fmt.Errorf("PoW unsolved (difficulty=%d, tried 10M nonces)", difficulty)
go func() { wg.Wait(); close(resultCh) }()
if h, ok := <-resultCh; ok {
return h, nil
}
return "", fmt.Errorf("PoW unsolved (difficulty=%d, tried %dM nonces)", difficulty, maxNonce/1000000)
}
func callCaptchaNotRobot(ctx context.Context, sessionToken, hash string, streamID int, client tlsclient.HttpClient, profile Profile, savedProfile *SavedProfile) (string, error) {
@ -775,11 +806,32 @@ func getVkCredsCached(ctx context.Context, link string, streamID int) (string, s
return user, pass, addr, nil
}
// vkClientThrottle holds per-client_id serialisation + cooldown timestamp.
// Was previously a single global mutex which forced acquires across distinct
// client_ids onto the same queue even though VK rate-limits per client_id.
type vkClientThrottle struct {
mu sync.Mutex
lastTime time.Time
}
var (
vkRequestMu sync.Mutex
globalLastVkFetchTime time.Time
vkThrottleStore = struct {
mu sync.Mutex
m map[string]*vkClientThrottle
}{m: make(map[string]*vkClientThrottle)}
)
func getVkThrottle(clientID string) *vkClientThrottle {
vkThrottleStore.mu.Lock()
defer vkThrottleStore.mu.Unlock()
t, ok := vkThrottleStore.m[clientID]
if !ok {
t = &vkClientThrottle{}
vkThrottleStore.m[clientID] = t
}
return t
}
// vkIdentity caches the captcha-gated portion of a VK auth chain (steps 1-3:
// anonym_token + getCallPreview + getAnonymousToken). Once acquired it can be
// replayed via acquireVkTurnSlot to mint independent TURN credentials, each
@ -942,16 +994,18 @@ func vkDoRequest(ctx context.Context, client tlsclient.HttpClient, profile Profi
// (steps 1-3: get_anonym_token, calls.getCallPreview, calls.getAnonymousToken).
// The result is cached and reused across many TURN slot acquisitions.
//
// Globally serialised via vkRequestMu + 3-6s cooldown to avoid VK API bans.
// Per-client_id serialised + 3-6s cooldown to avoid VK API bans. Distinct
// client_ids run in parallel.
func acquireVkIdentity(ctx context.Context, link string, streamID int, creds VKCredentials) (*vkIdentity, error) {
vkRequestMu.Lock()
defer vkRequestMu.Unlock()
throttle := getVkThrottle(creds.ClientID)
throttle.mu.Lock()
defer throttle.mu.Unlock()
minInterval := 3*time.Second + time.Duration(rand.Intn(3000))*time.Millisecond
elapsed := time.Since(globalLastVkFetchTime)
if !globalLastVkFetchTime.IsZero() && elapsed < minInterval {
elapsed := time.Since(throttle.lastTime)
if !throttle.lastTime.IsZero() && elapsed < minInterval {
wait := minInterval - elapsed
log.Printf("[STREAM %d] [VK Auth] Throttling: waiting %v to prevent rate limit...", streamID, wait.Truncate(time.Millisecond))
log.Printf("[STREAM %d] [VK Auth] Throttling client_id=%s: waiting %v...", streamID, creds.ClientID, wait.Truncate(time.Millisecond))
select {
case <-ctx.Done():
return nil, ctx.Err()
@ -959,7 +1013,7 @@ func acquireVkIdentity(ctx context.Context, link string, streamID int, creds VKC
}
}
defer func() {
globalLastVkFetchTime = time.Now()
throttle.lastTime = time.Now()
}()
if time.Now().Unix() < globalCaptchaLockout.Load() {
@ -1655,20 +1709,31 @@ func oneDtlsConnection(ctx context.Context, peer *net.UDPAddr, listenConn net.Pa
defer wg.Done()
defer dtlscancel()
buf := make([]byte, 1600)
var cachedAddr net.Addr
var cachedPtr any
for {
n, err1 := dtlsConn.Read(buf)
if err1 != nil {
return
}
// Send back to the active WG client
if peerAddr := activeLocalPeer.Load(); peerAddr != nil {
if addr, ok := peerAddr.(net.Addr); ok {
if _, err := listenConn.WriteTo(buf[:n], addr); err != nil {
log.Printf("[STREAM %d] failed to forward packet to local peer: %v", streamID, err)
}
// Send back to the active WG client. Cache addr locally — only
// re-resolve when atomic.Value pointer changes (rare).
peerAddr := activeLocalPeer.Load()
if peerAddr == nil {
continue
}
if peerAddr != cachedPtr {
if a, ok := peerAddr.(net.Addr); ok {
cachedAddr = a
cachedPtr = peerAddr
} else {
continue
}
}
if _, err := listenConn.WriteTo(buf[:n], cachedAddr); err != nil {
log.Printf("[STREAM %d] failed to forward packet to local peer: %v", streamID, err)
}
}
}()
@ -1832,28 +1897,40 @@ func dialTurn(ctx context.Context, useUDP bool, turnServerAddr string, turnServe
}
// relayPool is a concurrent ring of live relay PacketConns. Reads (pick) are
// non-blocking and lock-free on the hot path; mutation (add) is rare.
// fully lock-free via atomic.Pointer to a snapshot slice (copy-on-write).
// add is rare and pays the alloc cost.
type relayPool struct {
mu sync.RWMutex
relays []net.PacketConn
relays atomic.Pointer[[]net.PacketConn]
addMu sync.Mutex
counter atomic.Uint64
}
func (p *relayPool) add(r net.PacketConn) {
p.mu.Lock()
p.relays = append(p.relays, r)
p.mu.Unlock()
p.addMu.Lock()
defer p.addMu.Unlock()
cur := p.relays.Load()
var next []net.PacketConn
if cur != nil {
next = make([]net.PacketConn, len(*cur)+1)
copy(next, *cur)
next[len(*cur)] = r
} else {
next = []net.PacketConn{r}
}
p.relays.Store(&next)
}
func (p *relayPool) pick() net.PacketConn {
p.mu.RLock()
defer p.mu.RUnlock()
n := len(p.relays)
cur := p.relays.Load()
if cur == nil {
return nil
}
n := len(*cur)
if n == 0 {
return nil
}
idx := int(p.counter.Add(1)-1) % n
return p.relays[idx]
return (*cur)[idx]
}
func oneTurnConnection(ctx context.Context, turnParams *turnParams, peer *net.UDPAddr, conn2 net.PacketConn, streamID int, c chan<- error) {
@ -2217,9 +2294,24 @@ func main() {
}
// Shared Worker Pool Queue for Aggregation
inboundChan := make(chan *UDPPacket, 2000)
inboundChan := make(chan *UDPPacket, 8192)
var droppedPkts atomic.Uint64
go func() {
dropTicker := time.NewTicker(5 * time.Second)
defer dropTicker.Stop()
var lastDropped uint64
for range dropTicker.C {
cur := droppedPkts.Load()
if cur != lastDropped {
log.Printf("[inbound] dropped %d pkts since start (delta=%d) — queue saturated", cur, cur-lastDropped)
lastDropped = cur
}
}
}()
go func() {
var lastAddrStr string
for {
pktIface := packetPool.Get()
pkt, ok := pktIface.(*UDPPacket)
@ -2232,16 +2324,12 @@ func main() {
return
}
// Save the local WireGuard peer address
current := activeLocalPeer.Load()
if current == nil {
activeLocalPeer.Store(addr)
} else if addrStr, ok := current.(net.Addr); ok {
if addrStr.String() != addr.String() {
activeLocalPeer.Store(addr)
}
} else {
// Save local WireGuard peer addr; cache string to avoid repeated
// type-assert + String() in hot path.
s := addr.String()
if s != lastAddrStr {
activeLocalPeer.Store(addr)
lastAddrStr = s
}
pkt.N = nRead
@ -2249,7 +2337,7 @@ func main() {
select {
case inboundChan <- pkt:
default:
// Drop the packet only if the global queue is completely full
droppedPkts.Add(1)
packetPool.Put(pkt)
}
}
@ -2297,45 +2385,57 @@ func main() {
wg1.Wait()
}
// sessionPool manages a pool of smux sessions for round-robin TCP distribution.
// sessionPool manages smux sessions for round-robin TCP distribution.
// Lock-free reads via atomic.Pointer copy-on-write snapshot.
type sessionPool struct {
mu sync.RWMutex
sessions []*smux.Session
sessions atomic.Pointer[[]*smux.Session]
mu sync.Mutex
counter atomic.Uint64
}
func (p *sessionPool) snapshot() []*smux.Session {
cur := p.sessions.Load()
if cur == nil {
return nil
}
return *cur
}
func (p *sessionPool) add(s *smux.Session) {
p.mu.Lock()
p.sessions = append(p.sessions, s)
p.mu.Unlock()
defer p.mu.Unlock()
cur := p.snapshot()
next := make([]*smux.Session, len(cur)+1)
copy(next, cur)
next[len(cur)] = s
p.sessions.Store(&next)
}
func (p *sessionPool) remove(s *smux.Session) {
p.mu.Lock()
for i, sess := range p.sessions {
if sess == s {
p.sessions = append(p.sessions[:i], p.sessions[i+1:]...)
break
defer p.mu.Unlock()
cur := p.snapshot()
next := make([]*smux.Session, 0, len(cur))
for _, sess := range cur {
if sess != s {
next = append(next, sess)
}
}
p.mu.Unlock()
p.sessions.Store(&next)
}
func (p *sessionPool) pick() *smux.Session {
p.mu.RLock()
defer p.mu.RUnlock()
n := len(p.sessions)
cur := p.snapshot()
n := len(cur)
if n == 0 {
return nil
}
idx := p.counter.Add(1) % uint64(n)
return p.sessions[idx]
return cur[idx]
}
func (p *sessionPool) count() int {
p.mu.RLock()
defer p.mu.RUnlock()
return len(p.sessions)
return len(p.snapshot())
}
// runVLESSMode implements TCP forwarding with round-robin across N TURN sessions.

Loading…
Cancel
Save