diff --git a/client/main.go b/client/main.go index 09628ec..5ea94b7 100644 --- a/client/main.go +++ b/client/main.go @@ -811,6 +811,7 @@ const ( cacheSafetyMargin = 60 * time.Second maxCacheErrors = 3 errorWindow = 10 * time.Second + turnServerCooldown = 30 * time.Second ) var streamsPerCache = 10 @@ -831,6 +832,44 @@ var credentialsStore = struct { caches: make(map[int]*StreamCredentialsCache), } +var streamServerOffsets sync.Map // map[int]*atomic.Uint64 +var turnServerCooldowns sync.Map // map[string]*atomic.Int64 + +func getStreamServerOffset(streamID int) uint64 { + v, _ := streamServerOffsets.LoadOrStore(streamID, &atomic.Uint64{}) + return v.(*atomic.Uint64).Load() +} + +func rotateStreamServer(streamID int) uint64 { + v, _ := streamServerOffsets.LoadOrStore(streamID, &atomic.Uint64{}) + return v.(*atomic.Uint64).Add(1) +} + +func pickStreamServerAddr(streamID int, addrs []string) string { + start := (uint64(streamID) + getStreamServerOffset(streamID)) % uint64(len(addrs)) + for i := uint64(0); i < uint64(len(addrs)); i++ { + idx := (start + i) % uint64(len(addrs)) + addr := addrs[idx] + if isTURNServerAvailable(addr) { + return addr + } + } + return addrs[start] +} + +func markTURNServerCooldown(addr string) { + v, _ := turnServerCooldowns.LoadOrStore(addr, &atomic.Int64{}) + v.(*atomic.Int64).Store(time.Now().Add(turnServerCooldown).UnixNano()) +} + +func isTURNServerAvailable(addr string) bool { + v, ok := turnServerCooldowns.Load(addr) + if !ok { + return true + } + return time.Now().UnixNano() >= v.(*atomic.Int64).Load() +} + func getStreamCache(streamID int) *StreamCredentialsCache { cacheID := getCacheID(streamID) @@ -908,8 +947,7 @@ func getVkCredsCached(ctx context.Context, link string, streamID int, dialer *dn if cache.creds.Link == link && time.Now().Before(cache.creds.ExpiresAt) && len(cache.creds.ServerAddrs) > 0 { expires := time.Until(cache.creds.ExpiresAt) u, p := cache.creds.Username, cache.creds.Password - // Round-robin selection based on streamID - addr := cache.creds.ServerAddrs[streamID%len(cache.creds.ServerAddrs)] + addr := pickStreamServerAddr(streamID, cache.creds.ServerAddrs) cache.mutex.RUnlock() if isDebug { log.Printf("[STREAM %d] [VK Auth] Using cached credentials (cache=%d, expires in %v, server=%s)", streamID, cacheID, expires, addr) @@ -923,7 +961,7 @@ func getVkCredsCached(ctx context.Context, link string, streamID int, dialer *dn // Double-check inside lock if cache.creds.Link == link && time.Now().Before(cache.creds.ExpiresAt) && len(cache.creds.ServerAddrs) > 0 { - addr := cache.creds.ServerAddrs[streamID%len(cache.creds.ServerAddrs)] + addr := pickStreamServerAddr(streamID, cache.creds.ServerAddrs) return cache.creds.Username, cache.creds.Password, addr, nil } @@ -933,7 +971,7 @@ func getVkCredsCached(ctx context.Context, link string, streamID int, dialer *dn } cache.creds = TurnCredentials{Username: user, Password: pass, ServerAddrs: addrs, ExpiresAt: time.Now().Add(credentialLifetime - cacheSafetyMargin), Link: link} - addr := addrs[streamID%len(addrs)] + addr := pickStreamServerAddr(streamID, addrs) return user, pass, addr, nil } @@ -2791,6 +2829,15 @@ func maintainVLESSSession(ctx context.Context, tp *turnParams, peer *net.UDPAddr smuxSess, cleanup, err := createSmuxSession(ctx, tp, peer, id) if err != nil { + if shouldRotateTURNServer(err) { + offset := rotateStreamServer(id) + if addr, ok := turnSetupAddr(err); ok { + markTURNServerCooldown(addr) + debugf("[session %d] cooling down TURN server %s for %s after setup failure (offset=%d)", id, addr, turnServerCooldown, offset) + } else { + debugf("[session %d] rotating TURN server after setup failure (offset=%d)", id, offset) + } + } log.Printf("[session %d] setup error: %s, retrying...", id, err) select { case <-ctx.Done(): @@ -2825,6 +2872,37 @@ func maintainVLESSSession(ctx context.Context, tp *turnParams, peer *net.UDPAddr } } +type turnSetupError struct { + addr string + err error +} + +func (e *turnSetupError) Error() string { + return e.err.Error() +} + +func (e *turnSetupError) Unwrap() error { + return e.err +} + +func turnSetupAddr(err error) (string, bool) { + var setupErr *turnSetupError + if errors.As(err, &setupErr) && setupErr.addr != "" { + return setupErr.addr, true + } + return "", false +} + +func shouldRotateTURNServer(err error) bool { + if err == nil { + return false + } + errStr := err.Error() + return strings.Contains(errStr, "dial TURN") || + strings.Contains(errStr, "TURN allocate") || + strings.Contains(errStr, "DTLS handshake") +} + // createSmuxSession establishes a full TURN+DTLS+KCP+smux pipeline and returns // the smux session along with a cleanup function to tear down all layers. func createSmuxSession(ctx context.Context, tp *turnParams, peer *net.UDPAddr, id int) (*smux.Session, func(), error) { @@ -2865,7 +2943,7 @@ func createSmuxSession(ctx context.Context, tp *turnParams, peer *net.UDPAddr, i if tp.udp { c, err1 := net.DialUDP("udp", nil, turnServerUDPAddr) if err1 != nil { - return nil, nil, fmt.Errorf("dial TURN (udp): %w", err1) + return nil, nil, &turnSetupError{addr: turnServerAddr, err: fmt.Errorf("dial TURN (udp): %w", err1)} } cleanupFns = append(cleanupFns, func() { _ = c.Close() }) turnConn = &connectedUDPConn{c} @@ -2873,7 +2951,7 @@ func createSmuxSession(ctx context.Context, tp *turnParams, peer *net.UDPAddr, i var d net.Dialer c, err1 := d.DialContext(ctx1, "tcp", turnServerAddr) if err1 != nil { - return nil, nil, fmt.Errorf("dial TURN (tcp): %w", err1) + return nil, nil, &turnSetupError{addr: turnServerAddr, err: fmt.Errorf("dial TURN (tcp): %w", err1)} } cleanupFns = append(cleanupFns, func() { _ = c.Close() }) turnConn = turn.NewSTUNConn(c) @@ -2904,12 +2982,12 @@ func createSmuxSession(ctx context.Context, tp *turnParams, peer *net.UDPAddr, i cleanupFns = append(cleanupFns, func() { turnClient.Close() }) if err = turnClient.Listen(); err != nil { cleanup() - return nil, nil, fmt.Errorf("TURN listen: %w", err) + return nil, nil, &turnSetupError{addr: turnServerAddr, err: fmt.Errorf("TURN listen: %w", err)} } relayConn, err := turnClient.Allocate() if err != nil { cleanup() - return nil, nil, fmt.Errorf("TURN allocate: %w", err) + return nil, nil, &turnSetupError{addr: turnServerAddr, err: fmt.Errorf("TURN allocate: %w", err)} } cleanupFns = append(cleanupFns, func() { _ = relayConn.Close() }) debugf("relayed-address=%s", relayConn.LocalAddr().String()) @@ -2937,7 +3015,7 @@ func createSmuxSession(ctx context.Context, tp *turnParams, peer *net.UDPAddr, i if err = dtlsConn.HandshakeContext(ctx2); err != nil { _ = dtlsConn.Close() cleanup() - return nil, nil, fmt.Errorf("DTLS handshake: %w", err) + return nil, nil, &turnSetupError{addr: turnServerAddr, err: fmt.Errorf("DTLS handshake: %w", err)} } cleanupFns = append(cleanupFns, func() { _ = dtlsConn.Close() }) debugf("DTLS connection established")