Browse Source

feat: round-robin multi-session TCP mode for increased throughput

Distribute incoming TCP connections across N parallel TURN+DTLS+KCP+smux
sessions in round-robin fashion, aggregating bandwidth of multiple relays.
Add sessionPool with thread-safe add/remove/pick operations.
Each session is maintained by its own goroutine with auto-reconnect.
The -n flag now controls session count in TCP mode (default 16 for VK).
Refactor: extract createSmuxSession from runTCPSession for reuse.
pull/74/head
Moroka8 2 months ago
parent
commit
8acc5535f7
  1. 266
      client/main.go

266
client/main.go

@ -884,7 +884,7 @@ func main() { //nolint:cyclop
}
if *tcpMode {
runTCPMode(ctx, params, peer, *listen)
runTCPMode(ctx, params, peer, *listen, *n)
return
}
@ -946,10 +946,126 @@ func main() { //nolint:cyclop
wg1.Wait()
}
// runTCPMode implements TCP forwarding mode for VLESS.
// It establishes a DTLS tunnel through TURN, then creates a KCP+smux session
// on top, and forwards incoming TCP connections as smux streams.
func runTCPMode(ctx context.Context, tp *turnParams, peer *net.UDPAddr, listenAddr string) {
// sessionPool manages a pool of smux sessions for round-robin TCP distribution.
type sessionPool struct {
mu sync.RWMutex
sessions []*smux.Session
counter atomic.Uint64
}
func (p *sessionPool) add(s *smux.Session) {
p.mu.Lock()
p.sessions = append(p.sessions, s)
p.mu.Unlock()
}
func (p *sessionPool) remove(s *smux.Session) {
p.mu.Lock()
for i, sess := range p.sessions {
if sess == s {
p.sessions = append(p.sessions[:i], p.sessions[i+1:]...)
break
}
}
p.mu.Unlock()
}
func (p *sessionPool) pick() *smux.Session {
p.mu.RLock()
defer p.mu.RUnlock()
n := len(p.sessions)
if n == 0 {
return nil
}
idx := p.counter.Add(1) % uint64(n)
return p.sessions[idx]
}
func (p *sessionPool) count() int {
p.mu.RLock()
defer p.mu.RUnlock()
return len(p.sessions)
}
// runTCPMode implements TCP forwarding with round-robin across N TURN sessions.
func runTCPMode(ctx context.Context, tp *turnParams, peer *net.UDPAddr, listenAddr string, numSessions int) {
pool := &sessionPool{}
// Start N session maintainers with staggered startup
var wgMaint sync.WaitGroup
for i := 0; i < numSessions; i++ {
wgMaint.Add(1)
go func(id int) {
defer wgMaint.Done()
select {
case <-ctx.Done():
return
case <-time.After(time.Duration(id) * 300 * time.Millisecond):
}
maintainTCPSession(ctx, tp, peer, id, pool)
}(i)
}
// Wait for at least one session
log.Printf("TCP mode: waiting for sessions to connect (total: %d)...", numSessions)
for {
select {
case <-ctx.Done():
wgMaint.Wait()
return
case <-time.After(100 * time.Millisecond):
}
if pool.count() > 0 {
break
}
}
listener, err := net.Listen("tcp", listenAddr)
if err != nil {
log.Panicf("TCP listen: %s", err)
}
context.AfterFunc(ctx, func() { listener.Close() })
log.Printf("TCP mode: listening on %s (round-robin across %d sessions)", listenAddr, numSessions)
var wgConn sync.WaitGroup
for {
tcpConn, err := listener.Accept()
if err != nil {
select {
case <-ctx.Done():
wgConn.Wait()
wgMaint.Wait()
return
default:
}
log.Printf("TCP accept error: %s", err)
continue
}
sess := pool.pick()
if sess == nil || sess.IsClosed() {
log.Printf("No active sessions, rejecting connection")
tcpConn.Close()
continue
}
wgConn.Add(1)
go func(tc net.Conn, s *smux.Session) {
defer wgConn.Done()
defer tc.Close()
stream, err := s.OpenStream()
if err != nil {
log.Printf("smux open stream error: %s", err)
return
}
defer stream.Close()
pipe(ctx, tc, stream)
}(tcpConn, sess)
}
}
// maintainTCPSession keeps one TURN+DTLS+KCP+smux session alive, reconnecting on failure.
func maintainTCPSession(ctx context.Context, tp *turnParams, peer *net.UDPAddr, id int, pool *sessionPool) {
for {
select {
case <-ctx.Done():
@ -957,11 +1073,34 @@ func runTCPMode(ctx context.Context, tp *turnParams, peer *net.UDPAddr, listenAd
default:
}
err := runTCPSession(ctx, tp, peer, listenAddr)
smuxSess, cleanup, err := createSmuxSession(ctx, tp, peer)
if err != nil {
log.Printf("TCP session error: %s, reconnecting...", err)
log.Printf("[session %d] setup error: %s, retrying...", id, err)
select {
case <-ctx.Done():
return
case <-time.After(3 * time.Second):
}
continue
}
pool.add(smuxSess)
log.Printf("[session %d] connected (active: %d)", id, pool.count())
for !smuxSess.IsClosed() {
select {
case <-ctx.Done():
pool.remove(smuxSess)
cleanup()
return
case <-time.After(1 * time.Second):
}
}
pool.remove(smuxSess)
cleanup()
log.Printf("[session %d] disconnected (active: %d), reconnecting...", id, pool.count())
select {
case <-ctx.Done():
return
@ -970,15 +1109,24 @@ func runTCPMode(ctx context.Context, tp *turnParams, peer *net.UDPAddr, listenAd
}
}
func runTCPSession(ctx context.Context, tp *turnParams, peer *net.UDPAddr, listenAddr string) error {
// createSmuxSession establishes a full TURN+DTLS+KCP+smux pipeline and returns
// the smux session along with a cleanup function to tear down all layers.
func createSmuxSession(ctx context.Context, tp *turnParams, peer *net.UDPAddr) (*smux.Session, func(), error) {
var cleanupFns []func()
cleanup := func() {
for i := len(cleanupFns) - 1; i >= 0; i-- {
cleanupFns[i]()
}
}
// 1. Get TURN credentials
user, pass, url, err := tp.getCreds(tp.link)
user, pass, rawURL, err := tp.getCreds(tp.link)
if err != nil {
return fmt.Errorf("get TURN creds: %w", err)
return nil, nil, fmt.Errorf("get TURN creds: %w", err)
}
urlhost, urlport, err := net.SplitHostPort(url)
urlhost, urlport, err := net.SplitHostPort(rawURL)
if err != nil {
return fmt.Errorf("parse TURN addr: %w", err)
return nil, nil, fmt.Errorf("parse TURN addr: %w", err)
}
if tp.host != "" {
urlhost = tp.host
@ -989,10 +1137,9 @@ func runTCPSession(ctx context.Context, tp *turnParams, peer *net.UDPAddr, liste
turnServerAddr := net.JoinHostPort(urlhost, urlport)
turnServerUdpAddr, err := net.ResolveUDPAddr("udp", turnServerAddr)
if err != nil {
return fmt.Errorf("resolve TURN addr: %w", err)
return nil, nil, fmt.Errorf("resolve TURN addr: %w", err)
}
turnServerAddr = turnServerUdpAddr.String()
fmt.Println(turnServerUdpAddr.IP)
// 2. Connect to TURN server
var turnConn net.PacketConn
@ -1001,21 +1148,21 @@ func runTCPSession(ctx context.Context, tp *turnParams, peer *net.UDPAddr, liste
if tp.udp {
conn, err := net.DialUDP("udp", nil, turnServerUdpAddr)
if err != nil {
return fmt.Errorf("dial TURN (udp): %w", err)
return nil, nil, fmt.Errorf("dial TURN (udp): %w", err)
}
defer conn.Close()
cleanupFns = append(cleanupFns, func() { conn.Close() })
turnConn = &connectedUDPConn{conn}
} else {
var d net.Dialer
conn, err := d.DialContext(ctx1, "tcp", turnServerAddr)
if err != nil {
return fmt.Errorf("dial TURN (tcp): %w", err)
return nil, nil, fmt.Errorf("dial TURN (tcp): %w", err)
}
defer conn.Close()
cleanupFns = append(cleanupFns, func() { conn.Close() })
turnConn = turn.NewSTUNConn(conn)
}
// 3. Allocate TURN relay
// 3. Create TURN client and allocate relay
var addrFamily turn.RequestedAddressFamily
if peer.IP.To4() != nil {
addrFamily = turn.RequestedAddressFamilyIPv4
@ -1033,28 +1180,29 @@ func runTCPSession(ctx context.Context, tp *turnParams, peer *net.UDPAddr, liste
}
turnClient, err := turn.NewClient(cfg)
if err != nil {
return fmt.Errorf("create TURN client: %w", err)
cleanup()
return nil, nil, fmt.Errorf("create TURN client: %w", err)
}
defer turnClient.Close()
cleanupFns = append(cleanupFns, func() { turnClient.Close() })
if err = turnClient.Listen(); err != nil {
return fmt.Errorf("TURN listen: %w", err)
cleanup()
return nil, nil, fmt.Errorf("TURN listen: %w", err)
}
relayConn, err := turnClient.Allocate()
if err != nil {
return fmt.Errorf("TURN allocate: %w", err)
cleanup()
return nil, nil, fmt.Errorf("TURN allocate: %w", err)
}
defer relayConn.Close()
cleanupFns = append(cleanupFns, func() { relayConn.Close() })
log.Printf("relayed-address=%s", relayConn.LocalAddr().String())
// 4. Establish DTLS over TURN relay
certificate, err := selfsign.GenerateSelfSigned()
if err != nil {
return fmt.Errorf("generate cert: %w", err)
cleanup()
return nil, nil, fmt.Errorf("generate cert: %w", err)
}
// Create a connected PacketConn for DTLS: relay writes go to peer
dtlsPC := &relayPacketConn{relay: relayConn, peer: peer}
dtlsConfig := &dtls.Config{
Certificates: []tls.Certificate{certificate},
InsecureSkipVerify: true,
@ -1062,78 +1210,40 @@ func runTCPSession(ctx context.Context, tp *turnParams, peer *net.UDPAddr, liste
CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
ConnectionIDGenerator: dtls.OnlySendCIDGenerator(),
}
dtlsConn, err := dtls.Client(dtlsPC, peer, dtlsConfig)
if err != nil {
return fmt.Errorf("DTLS client create: %w", err)
cleanup()
return nil, nil, fmt.Errorf("DTLS client create: %w", err)
}
ctx2, cancel2 := context.WithTimeout(ctx, 30*time.Second)
defer cancel2()
if err = dtlsConn.HandshakeContext(ctx2); err != nil {
dtlsConn.Close()
return fmt.Errorf("DTLS handshake: %w", err)
cleanup()
return nil, nil, fmt.Errorf("DTLS handshake: %w", err)
}
defer dtlsConn.Close()
cleanupFns = append(cleanupFns, func() { dtlsConn.Close() })
log.Printf("DTLS connection established")
// 5. Create KCP session over DTLS
kcpSess, err := tcputil.NewKCPOverDTLS(dtlsConn, false)
if err != nil {
return fmt.Errorf("KCP session: %w", err)
cleanup()
return nil, nil, fmt.Errorf("KCP session: %w", err)
}
defer kcpSess.Close()
cleanupFns = append(cleanupFns, func() { kcpSess.Close() })
log.Printf("KCP session established")
// 6. Create smux client session over KCP
smuxSess, err := smux.Client(kcpSess, tcputil.DefaultSmuxConfig())
if err != nil {
return fmt.Errorf("smux client: %w", err)
cleanup()
return nil, nil, fmt.Errorf("smux client: %w", err)
}
defer smuxSess.Close()
cleanupFns = append(cleanupFns, func() { smuxSess.Close() })
log.Printf("smux session established")
// 7. Listen for TCP connections and forward through smux
listener, err := net.Listen("tcp", listenAddr)
if err != nil {
return fmt.Errorf("TCP listen: %w", err)
}
context.AfterFunc(ctx, func() { listener.Close() })
log.Printf("TCP mode: listening on %s", listenAddr)
var wg sync.WaitGroup
for {
tcpConn, err := listener.Accept()
if err != nil {
select {
case <-ctx.Done():
wg.Wait()
return nil
default:
}
if smuxSess.IsClosed() {
wg.Wait()
return fmt.Errorf("smux session closed")
}
log.Printf("TCP accept error: %s", err)
continue
}
wg.Add(1)
go func(tc net.Conn) {
defer wg.Done()
defer tc.Close()
stream, err := smuxSess.OpenStream()
if err != nil {
log.Printf("smux open stream error: %s", err)
return
}
defer stream.Close()
pipe(ctx, tc, stream)
}(tcpConn)
}
return smuxSess, cleanup, nil
}
// relayPacketConn wraps a TURN relay PacketConn to direct all writes to the peer.

Loading…
Cancel
Save