Browse Source

perf(client): lock-free hot paths, parallel PoW, per-client_id throttle

- relayPool/sessionPool: atomic.Pointer copy-on-write, drop RWMutex from pick()
- DTLS read loop caches activeLocalPeer locally to skip type-assert per packet
- solvePoW parallelised across runtime.NumCPU() workers
- vkRequestMu replaced with per-client_id throttle so distinct client_ids run in parallel
- inboundChan 2000 -> 8192, periodic drop-counter logging
- listener caches addr.String() to avoid redundant atomic.Value stores
pull/151/head
samosvalishe 2 months ago
parent
commit
5c3b6a681c
  1. 222
      client/main.go

222
client/main.go

@ -22,6 +22,7 @@ import (
neturl "net/url" neturl "net/url"
"os" "os"
"os/signal" "os/signal"
"runtime"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@ -465,15 +466,45 @@ func fetchCaptchaBootstrap(ctx context.Context, redirectURI string, client tlscl
func solvePoW(powInput string, difficulty int) (string, error) { func solvePoW(powInput string, difficulty int) (string, error) {
target := strings.Repeat("0", difficulty) target := strings.Repeat("0", difficulty)
for nonce := 1; nonce <= 10000000; nonce++ { const maxNonce = 10000000
data := powInput + strconv.Itoa(nonce) workers := runtime.NumCPU()
hash := sha256.Sum256([]byte(data)) if workers < 1 {
hexHash := hex.EncodeToString(hash[:]) workers = 1
if strings.HasPrefix(hexHash, target) { }
return hexHash, nil
} var (
found atomic.Bool
resultCh = make(chan string, 1)
wg sync.WaitGroup
)
for w := 0; w < workers; w++ {
wg.Add(1)
go func(start int) {
defer wg.Done()
for nonce := start; nonce <= maxNonce; nonce += workers {
if found.Load() {
return
}
data := powInput + strconv.Itoa(nonce)
hash := sha256.Sum256([]byte(data))
hexHash := hex.EncodeToString(hash[:])
if strings.HasPrefix(hexHash, target) {
if found.CompareAndSwap(false, true) {
resultCh <- hexHash
}
return
}
}
}(w + 1)
} }
return "", fmt.Errorf("PoW unsolved (difficulty=%d, tried 10M nonces)", difficulty)
go func() { wg.Wait(); close(resultCh) }()
if h, ok := <-resultCh; ok {
return h, nil
}
return "", fmt.Errorf("PoW unsolved (difficulty=%d, tried %dM nonces)", difficulty, maxNonce/1000000)
} }
func callCaptchaNotRobot(ctx context.Context, sessionToken, hash string, streamID int, client tlsclient.HttpClient, profile Profile, savedProfile *SavedProfile) (string, error) { func callCaptchaNotRobot(ctx context.Context, sessionToken, hash string, streamID int, client tlsclient.HttpClient, profile Profile, savedProfile *SavedProfile) (string, error) {
@ -775,11 +806,32 @@ func getVkCredsCached(ctx context.Context, link string, streamID int) (string, s
return user, pass, addr, nil return user, pass, addr, nil
} }
// vkClientThrottle holds per-client_id serialisation + cooldown timestamp.
// Was previously a single global mutex which forced acquires across distinct
// client_ids onto the same queue even though VK rate-limits per client_id.
type vkClientThrottle struct {
mu sync.Mutex
lastTime time.Time
}
var ( var (
vkRequestMu sync.Mutex vkThrottleStore = struct {
globalLastVkFetchTime time.Time mu sync.Mutex
m map[string]*vkClientThrottle
}{m: make(map[string]*vkClientThrottle)}
) )
func getVkThrottle(clientID string) *vkClientThrottle {
vkThrottleStore.mu.Lock()
defer vkThrottleStore.mu.Unlock()
t, ok := vkThrottleStore.m[clientID]
if !ok {
t = &vkClientThrottle{}
vkThrottleStore.m[clientID] = t
}
return t
}
// vkIdentity caches the captcha-gated portion of a VK auth chain (steps 1-3: // vkIdentity caches the captcha-gated portion of a VK auth chain (steps 1-3:
// anonym_token + getCallPreview + getAnonymousToken). Once acquired it can be // anonym_token + getCallPreview + getAnonymousToken). Once acquired it can be
// replayed via acquireVkTurnSlot to mint independent TURN credentials, each // replayed via acquireVkTurnSlot to mint independent TURN credentials, each
@ -942,16 +994,18 @@ func vkDoRequest(ctx context.Context, client tlsclient.HttpClient, profile Profi
// (steps 1-3: get_anonym_token, calls.getCallPreview, calls.getAnonymousToken). // (steps 1-3: get_anonym_token, calls.getCallPreview, calls.getAnonymousToken).
// The result is cached and reused across many TURN slot acquisitions. // The result is cached and reused across many TURN slot acquisitions.
// //
// Globally serialised via vkRequestMu + 3-6s cooldown to avoid VK API bans. // Per-client_id serialised + 3-6s cooldown to avoid VK API bans. Distinct
// client_ids run in parallel.
func acquireVkIdentity(ctx context.Context, link string, streamID int, creds VKCredentials) (*vkIdentity, error) { func acquireVkIdentity(ctx context.Context, link string, streamID int, creds VKCredentials) (*vkIdentity, error) {
vkRequestMu.Lock() throttle := getVkThrottle(creds.ClientID)
defer vkRequestMu.Unlock() throttle.mu.Lock()
defer throttle.mu.Unlock()
minInterval := 3*time.Second + time.Duration(rand.Intn(3000))*time.Millisecond minInterval := 3*time.Second + time.Duration(rand.Intn(3000))*time.Millisecond
elapsed := time.Since(globalLastVkFetchTime) elapsed := time.Since(throttle.lastTime)
if !globalLastVkFetchTime.IsZero() && elapsed < minInterval { if !throttle.lastTime.IsZero() && elapsed < minInterval {
wait := minInterval - elapsed wait := minInterval - elapsed
log.Printf("[STREAM %d] [VK Auth] Throttling: waiting %v to prevent rate limit...", streamID, wait.Truncate(time.Millisecond)) log.Printf("[STREAM %d] [VK Auth] Throttling client_id=%s: waiting %v...", streamID, creds.ClientID, wait.Truncate(time.Millisecond))
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil, ctx.Err() return nil, ctx.Err()
@ -959,7 +1013,7 @@ func acquireVkIdentity(ctx context.Context, link string, streamID int, creds VKC
} }
} }
defer func() { defer func() {
globalLastVkFetchTime = time.Now() throttle.lastTime = time.Now()
}() }()
if time.Now().Unix() < globalCaptchaLockout.Load() { if time.Now().Unix() < globalCaptchaLockout.Load() {
@ -1655,20 +1709,31 @@ func oneDtlsConnection(ctx context.Context, peer *net.UDPAddr, listenConn net.Pa
defer wg.Done() defer wg.Done()
defer dtlscancel() defer dtlscancel()
buf := make([]byte, 1600) buf := make([]byte, 1600)
var cachedAddr net.Addr
var cachedPtr any
for { for {
n, err1 := dtlsConn.Read(buf) n, err1 := dtlsConn.Read(buf)
if err1 != nil { if err1 != nil {
return return
} }
// Send back to the active WG client // Send back to the active WG client. Cache addr locally — only
if peerAddr := activeLocalPeer.Load(); peerAddr != nil { // re-resolve when atomic.Value pointer changes (rare).
if addr, ok := peerAddr.(net.Addr); ok { peerAddr := activeLocalPeer.Load()
if _, err := listenConn.WriteTo(buf[:n], addr); err != nil { if peerAddr == nil {
log.Printf("[STREAM %d] failed to forward packet to local peer: %v", streamID, err) continue
} }
if peerAddr != cachedPtr {
if a, ok := peerAddr.(net.Addr); ok {
cachedAddr = a
cachedPtr = peerAddr
} else {
continue
} }
} }
if _, err := listenConn.WriteTo(buf[:n], cachedAddr); err != nil {
log.Printf("[STREAM %d] failed to forward packet to local peer: %v", streamID, err)
}
} }
}() }()
@ -1832,28 +1897,40 @@ func dialTurn(ctx context.Context, useUDP bool, turnServerAddr string, turnServe
} }
// relayPool is a concurrent ring of live relay PacketConns. Reads (pick) are // relayPool is a concurrent ring of live relay PacketConns. Reads (pick) are
// non-blocking and lock-free on the hot path; mutation (add) is rare. // fully lock-free via atomic.Pointer to a snapshot slice (copy-on-write).
// add is rare and pays the alloc cost.
type relayPool struct { type relayPool struct {
mu sync.RWMutex relays atomic.Pointer[[]net.PacketConn]
relays []net.PacketConn addMu sync.Mutex
counter atomic.Uint64 counter atomic.Uint64
} }
func (p *relayPool) add(r net.PacketConn) { func (p *relayPool) add(r net.PacketConn) {
p.mu.Lock() p.addMu.Lock()
p.relays = append(p.relays, r) defer p.addMu.Unlock()
p.mu.Unlock() cur := p.relays.Load()
var next []net.PacketConn
if cur != nil {
next = make([]net.PacketConn, len(*cur)+1)
copy(next, *cur)
next[len(*cur)] = r
} else {
next = []net.PacketConn{r}
}
p.relays.Store(&next)
} }
func (p *relayPool) pick() net.PacketConn { func (p *relayPool) pick() net.PacketConn {
p.mu.RLock() cur := p.relays.Load()
defer p.mu.RUnlock() if cur == nil {
n := len(p.relays) return nil
}
n := len(*cur)
if n == 0 { if n == 0 {
return nil return nil
} }
idx := int(p.counter.Add(1)-1) % n idx := int(p.counter.Add(1)-1) % n
return p.relays[idx] return (*cur)[idx]
} }
func oneTurnConnection(ctx context.Context, turnParams *turnParams, peer *net.UDPAddr, conn2 net.PacketConn, streamID int, c chan<- error) { func oneTurnConnection(ctx context.Context, turnParams *turnParams, peer *net.UDPAddr, conn2 net.PacketConn, streamID int, c chan<- error) {
@ -2217,9 +2294,24 @@ func main() {
} }
// Shared Worker Pool Queue for Aggregation // Shared Worker Pool Queue for Aggregation
inboundChan := make(chan *UDPPacket, 2000) inboundChan := make(chan *UDPPacket, 8192)
var droppedPkts atomic.Uint64
go func() { go func() {
dropTicker := time.NewTicker(5 * time.Second)
defer dropTicker.Stop()
var lastDropped uint64
for range dropTicker.C {
cur := droppedPkts.Load()
if cur != lastDropped {
log.Printf("[inbound] dropped %d pkts since start (delta=%d) — queue saturated", cur, cur-lastDropped)
lastDropped = cur
}
}
}()
go func() {
var lastAddrStr string
for { for {
pktIface := packetPool.Get() pktIface := packetPool.Get()
pkt, ok := pktIface.(*UDPPacket) pkt, ok := pktIface.(*UDPPacket)
@ -2232,16 +2324,12 @@ func main() {
return return
} }
// Save the local WireGuard peer address // Save local WireGuard peer addr; cache string to avoid repeated
current := activeLocalPeer.Load() // type-assert + String() in hot path.
if current == nil { s := addr.String()
activeLocalPeer.Store(addr) if s != lastAddrStr {
} else if addrStr, ok := current.(net.Addr); ok {
if addrStr.String() != addr.String() {
activeLocalPeer.Store(addr)
}
} else {
activeLocalPeer.Store(addr) activeLocalPeer.Store(addr)
lastAddrStr = s
} }
pkt.N = nRead pkt.N = nRead
@ -2249,7 +2337,7 @@ func main() {
select { select {
case inboundChan <- pkt: case inboundChan <- pkt:
default: default:
// Drop the packet only if the global queue is completely full droppedPkts.Add(1)
packetPool.Put(pkt) packetPool.Put(pkt)
} }
} }
@ -2297,45 +2385,57 @@ func main() {
wg1.Wait() wg1.Wait()
} }
// sessionPool manages a pool of smux sessions for round-robin TCP distribution. // sessionPool manages smux sessions for round-robin TCP distribution.
// Lock-free reads via atomic.Pointer copy-on-write snapshot.
type sessionPool struct { type sessionPool struct {
mu sync.RWMutex sessions atomic.Pointer[[]*smux.Session]
sessions []*smux.Session mu sync.Mutex
counter atomic.Uint64 counter atomic.Uint64
} }
func (p *sessionPool) snapshot() []*smux.Session {
cur := p.sessions.Load()
if cur == nil {
return nil
}
return *cur
}
func (p *sessionPool) add(s *smux.Session) { func (p *sessionPool) add(s *smux.Session) {
p.mu.Lock() p.mu.Lock()
p.sessions = append(p.sessions, s) defer p.mu.Unlock()
p.mu.Unlock() cur := p.snapshot()
next := make([]*smux.Session, len(cur)+1)
copy(next, cur)
next[len(cur)] = s
p.sessions.Store(&next)
} }
func (p *sessionPool) remove(s *smux.Session) { func (p *sessionPool) remove(s *smux.Session) {
p.mu.Lock() p.mu.Lock()
for i, sess := range p.sessions { defer p.mu.Unlock()
if sess == s { cur := p.snapshot()
p.sessions = append(p.sessions[:i], p.sessions[i+1:]...) next := make([]*smux.Session, 0, len(cur))
break for _, sess := range cur {
if sess != s {
next = append(next, sess)
} }
} }
p.mu.Unlock() p.sessions.Store(&next)
} }
func (p *sessionPool) pick() *smux.Session { func (p *sessionPool) pick() *smux.Session {
p.mu.RLock() cur := p.snapshot()
defer p.mu.RUnlock() n := len(cur)
n := len(p.sessions)
if n == 0 { if n == 0 {
return nil return nil
} }
idx := p.counter.Add(1) % uint64(n) idx := p.counter.Add(1) % uint64(n)
return p.sessions[idx] return cur[idx]
} }
func (p *sessionPool) count() int { func (p *sessionPool) count() int {
p.mu.RLock() return len(p.snapshot())
defer p.mu.RUnlock()
return len(p.sessions)
} }
// runVLESSMode implements TCP forwarding with round-robin across N TURN sessions. // runVLESSMode implements TCP forwarding with round-robin across N TURN sessions.

Loading…
Cancel
Save