diff --git a/client/main.go b/client/main.go index 58e347b..0ebc343 100644 --- a/client/main.go +++ b/client/main.go @@ -884,7 +884,7 @@ func main() { //nolint:cyclop } if *tcpMode { - runTCPMode(ctx, params, peer, *listen) + runTCPMode(ctx, params, peer, *listen, *n) return } @@ -946,10 +946,126 @@ func main() { //nolint:cyclop wg1.Wait() } -// runTCPMode implements TCP forwarding mode for VLESS. -// It establishes a DTLS tunnel through TURN, then creates a KCP+smux session -// on top, and forwards incoming TCP connections as smux streams. -func runTCPMode(ctx context.Context, tp *turnParams, peer *net.UDPAddr, listenAddr string) { +// sessionPool manages a pool of smux sessions for round-robin TCP distribution. +type sessionPool struct { + mu sync.RWMutex + sessions []*smux.Session + counter atomic.Uint64 +} + +func (p *sessionPool) add(s *smux.Session) { + p.mu.Lock() + p.sessions = append(p.sessions, s) + p.mu.Unlock() +} + +func (p *sessionPool) remove(s *smux.Session) { + p.mu.Lock() + for i, sess := range p.sessions { + if sess == s { + p.sessions = append(p.sessions[:i], p.sessions[i+1:]...) + break + } + } + p.mu.Unlock() +} + +func (p *sessionPool) pick() *smux.Session { + p.mu.RLock() + defer p.mu.RUnlock() + n := len(p.sessions) + if n == 0 { + return nil + } + idx := p.counter.Add(1) % uint64(n) + return p.sessions[idx] +} + +func (p *sessionPool) count() int { + p.mu.RLock() + defer p.mu.RUnlock() + return len(p.sessions) +} + +// runTCPMode implements TCP forwarding with round-robin across N TURN sessions. +func runTCPMode(ctx context.Context, tp *turnParams, peer *net.UDPAddr, listenAddr string, numSessions int) { + pool := &sessionPool{} + + // Start N session maintainers with staggered startup + var wgMaint sync.WaitGroup + for i := 0; i < numSessions; i++ { + wgMaint.Add(1) + go func(id int) { + defer wgMaint.Done() + select { + case <-ctx.Done(): + return + case <-time.After(time.Duration(id) * 300 * time.Millisecond): + } + maintainTCPSession(ctx, tp, peer, id, pool) + }(i) + } + + // Wait for at least one session + log.Printf("TCP mode: waiting for sessions to connect (total: %d)...", numSessions) + for { + select { + case <-ctx.Done(): + wgMaint.Wait() + return + case <-time.After(100 * time.Millisecond): + } + if pool.count() > 0 { + break + } + } + + listener, err := net.Listen("tcp", listenAddr) + if err != nil { + log.Panicf("TCP listen: %s", err) + } + context.AfterFunc(ctx, func() { listener.Close() }) + log.Printf("TCP mode: listening on %s (round-robin across %d sessions)", listenAddr, numSessions) + + var wgConn sync.WaitGroup + for { + tcpConn, err := listener.Accept() + if err != nil { + select { + case <-ctx.Done(): + wgConn.Wait() + wgMaint.Wait() + return + default: + } + log.Printf("TCP accept error: %s", err) + continue + } + + sess := pool.pick() + if sess == nil || sess.IsClosed() { + log.Printf("No active sessions, rejecting connection") + tcpConn.Close() + continue + } + + wgConn.Add(1) + go func(tc net.Conn, s *smux.Session) { + defer wgConn.Done() + defer tc.Close() + stream, err := s.OpenStream() + if err != nil { + log.Printf("smux open stream error: %s", err) + return + } + defer stream.Close() + pipe(ctx, tc, stream) + }(tcpConn, sess) + } +} + +// maintainTCPSession keeps one TURN+DTLS+KCP+smux session alive, reconnecting on failure. +func maintainTCPSession(ctx context.Context, tp *turnParams, peer *net.UDPAddr, id int, pool *sessionPool) { for { select { case <-ctx.Done(): @@ -957,11 +1073,34 @@ func runTCPMode(ctx context.Context, tp *turnParams, peer *net.UDPAddr, listenAd default: } - err := runTCPSession(ctx, tp, peer, listenAddr) + smuxSess, cleanup, err := createSmuxSession(ctx, tp, peer) if err != nil { - log.Printf("TCP session error: %s, reconnecting...", err) + log.Printf("[session %d] setup error: %s, retrying...", id, err) + select { + case <-ctx.Done(): + return + case <-time.After(3 * time.Second): + } + continue } + pool.add(smuxSess) + log.Printf("[session %d] connected (active: %d)", id, pool.count()) + + for !smuxSess.IsClosed() { + select { + case <-ctx.Done(): + pool.remove(smuxSess) + cleanup() + return + case <-time.After(1 * time.Second): + } + } + + pool.remove(smuxSess) + cleanup() + log.Printf("[session %d] disconnected (active: %d), reconnecting...", id, pool.count()) + select { case <-ctx.Done(): return @@ -970,15 +1109,24 @@ func runTCPMode(ctx context.Context, tp *turnParams, peer *net.UDPAddr, listenAd } } -func runTCPSession(ctx context.Context, tp *turnParams, peer *net.UDPAddr, listenAddr string) error { +// createSmuxSession establishes a full TURN+DTLS+KCP+smux pipeline and returns +// the smux session along with a cleanup function to tear down all layers. +func createSmuxSession(ctx context.Context, tp *turnParams, peer *net.UDPAddr) (*smux.Session, func(), error) { + var cleanupFns []func() + cleanup := func() { + for i := len(cleanupFns) - 1; i >= 0; i-- { + cleanupFns[i]() + } + } + // 1. Get TURN credentials - user, pass, url, err := tp.getCreds(tp.link) + user, pass, rawURL, err := tp.getCreds(tp.link) if err != nil { - return fmt.Errorf("get TURN creds: %w", err) + return nil, nil, fmt.Errorf("get TURN creds: %w", err) } - urlhost, urlport, err := net.SplitHostPort(url) + urlhost, urlport, err := net.SplitHostPort(rawURL) if err != nil { - return fmt.Errorf("parse TURN addr: %w", err) + return nil, nil, fmt.Errorf("parse TURN addr: %w", err) } if tp.host != "" { urlhost = tp.host @@ -989,10 +1137,9 @@ func runTCPSession(ctx context.Context, tp *turnParams, peer *net.UDPAddr, liste turnServerAddr := net.JoinHostPort(urlhost, urlport) turnServerUdpAddr, err := net.ResolveUDPAddr("udp", turnServerAddr) if err != nil { - return fmt.Errorf("resolve TURN addr: %w", err) + return nil, nil, fmt.Errorf("resolve TURN addr: %w", err) } turnServerAddr = turnServerUdpAddr.String() - fmt.Println(turnServerUdpAddr.IP) // 2. Connect to TURN server var turnConn net.PacketConn @@ -1001,21 +1148,21 @@ func runTCPSession(ctx context.Context, tp *turnParams, peer *net.UDPAddr, liste if tp.udp { conn, err := net.DialUDP("udp", nil, turnServerUdpAddr) if err != nil { - return fmt.Errorf("dial TURN (udp): %w", err) + return nil, nil, fmt.Errorf("dial TURN (udp): %w", err) } - defer conn.Close() + cleanupFns = append(cleanupFns, func() { conn.Close() }) turnConn = &connectedUDPConn{conn} } else { var d net.Dialer conn, err := d.DialContext(ctx1, "tcp", turnServerAddr) if err != nil { - return fmt.Errorf("dial TURN (tcp): %w", err) + return nil, nil, fmt.Errorf("dial TURN (tcp): %w", err) } - defer conn.Close() + cleanupFns = append(cleanupFns, func() { conn.Close() }) turnConn = turn.NewSTUNConn(conn) } - // 3. Allocate TURN relay + // 3. Create TURN client and allocate relay var addrFamily turn.RequestedAddressFamily if peer.IP.To4() != nil { addrFamily = turn.RequestedAddressFamilyIPv4 @@ -1033,28 +1180,29 @@ func runTCPSession(ctx context.Context, tp *turnParams, peer *net.UDPAddr, liste } turnClient, err := turn.NewClient(cfg) if err != nil { - return fmt.Errorf("create TURN client: %w", err) + cleanup() + return nil, nil, fmt.Errorf("create TURN client: %w", err) } - defer turnClient.Close() + cleanupFns = append(cleanupFns, func() { turnClient.Close() }) if err = turnClient.Listen(); err != nil { - return fmt.Errorf("TURN listen: %w", err) + cleanup() + return nil, nil, fmt.Errorf("TURN listen: %w", err) } relayConn, err := turnClient.Allocate() if err != nil { - return fmt.Errorf("TURN allocate: %w", err) + cleanup() + return nil, nil, fmt.Errorf("TURN allocate: %w", err) } - defer relayConn.Close() + cleanupFns = append(cleanupFns, func() { relayConn.Close() }) log.Printf("relayed-address=%s", relayConn.LocalAddr().String()) // 4. Establish DTLS over TURN relay certificate, err := selfsign.GenerateSelfSigned() if err != nil { - return fmt.Errorf("generate cert: %w", err) + cleanup() + return nil, nil, fmt.Errorf("generate cert: %w", err) } - - // Create a connected PacketConn for DTLS: relay writes go to peer dtlsPC := &relayPacketConn{relay: relayConn, peer: peer} - dtlsConfig := &dtls.Config{ Certificates: []tls.Certificate{certificate}, InsecureSkipVerify: true, @@ -1062,78 +1210,40 @@ func runTCPSession(ctx context.Context, tp *turnParams, peer *net.UDPAddr, liste CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, ConnectionIDGenerator: dtls.OnlySendCIDGenerator(), } - dtlsConn, err := dtls.Client(dtlsPC, peer, dtlsConfig) if err != nil { - return fmt.Errorf("DTLS client create: %w", err) + cleanup() + return nil, nil, fmt.Errorf("DTLS client create: %w", err) } - ctx2, cancel2 := context.WithTimeout(ctx, 30*time.Second) defer cancel2() if err = dtlsConn.HandshakeContext(ctx2); err != nil { dtlsConn.Close() - return fmt.Errorf("DTLS handshake: %w", err) + cleanup() + return nil, nil, fmt.Errorf("DTLS handshake: %w", err) } - defer dtlsConn.Close() + cleanupFns = append(cleanupFns, func() { dtlsConn.Close() }) log.Printf("DTLS connection established") // 5. Create KCP session over DTLS kcpSess, err := tcputil.NewKCPOverDTLS(dtlsConn, false) if err != nil { - return fmt.Errorf("KCP session: %w", err) + cleanup() + return nil, nil, fmt.Errorf("KCP session: %w", err) } - defer kcpSess.Close() + cleanupFns = append(cleanupFns, func() { kcpSess.Close() }) log.Printf("KCP session established") // 6. Create smux client session over KCP smuxSess, err := smux.Client(kcpSess, tcputil.DefaultSmuxConfig()) if err != nil { - return fmt.Errorf("smux client: %w", err) + cleanup() + return nil, nil, fmt.Errorf("smux client: %w", err) } - defer smuxSess.Close() + cleanupFns = append(cleanupFns, func() { smuxSess.Close() }) log.Printf("smux session established") - // 7. Listen for TCP connections and forward through smux - listener, err := net.Listen("tcp", listenAddr) - if err != nil { - return fmt.Errorf("TCP listen: %w", err) - } - context.AfterFunc(ctx, func() { listener.Close() }) - log.Printf("TCP mode: listening on %s", listenAddr) - - var wg sync.WaitGroup - for { - tcpConn, err := listener.Accept() - if err != nil { - select { - case <-ctx.Done(): - wg.Wait() - return nil - default: - } - if smuxSess.IsClosed() { - wg.Wait() - return fmt.Errorf("smux session closed") - } - log.Printf("TCP accept error: %s", err) - continue - } - - wg.Add(1) - go func(tc net.Conn) { - defer wg.Done() - defer tc.Close() - - stream, err := smuxSess.OpenStream() - if err != nil { - log.Printf("smux open stream error: %s", err) - return - } - defer stream.Close() - - pipe(ctx, tc, stream) - }(tcpConn) - } + return smuxSess, cleanup, nil } // relayPacketConn wraps a TURN relay PacketConn to direct all writes to the peer.