From 96831328bff085137d32c07f8dd8d2dae61d8eae Mon Sep 17 00:00:00 2001 From: Moroka8 Date: Sat, 2 May 2026 13:46:30 +0700 Subject: [PATCH] feat: add connection bonding and configurable kcp profiles Implement multi-lane transport (bonding) for VLESS to aggregate bandwidth from multiple KCP/DTLS sessions. Added environment-based KCP tuning and throughput statistics. - Support multi-lane bonding in client and server - Add KCP profiles (fast, balanced, slow) via VK_TURN_KCP_PROFILE - Auto-detect bonding streams via magic prefix - Add real-time throughput logging for active connections --- Dockerfile | 1 + client/main.go | 543 +++++++++++++++++++++++++++++++++++++-- docker-entrypoint.sh | 7 +- server/main.go | 600 ++++++++++++++++++++++++++++++++++++++++++- tcputil/tcputil.go | 111 +++++++- 5 files changed, 1221 insertions(+), 41 deletions(-) diff --git a/Dockerfile b/Dockerfile index f3c8350..eaaec0c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -15,6 +15,7 @@ COPY docker-entrypoint.sh . COPY --from=builder /build/vk-turn-proxy . RUN chmod +x docker-entrypoint.sh +EXPOSE 56000/tcp EXPOSE 56000/udp ENTRYPOINT ["./docker-entrypoint.sh"] diff --git a/client/main.go b/client/main.go index b6528be..7f053d6 100644 --- a/client/main.go +++ b/client/main.go @@ -9,6 +9,7 @@ import ( "crypto/md5" "crypto/sha256" "encoding/base64" + "encoding/binary" "encoding/hex" "encoding/json" "flag" @@ -121,6 +122,102 @@ var packetPool = sync.Pool{ New: func() any { return &UDPPacket{Data: make([]byte, 2048)} }, } +type throughputStats struct { + tx atomic.Uint64 + rx atomic.Uint64 +} + +func (s *throughputStats) addTx(n int) { + if n > 0 { + s.tx.Add(uint64(n)) + } +} + +func (s *throughputStats) addRx(n int) { + if n > 0 { + s.rx.Add(uint64(n)) + } +} + +func (s *throughputStats) logEvery(ctx context.Context, label, txName, rxName string) { + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + + var prevTx, prevRx uint64 + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + tx := s.tx.Load() + rx := s.rx.Load() + deltaTx := tx - prevTx + deltaRx := rx - prevRx + prevTx = tx + prevRx = rx + + if deltaTx == 0 && deltaRx == 0 { + continue + } + + log.Printf( + "%s throughput: %s=%s %s=%s total_%s=%s total_%s=%s", + label, + txName, + formatBitsPerSecond(deltaTx, 5*time.Second), + rxName, + formatBitsPerSecond(deltaRx, 5*time.Second), + txName, + formatByteCount(tx), + rxName, + formatByteCount(rx), + ) + } + } +} + +func formatBitsPerSecond(bytes uint64, interval time.Duration) string { + if interval <= 0 { + interval = time.Second + } + + bps := float64(bytes*8) / interval.Seconds() + if bps >= 1_000_000 { + return fmt.Sprintf("%.2f Mbit/s", bps/1_000_000) + } + if bps >= 1_000 { + return fmt.Sprintf("%.1f kbit/s", bps/1_000) + } + return fmt.Sprintf("%.0f bit/s", bps) +} + +func formatByteCount(bytes uint64) string { + if bytes >= 1024*1024 { + return fmt.Sprintf("%.2f MiB", float64(bytes)/(1024*1024)) + } + if bytes >= 1024 { + return fmt.Sprintf("%.1f KiB", float64(bytes)/1024) + } + return fmt.Sprintf("%d B", bytes) +} + +type countingConn struct { + net.Conn + stats *throughputStats +} + +func (c *countingConn) Read(p []byte) (int, error) { + n, err := c.Conn.Read(p) + c.stats.addRx(n) + return n, err +} + +func (c *countingConn) Write(p []byte) (int, error) { + n, err := c.Conn.Write(p) + c.stats.addTx(n) + return n, err +} + func newDirectNet() transport.Net { return directNet{} } @@ -1754,6 +1851,9 @@ func oneTurnConnection(ctx context.Context, turnParams *turnParams, peer *net.UD wg := sync.WaitGroup{} wg.Add(1) turnctx, turncancel := context.WithCancel(ctx) + stats := &throughputStats{} + go stats.logEvery(turnctx, fmt.Sprintf("[STREAM %d] TURN", streamID), "to-turn", "from-turn") + context.AfterFunc(turnctx, func() { if err := relayConn.SetDeadline(time.Now()); err != nil { log.Printf("Failed to set relay deadline: %s", err) @@ -1779,7 +1879,8 @@ func oneTurnConnection(ctx context.Context, turnParams *turnParams, peer *net.UD internalPipeAddr.Store(addr1) - _, err1 = relayConn.WriteTo(buf[:n], peer) + written, err1 := relayConn.WriteTo(buf[:n], peer) + stats.addTx(written) if err1 != nil { return } @@ -1801,6 +1902,7 @@ func oneTurnConnection(ctx context.Context, turnParams *turnParams, peer *net.UD } if addr, ok := addr1.(net.Addr); ok { + stats.addRx(n) if _, err := conn2.WriteTo(buf[:n], addr); err != nil { return } @@ -1937,6 +2039,7 @@ func main() { udp := flag.Bool("udp", false, "connect to TURN with UDP") direct := flag.Bool("no-dtls", false, "connect without obfuscation. DO NOT USE") vlessMode := flag.Bool("vless", false, "VLESS mode: forward TCP connections (for VLESS) instead of UDP packets") + vlessBond := flag.Bool("vless-bond", false, "bond one VLESS TCP connection across all active smux sessions") debugFlag := flag.Bool("debug", false, "enable debug logging") manualCaptchaFlag := flag.Bool("manual-captcha", false, "skip auto captcha solving, use manual mode immediately") flag.Parse() @@ -1996,7 +2099,7 @@ func main() { } if *vlessMode { - runVLESSMode(ctx, params, peer, *listen, *n) + runVLESSMode(ctx, params, peer, *listen, *n, *vlessBond) return } @@ -2097,22 +2200,35 @@ func main() { } // sessionPool manages a pool of smux sessions for round-robin TCP distribution. +type pooledSession struct { + id int + sess *smux.Session + active atomic.Int32 + opened atomic.Uint64 + closed atomic.Uint64 + toSession atomic.Uint64 + fromSession atomic.Uint64 +} + type sessionPool struct { - mu sync.RWMutex - sessions []*smux.Session - counter atomic.Uint64 + mu sync.RWMutex + sessions []*pooledSession + counter atomic.Uint64 + connCounter atomic.Uint64 } -func (p *sessionPool) add(s *smux.Session) { +func (p *sessionPool) add(id int, s *smux.Session) *pooledSession { + ps := &pooledSession{id: id, sess: s} p.mu.Lock() - p.sessions = append(p.sessions, s) + p.sessions = append(p.sessions, ps) p.mu.Unlock() + return ps } -func (p *sessionPool) remove(s *smux.Session) { +func (p *sessionPool) remove(ps *pooledSession) { p.mu.Lock() for i, sess := range p.sessions { - if sess == s { + if sess == ps { p.sessions = append(p.sessions[:i], p.sessions[i+1:]...) break } @@ -2120,25 +2236,346 @@ func (p *sessionPool) remove(s *smux.Session) { p.mu.Unlock() } -func (p *sessionPool) pick() *smux.Session { +func (p *sessionPool) pick() *pooledSession { p.mu.RLock() defer p.mu.RUnlock() n := len(p.sessions) if n == 0 { return nil } - idx := p.counter.Add(1) % uint64(n) + idx := (p.counter.Add(1) - 1) % uint64(n) return p.sessions[idx] } +func (p *sessionPool) nextConnID() uint64 { + return p.connCounter.Add(1) +} + +func (p *sessionPool) snapshot() []*pooledSession { + p.mu.RLock() + defer p.mu.RUnlock() + out := make([]*pooledSession, 0, len(p.sessions)) + for _, ps := range p.sessions { + if !ps.sess.IsClosed() { + out = append(out, ps) + } + } + return out +} + func (p *sessionPool) count() int { p.mu.RLock() defer p.mu.RUnlock() return len(p.sessions) } +const ( + bondVersion = 1 + bondMagic = "VLB1" + + bondFrameData byte = 1 + bondFrameFIN byte = 2 + + bondMaxChunk = 16 * 1024 +) + +type bondFrame struct { + typ byte + seq uint64 + data []byte +} + +type bondClientLane struct { + ps *pooledSession + stream *smux.Stream + mu sync.Mutex + dead atomic.Bool +} + +func writeBondHello(w io.Writer, connID uint64, laneIndex, laneCount uint16) error { + var hdr [17]byte + copy(hdr[0:4], bondMagic) + hdr[4] = bondVersion + binary.BigEndian.PutUint64(hdr[5:13], connID) + binary.BigEndian.PutUint16(hdr[13:15], laneIndex) + binary.BigEndian.PutUint16(hdr[15:17], laneCount) + _, err := w.Write(hdr[:]) + return err +} + +func writeBondFrame(w io.Writer, typ byte, seq uint64, data []byte) error { + var hdr [13]byte + hdr[0] = typ + binary.BigEndian.PutUint64(hdr[1:9], seq) + binary.BigEndian.PutUint32(hdr[9:13], uint32(len(data))) + if _, err := w.Write(hdr[:]); err != nil { + return err + } + if len(data) == 0 { + return nil + } + _, err := w.Write(data) + return err +} + +func readBondFrame(r io.Reader) (bondFrame, error) { + var hdr [13]byte + if _, err := io.ReadFull(r, hdr[:]); err != nil { + return bondFrame{}, err + } + size := binary.BigEndian.Uint32(hdr[9:13]) + if size > 4*1024*1024 { + return bondFrame{}, fmt.Errorf("bond frame too large: %d", size) + } + f := bondFrame{ + typ: hdr[0], + seq: binary.BigEndian.Uint64(hdr[1:9]), + } + if size > 0 { + f.data = make([]byte, size) + if _, err := io.ReadFull(r, f.data); err != nil { + return bondFrame{}, err + } + } + return f, nil +} + +func closeWrite(conn net.Conn) { + type closeWriter interface { + CloseWrite() error + } + if cw, ok := conn.(closeWriter); ok { + if err := cw.CloseWrite(); err != nil && isDebug { + log.Printf("CloseWrite failed: %v", err) + } + } +} + +func handleBondedTCP(ctx context.Context, tcpConn net.Conn, connID uint64, candidates []*pooledSession) { + defer func() { _ = tcpConn.Close() }() + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + lanes := make([]*bondClientLane, 0, len(candidates)) + laneIDs := make([]string, 0, len(candidates)) + for i, ps := range candidates { + if ps.sess.IsClosed() { + continue + } + stream, err := ps.sess.OpenStream() + if err != nil { + log.Printf("[bond %d] session %d open stream error: %s", connID, ps.id, err) + continue + } + if err := writeBondHello(stream, connID, uint16(i), uint16(len(candidates))); err != nil { + log.Printf("[bond %d] session %d hello error: %s", connID, ps.id, err) + _ = stream.Close() + continue + } + ps.opened.Add(1) + ps.active.Add(1) + lanes = append(lanes, &bondClientLane{ps: ps, stream: stream}) + laneIDs = append(laneIDs, strconv.Itoa(ps.id)) + } + + if len(lanes) == 0 { + log.Printf("[bond %d] no usable lanes, rejecting TCP from %s", connID, tcpConn.RemoteAddr()) + return + } + context.AfterFunc(ctx, func() { + now := time.Now() + if err := tcpConn.SetDeadline(now); err != nil && isDebug { + log.Printf("[bond %d] local TCP deadline error: %v", connID, err) + } + for _, lane := range lanes { + if err := lane.stream.SetDeadline(now); err != nil && isDebug { + log.Printf("[bond %d] session %d stream deadline error: %v", connID, lane.ps.id, err) + } + } + }) + + log.Printf("[bond %d] TCP accept from=%s lanes=%d [%s]", connID, tcpConn.RemoteAddr(), len(lanes), strings.Join(laneIDs, ",")) + defer func() { + for _, lane := range lanes { + _ = lane.stream.Close() + active := lane.ps.active.Add(-1) + closed := lane.ps.closed.Add(1) + log.Printf("[bond %d] lane session %d close active=%d closed=%d totals: to-session=%s from-session=%s", + connID, lane.ps.id, active, closed, + formatByteCount(lane.ps.toSession.Load()), formatByteCount(lane.ps.fromSession.Load())) + } + }() + + recvCh := make(chan bondFrame, 1024) + var readWG sync.WaitGroup + for _, lane := range lanes { + readWG.Add(1) + go func(l *bondClientLane) { + defer readWG.Done() + for { + f, err := readBondFrame(l.stream) + if err != nil { + l.dead.Store(true) + select { + case <-ctx.Done(): + default: + if err != io.EOF { + log.Printf("[bond %d] session %d read frame error: %v", connID, l.ps.id, err) + } + } + return + } + if f.typ == bondFrameData { + l.ps.fromSession.Add(uint64(len(f.data))) + } + select { + case recvCh <- f: + case <-ctx.Done(): + return + } + } + }(lane) + } + go func() { + readWG.Wait() + close(recvCh) + }() + + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + copyTCPToBond(ctx, connID, tcpConn, lanes) + }() + go func() { + defer wg.Done() + copyBondToTCP(ctx, connID, tcpConn, recvCh) + cancel() + }() + wg.Wait() +} + +func copyTCPToBond(ctx context.Context, connID uint64, tcpConn net.Conn, lanes []*bondClientLane) { + buf := make([]byte, bondMaxChunk) + var seq uint64 + var laneIdx uint64 + for { + n, err := tcpConn.Read(buf) + if n > 0 { + data := make([]byte, n) + copy(data, buf[:n]) + lane, writeErr := writeBondFrameToNextLane(ctx, lanes, bondFrameData, seq, data, &laneIdx) + if writeErr != nil { + log.Printf("[bond %d] write data error: %v", connID, writeErr) + return + } + lane.ps.toSession.Add(uint64(n)) + seq++ + } + if err != nil { + if isDebug && err != io.EOF { + log.Printf("[bond %d] local TCP read finished with error: %v", connID, err) + } + for _, lane := range lanes { + if lane.dead.Load() { + continue + } + lane.mu.Lock() + writeErr := writeBondFrame(lane.stream, bondFrameFIN, seq, nil) + lane.mu.Unlock() + if writeErr != nil && ctx.Err() == nil { + log.Printf("[bond %d] session %d write FIN error: %v", connID, lane.ps.id, writeErr) + } + } + log.Printf("[bond %d] upload finished chunks=%d", connID, seq) + return + } + select { + case <-ctx.Done(): + return + default: + } + } +} + +func writeBondFrameToNextLane(ctx context.Context, lanes []*bondClientLane, typ byte, seq uint64, data []byte, laneIdx *uint64) (*bondClientLane, error) { + for attempts := 0; attempts < len(lanes); attempts++ { + idx := *laneIdx % uint64(len(lanes)) + *laneIdx++ + lane := lanes[idx] + if lane.dead.Load() { + continue + } + lane.mu.Lock() + err := writeBondFrame(lane.stream, typ, seq, data) + lane.mu.Unlock() + if err == nil { + return lane, nil + } + lane.dead.Store(true) + if ctx.Err() != nil { + return nil, ctx.Err() + } + } + if ctx.Err() != nil { + return nil, ctx.Err() + } + return nil, fmt.Errorf("no live bond lanes") +} + +func copyBondToTCP(ctx context.Context, connID uint64, tcpConn net.Conn, recvCh <-chan bondFrame) { + pending := make(map[uint64][]byte) + var expect uint64 + var finSeq *uint64 + + for { + if finSeq != nil && expect == *finSeq { + closeWrite(tcpConn) + log.Printf("[bond %d] download finished chunks=%d", connID, expect) + return + } + + select { + case <-ctx.Done(): + return + case f, ok := <-recvCh: + if !ok { + return + } + switch f.typ { + case bondFrameData: + pending[f.seq] = f.data + case bondFrameFIN: + v := f.seq + if finSeq == nil || v < *finSeq { + finSeq = &v + } + default: + log.Printf("[bond %d] unknown frame type %d", connID, f.typ) + return + } + + for { + data, ok := pending[expect] + if !ok { + break + } + delete(pending, expect) + if len(data) > 0 { + if _, err := tcpConn.Write(data); err != nil { + log.Printf("[bond %d] local TCP write error: %v", connID, err) + return + } + } + expect++ + } + } + } +} + // runVLESSMode implements TCP forwarding with round-robin across N TURN sessions. -func runVLESSMode(ctx context.Context, tp *turnParams, peer *net.UDPAddr, listenAddr string, numSessions int) { +func runVLESSMode(ctx context.Context, tp *turnParams, peer *net.UDPAddr, listenAddr string, numSessions int, bond bool) { pool := &sessionPool{} // Start N session maintainers with staggered startup @@ -2182,7 +2619,11 @@ func runVLESSMode(ctx context.Context, tp *turnParams, peer *net.UDPAddr, listen } context.AfterFunc(ctx, func() { _ = wrappedListener.Close() }) - log.Printf("VLESS mode: listening on %s (round-robin across %d sessions)", listenAddr, numSessions) + if bond { + log.Printf("VLESS bond mode: listening on %s (striping each TCP connection across active sessions)", listenAddr) + } else { + log.Printf("VLESS mode: listening on %s (round-robin across %d sessions)", listenAddr, numSessions) + } var wgConn sync.WaitGroup for { @@ -2199,25 +2640,60 @@ func runVLESSMode(ctx context.Context, tp *turnParams, peer *net.UDPAddr, listen continue } - sess := pool.pick() - if sess == nil || sess.IsClosed() { + if bond { + connID := (uint64(time.Now().UnixNano()) << 16) ^ pool.nextConnID() + lanes := pool.snapshot() + if len(lanes) == 0 { + log.Printf("No active sessions, rejecting connection") + _ = tcpConn.Close() + continue + } + + wgConn.Add(1) + go func(tc net.Conn, connID uint64, lanes []*pooledSession) { + defer wgConn.Done() + handleBondedTCP(ctx, tc, connID, lanes) + }(tcpConn, connID, lanes) + continue + } + + ps := pool.pick() + if ps == nil || ps.sess.IsClosed() { log.Printf("No active sessions, rejecting connection") _ = tcpConn.Close() continue } + connID := pool.nextConnID() + opened := ps.opened.Add(1) + active := ps.active.Add(1) + log.Printf("[session %d] TCP accept #%d from=%s active=%d opened=%d pool=%d", + ps.id, connID, tcpConn.RemoteAddr(), active, opened, pool.count()) + wgConn.Add(1) - go func(tc net.Conn, s *smux.Session) { + go func(tc net.Conn, ps *pooledSession, connID uint64) { defer wgConn.Done() defer func() { _ = tc.Close() }() - stream, err := s.OpenStream() + defer func() { + active := ps.active.Add(-1) + closed := ps.closed.Add(1) + log.Printf("[session %d] TCP close #%d active=%d closed=%d totals: to-session=%s from-session=%s", + ps.id, connID, active, closed, + formatByteCount(ps.toSession.Load()), formatByteCount(ps.fromSession.Load())) + }() + + stream, err := ps.sess.OpenStream() if err != nil { - log.Printf("smux open stream error: %s", err) + log.Printf("[session %d] smux open stream error for TCP #%d: %s", ps.id, connID, err) return } defer func() { _ = stream.Close() }() - pipe(ctx, tc, stream) - }(tcpConn, sess) + fromSession, toSession := pipe(ctx, tc, stream) + ps.fromSession.Add(uint64(fromSession)) + ps.toSession.Add(uint64(toSession)) + log.Printf("[session %d] TCP done #%d local<-session=%s local->session=%s", + ps.id, connID, formatByteCount(uint64(fromSession)), formatByteCount(uint64(toSession))) + }(tcpConn, ps, connID) } } @@ -2241,20 +2717,20 @@ func maintainVLESSSession(ctx context.Context, tp *turnParams, peer *net.UDPAddr continue } - pool.add(smuxSess) + ps := pool.add(id, smuxSess) log.Printf("[session %d] connected (active: %d)", id, pool.count()) for !smuxSess.IsClosed() { select { case <-ctx.Done(): - pool.remove(smuxSess) + pool.remove(ps) cleanup() return case <-time.After(1 * time.Second): } } - pool.remove(smuxSess) + pool.remove(ps) cleanup() log.Printf("[session %d] disconnected (active: %d), reconnecting...", id, pool.count()) @@ -2384,7 +2860,12 @@ func createSmuxSession(ctx context.Context, tp *turnParams, peer *net.UDPAddr, i log.Printf("DTLS connection established") // 5. Create KCP session over DTLS - kcpSess, err := tcputil.NewKCPOverDTLS(dtlsConn, false) + statsCtx, statsCancel := context.WithCancel(ctx) + cleanupFns = append(cleanupFns, statsCancel) + stats := &throughputStats{} + go stats.logEvery(statsCtx, fmt.Sprintf("[session %d] VLESS", id), "to-turn", "from-turn") + + kcpSess, err := tcputil.NewKCPOverDTLS(&countingConn{Conn: dtlsConn, stats: stats}, false) if err != nil { cleanup() return nil, nil, fmt.Errorf("KCP session: %w", err) @@ -2425,7 +2906,8 @@ func (r *relayPacketConn) SetReadDeadline(t time.Time) error { return r.relay.S func (r *relayPacketConn) SetWriteDeadline(t time.Time) error { return r.relay.SetWriteDeadline(t) } // pipe copies data bidirectionally between two connections. -func pipe(ctx context.Context, c1, c2 net.Conn) { +// It returns bytes copied as c1<-c2 and c2<-c1. +func pipe(ctx context.Context, c1, c2 net.Conn) (int64, int64) { ctx2, cancel := context.WithCancel(ctx) context.AfterFunc(ctx2, func() { if err := c1.SetDeadline(time.Now()); err != nil { @@ -2437,11 +2919,15 @@ func pipe(ctx context.Context, c1, c2 net.Conn) { }) var wg sync.WaitGroup + var c1FromC2 int64 + var c2FromC1 int64 wg.Add(2) go func() { defer wg.Done() defer cancel() - if _, err := io.Copy(c1, c2); err != nil { + n, err := io.Copy(c1, c2) + c1FromC2 = n + if err != nil { if isDebug { log.Printf("pipe: c1<-c2 copy error: %v", err) } @@ -2450,7 +2936,9 @@ func pipe(ctx context.Context, c1, c2 net.Conn) { go func() { defer wg.Done() defer cancel() - if _, err := io.Copy(c2, c1); err != nil { + n, err := io.Copy(c2, c1) + c2FromC1 = n + if err != nil { if isDebug { log.Printf("pipe: c2<-c1 copy error: %v", err) } @@ -2467,4 +2955,5 @@ func pipe(ctx context.Context, c1, c2 net.Conn) { log.Printf("pipe: failed to reset deadline c2: %v", err) } } + return c1FromC2, c2FromC1 } diff --git a/docker-entrypoint.sh b/docker-entrypoint.sh index 941bf24..86c3879 100644 --- a/docker-entrypoint.sh +++ b/docker-entrypoint.sh @@ -8,4 +8,9 @@ if [ "${VLESS_MODE}" = "true" ]; then VLESS_FLAG="-vless" fi -exec ./vk-turn-proxy -listen 0.0.0.0:56000 -connect "$CONNECT" $VLESS_FLAG +BOND_FLAG="" +if [ "${VLESS_BOND}" = "true" ]; then + BOND_FLAG="-vless-bond" +fi + +exec ./vk-turn-proxy -listen 0.0.0.0:56000 -connect "$CONNECT" $VLESS_FLAG $BOND_FLAG diff --git a/server/main.go b/server/main.go index 00f871a..e7ecc80 100644 --- a/server/main.go +++ b/server/main.go @@ -2,6 +2,7 @@ package main import ( "context" + "encoding/binary" "flag" "fmt" "io" @@ -10,6 +11,7 @@ import ( "os" "os/signal" "sync" + "sync/atomic" "syscall" "time" @@ -23,6 +25,7 @@ func main() { listen := flag.String("listen", "0.0.0.0:56000", "listen on ip:port") connect := flag.String("connect", "", "connect to ip:port") vlessMode := flag.Bool("vless", false, "VLESS mode: forward TCP connections (for VLESS) instead of UDP packets") + vlessBond := flag.Bool("vless-bond", false, "bond one VLESS TCP connection across all active smux sessions") flag.Parse() ctx, cancel := context.WithCancel(context.Background()) @@ -44,6 +47,7 @@ func main() { if len(*connect) == 0 { log.Panicf("server address is required") } + log.Printf("Starting server listen=%s connect=%s vless=%t vless-bond=%t bond-autodetect=true", *listen, *connect, *vlessMode, *vlessBond) // Generate a certificate and private key to secure the connection certificate, genErr := selfsign.GenerateSelfSigned() if genErr != nil { @@ -115,7 +119,7 @@ func main() { log.Println("Handshake done") if *vlessMode { - handleVLESSConnection(ctx, dtlsConn, *connect) + handleVLESSConnection(ctx, dtlsConn, *connect, *vlessBond) } else { handleUDPConnection(ctx, conn, *connect) } @@ -125,6 +129,553 @@ func main() { } } +type throughputStats struct { + tx atomic.Uint64 + rx atomic.Uint64 +} + +func (s *throughputStats) addTx(n int) { + if n > 0 { + s.tx.Add(uint64(n)) + } +} + +func (s *throughputStats) addRx(n int) { + if n > 0 { + s.rx.Add(uint64(n)) + } +} + +func (s *throughputStats) logEvery(ctx context.Context, label, txName, rxName string) { + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + + var prevTx, prevRx uint64 + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + tx := s.tx.Load() + rx := s.rx.Load() + deltaTx := tx - prevTx + deltaRx := rx - prevRx + prevTx = tx + prevRx = rx + + if deltaTx == 0 && deltaRx == 0 { + continue + } + + log.Printf( + "%s throughput: %s=%s %s=%s total_%s=%s total_%s=%s", + label, + txName, + formatBitsPerSecond(deltaTx, 5*time.Second), + rxName, + formatBitsPerSecond(deltaRx, 5*time.Second), + txName, + formatByteCount(tx), + rxName, + formatByteCount(rx), + ) + } + } +} + +func formatBitsPerSecond(bytes uint64, interval time.Duration) string { + if interval <= 0 { + interval = time.Second + } + + bps := float64(bytes*8) / interval.Seconds() + if bps >= 1_000_000 { + return fmt.Sprintf("%.2f Mbit/s", bps/1_000_000) + } + if bps >= 1_000 { + return fmt.Sprintf("%.1f kbit/s", bps/1_000) + } + return fmt.Sprintf("%.0f bit/s", bps) +} + +func formatByteCount(bytes uint64) string { + if bytes >= 1024*1024 { + return fmt.Sprintf("%.2f MiB", float64(bytes)/(1024*1024)) + } + if bytes >= 1024 { + return fmt.Sprintf("%.1f KiB", float64(bytes)/1024) + } + return fmt.Sprintf("%d B", bytes) +} + +type countingConn struct { + net.Conn + stats *throughputStats +} + +func (c *countingConn) Read(p []byte) (int, error) { + n, err := c.Conn.Read(p) + c.stats.addRx(n) + return n, err +} + +func (c *countingConn) Write(p []byte) (int, error) { + n, err := c.Conn.Write(p) + c.stats.addTx(n) + return n, err +} + +const ( + bondVersion = 1 + bondMagic = "VLB1" + + bondFrameData byte = 1 + bondFrameFIN byte = 2 + + bondMaxChunk = 16 * 1024 + + bondLaneAttachTimeout = 300 * time.Millisecond +) + +type bondHello struct { + connID uint64 + laneIndex uint16 + laneCount uint16 +} + +type bondFrame struct { + typ byte + seq uint64 + data []byte +} + +func readBondHello(r io.Reader) (bondHello, error) { + var hdr [17]byte + if _, err := io.ReadFull(r, hdr[:]); err != nil { + return bondHello{}, err + } + return parseBondHelloHeader(hdr[:]) +} + +func readBondHelloAfterMagic(r io.Reader, magic [4]byte) (bondHello, error) { + var hdr [17]byte + copy(hdr[0:4], magic[:]) + if _, err := io.ReadFull(r, hdr[4:]); err != nil { + return bondHello{}, err + } + return parseBondHelloHeader(hdr[:]) +} + +func parseBondHelloHeader(hdr []byte) (bondHello, error) { + if len(hdr) != 17 { + return bondHello{}, fmt.Errorf("bad bond hello size: %d", len(hdr)) + } + if string(hdr[0:4]) != bondMagic { + return bondHello{}, fmt.Errorf("bad bond magic") + } + if hdr[4] != bondVersion { + return bondHello{}, fmt.Errorf("unsupported bond version: %d", hdr[4]) + } + return bondHello{ + connID: binary.BigEndian.Uint64(hdr[5:13]), + laneIndex: binary.BigEndian.Uint16(hdr[13:15]), + laneCount: binary.BigEndian.Uint16(hdr[15:17]), + }, nil +} + +func writeBondFrame(w io.Writer, typ byte, seq uint64, data []byte) error { + var hdr [13]byte + hdr[0] = typ + binary.BigEndian.PutUint64(hdr[1:9], seq) + binary.BigEndian.PutUint32(hdr[9:13], uint32(len(data))) + if _, err := w.Write(hdr[:]); err != nil { + return err + } + if len(data) == 0 { + return nil + } + _, err := w.Write(data) + return err +} + +func readBondFrame(r io.Reader) (bondFrame, error) { + var hdr [13]byte + if _, err := io.ReadFull(r, hdr[:]); err != nil { + return bondFrame{}, err + } + size := binary.BigEndian.Uint32(hdr[9:13]) + if size > 4*1024*1024 { + return bondFrame{}, fmt.Errorf("bond frame too large: %d", size) + } + f := bondFrame{ + typ: hdr[0], + seq: binary.BigEndian.Uint64(hdr[1:9]), + } + if size > 0 { + f.data = make([]byte, size) + if _, err := io.ReadFull(r, f.data); err != nil { + return bondFrame{}, err + } + } + return f, nil +} + +func closeWrite(conn net.Conn) { + type closeWriter interface { + CloseWrite() error + } + if cw, ok := conn.(closeWriter); ok { + if err := cw.CloseWrite(); err != nil { + log.Printf("CloseWrite failed: %v", err) + } + } +} + +type bondServerLane struct { + index uint16 + stream *smux.Stream + mu sync.Mutex +} + +type bondServerConn struct { + id uint64 + connectAddr string + ctx context.Context + cancel context.CancelFunc + done chan struct{} + + lanesMu sync.RWMutex + lanes []*bondServerLane + want uint16 + ready chan struct{} + + recvCh chan bondFrame + once sync.Once +} + +type bondRegistry struct { + mu sync.Mutex + conns map[uint64]*bondServerConn +} + +var globalBondRegistry = &bondRegistry{conns: make(map[uint64]*bondServerConn)} + +func (r *bondRegistry) get(ctx context.Context, id uint64, connectAddr string) *bondServerConn { + r.mu.Lock() + defer r.mu.Unlock() + if c := r.conns[id]; c != nil { + return c + } + connCtx, cancel := context.WithCancel(ctx) + c := &bondServerConn{ + id: id, + connectAddr: connectAddr, + ctx: connCtx, + cancel: cancel, + done: make(chan struct{}), + ready: make(chan struct{}, 1), + recvCh: make(chan bondFrame, 1024), + } + r.conns[id] = c + go func() { + <-c.done + r.mu.Lock() + if r.conns[id] == c { + delete(r.conns, id) + } + r.mu.Unlock() + }() + return c +} + +func (c *bondServerConn) addLane(l *bondServerLane, laneCount uint16) { + c.lanesMu.Lock() + if laneCount > c.want { + c.want = laneCount + } + c.lanes = append(c.lanes, l) + count := len(c.lanes) + c.lanesMu.Unlock() + log.Printf("[bond %d] lane %d attached (lanes=%d)", c.id, l.index, count) + select { + case c.ready <- struct{}{}: + default: + } + + go c.readLane(l) + c.once.Do(func() { + go c.run() + }) +} + +func (c *bondServerConn) snapshotLanes() []*bondServerLane { + c.lanesMu.RLock() + defer c.lanesMu.RUnlock() + out := make([]*bondServerLane, len(c.lanes)) + copy(out, c.lanes) + return out +} + +func (c *bondServerConn) removeLane(l *bondServerLane) int { + c.lanesMu.Lock() + defer c.lanesMu.Unlock() + for i, lane := range c.lanes { + if lane == l { + c.lanes = append(c.lanes[:i], c.lanes[i+1:]...) + break + } + } + return len(c.lanes) +} + +func (c *bondServerConn) waitForInitialLanes() { + timer := time.NewTimer(bondLaneAttachTimeout) + defer timer.Stop() + for { + c.lanesMu.RLock() + count := len(c.lanes) + want := int(c.want) + c.lanesMu.RUnlock() + if want <= 0 || count >= want { + return + } + select { + case <-c.ctx.Done(): + return + case <-c.ready: + case <-timer.C: + log.Printf("[bond %d] starting with %d/%d lanes after attach timeout", c.id, count, want) + return + } + } +} + +func (c *bondServerConn) readLane(l *bondServerLane) { + for { + f, err := readBondFrame(l.stream) + if err != nil { + left := c.removeLane(l) + select { + case <-c.ctx.Done(): + default: + if err != io.EOF { + log.Printf("[bond %d] lane %d read error: %v (lanes=%d)", c.id, l.index, err, left) + } + if left == 0 { + c.cancel() + } + } + return + } + select { + case c.recvCh <- f: + case <-c.ctx.Done(): + return + } + } +} + +func (c *bondServerConn) run() { + defer close(c.done) + defer c.cancel() + + c.waitForInitialLanes() + + backendConn, err := net.DialTimeout("tcp", c.connectAddr, 10*time.Second) + if err != nil { + log.Printf("[bond %d] backend dial error: %s", c.id, err) + return + } + defer func() { + if err := backendConn.Close(); err != nil { + log.Printf("[bond %d] failed to close backend connection: %v", c.id, err) + } + }() + context.AfterFunc(c.ctx, func() { + now := time.Now() + if err := backendConn.SetDeadline(now); err != nil { + log.Printf("[bond %d] backend deadline error: %v", c.id, err) + } + for _, lane := range c.snapshotLanes() { + if err := lane.stream.SetDeadline(now); err != nil { + log.Printf("[bond %d] lane %d deadline error: %v", c.id, lane.index, err) + } + } + }) + log.Printf("[bond %d] backend connected", c.id) + + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + c.copyBondToBackend(backendConn) + }() + go func() { + defer wg.Done() + defer c.cancel() + c.copyBackendToBond(backendConn) + }() + wg.Wait() +} + +func (c *bondServerConn) copyBondToBackend(backendConn net.Conn) { + pending := make(map[uint64][]byte) + var expect uint64 + var finSeq *uint64 + + for { + if finSeq != nil && expect == *finSeq { + closeWrite(backendConn) + log.Printf("[bond %d] upload to backend finished chunks=%d", c.id, expect) + return + } + + select { + case <-c.ctx.Done(): + return + case f := <-c.recvCh: + switch f.typ { + case bondFrameData: + pending[f.seq] = f.data + case bondFrameFIN: + v := f.seq + if finSeq == nil || v < *finSeq { + finSeq = &v + } + default: + log.Printf("[bond %d] unknown frame type %d", c.id, f.typ) + return + } + + for { + data, ok := pending[expect] + if !ok { + break + } + delete(pending, expect) + if len(data) > 0 { + if _, err := backendConn.Write(data); err != nil { + log.Printf("[bond %d] backend write error: %v", c.id, err) + return + } + } + expect++ + } + } + } +} + +func (c *bondServerConn) copyBackendToBond(backendConn net.Conn) { + buf := make([]byte, bondMaxChunk) + var seq uint64 + var laneIdx uint64 + for { + n, err := backendConn.Read(buf) + if n > 0 { + data := make([]byte, n) + copy(data, buf[:n]) + if writeErr := c.writeToNextLane(bondFrameData, seq, data, &laneIdx); writeErr != nil { + log.Printf("[bond %d] lane write data error: %v", c.id, writeErr) + return + } + seq++ + } + if err != nil { + lanes := c.snapshotLanes() + for _, lane := range lanes { + lane.mu.Lock() + writeErr := writeBondFrame(lane.stream, bondFrameFIN, seq, nil) + lane.mu.Unlock() + if writeErr != nil && c.ctx.Err() == nil { + log.Printf("[bond %d] lane %d write FIN error: %v", c.id, lane.index, writeErr) + } + } + log.Printf("[bond %d] download from backend finished chunks=%d", c.id, seq) + return + } + select { + case <-c.ctx.Done(): + return + default: + } + } +} + +func (c *bondServerConn) writeToNextLane(typ byte, seq uint64, data []byte, laneIdx *uint64) error { + for { + lanes := c.snapshotLanes() + for attempts := 0; attempts < len(lanes); attempts++ { + lane := lanes[*laneIdx%uint64(len(lanes))] + (*laneIdx)++ + lane.mu.Lock() + err := writeBondFrame(lane.stream, typ, seq, data) + lane.mu.Unlock() + if err == nil { + return nil + } + left := c.removeLane(lane) + log.Printf("[bond %d] lane %d write error: %v (lanes=%d)", c.id, lane.index, err, left) + if left == 0 { + return err + } + } + select { + case <-c.ctx.Done(): + return c.ctx.Err() + case <-time.After(10 * time.Millisecond): + } + } +} + +func handleBondServerStream(ctx context.Context, stream *smux.Stream, connectAddr string) { + handleBondServerStreamWithHello(ctx, stream, connectAddr, readBondHello) +} + +func handleBondServerStreamAfterMagic(ctx context.Context, stream *smux.Stream, connectAddr string, magic [4]byte) { + handleBondServerStreamWithHello(ctx, stream, connectAddr, func(r io.Reader) (bondHello, error) { + return readBondHelloAfterMagic(r, magic) + }) +} + +func handleBondServerStreamWithHello(ctx context.Context, stream *smux.Stream, connectAddr string, readHello func(io.Reader) (bondHello, error)) { + defer func() { + if err := stream.Close(); err != nil && err != smux.ErrGoAway { + log.Printf("failed to close bond smux stream: %v", err) + } + }() + + hello, err := readHello(stream) + if err != nil { + log.Printf("bond hello error: %v", err) + return + } + + conn := globalBondRegistry.get(ctx, hello.connID, connectAddr) + conn.addLane(&bondServerLane{ + index: hello.laneIndex, + stream: stream, + }, hello.laneCount) + + select { + case <-ctx.Done(): + case <-conn.done: + } +} + +type prefixedConn struct { + net.Conn + prefix []byte +} + +func (c *prefixedConn) Read(p []byte) (int, error) { + if len(c.prefix) > 0 { + n := copy(p, c.prefix) + c.prefix = c.prefix[n:] + return n, nil + } + return c.Conn.Read(p) +} + // handleUDPConnection forwards DTLS packets to a UDP backend (WireGuard). func handleUDPConnection(ctx context.Context, conn net.Conn, connectAddr string) { serverConn, err := net.Dial("udp", connectAddr) @@ -141,6 +692,14 @@ func handleUDPConnection(ctx context.Context, conn net.Conn, connectAddr string) var wg sync.WaitGroup wg.Add(2) ctx2, cancel2 := context.WithCancel(ctx) + stats := &throughputStats{} + go stats.logEvery( + ctx2, + fmt.Sprintf("[DTLS %s]", conn.RemoteAddr()), + "dtls-to-backend", + "backend-to-dtls", + ) + context.AfterFunc(ctx2, func() { if err := conn.SetDeadline(time.Now()); err != nil { log.Printf("failed to set incoming deadline: %s", err) @@ -173,7 +732,8 @@ func handleUDPConnection(ctx context.Context, conn net.Conn, connectAddr string) log.Printf("Failed: %s", err1) return } - _, err1 = serverConn.Write(buf[:n]) + written, err1 := serverConn.Write(buf[:n]) + stats.addTx(written) if err1 != nil { log.Printf("Failed: %s", err1) return @@ -204,7 +764,8 @@ func handleUDPConnection(ctx context.Context, conn net.Conn, connectAddr string) log.Printf("Failed: %s", err1) return } - _, err1 = conn.Write(buf[:n]) + written, err1 := conn.Write(buf[:n]) + stats.addRx(written) if err1 != nil { log.Printf("Failed: %s", err1) return @@ -216,9 +777,19 @@ func handleUDPConnection(ctx context.Context, conn net.Conn, connectAddr string) // handleVLESSConnection creates a KCP+smux session over DTLS and forwards // each smux stream as a TCP connection to the backend (Xray/VLESS). -func handleVLESSConnection(ctx context.Context, dtlsConn net.Conn, connectAddr string) { +func handleVLESSConnection(ctx context.Context, dtlsConn net.Conn, connectAddr string, bond bool) { // 1. Create KCP session over DTLS - kcpSess, err := tcputil.NewKCPOverDTLS(dtlsConn, true) + statsCtx, statsCancel := context.WithCancel(ctx) + defer statsCancel() + stats := &throughputStats{} + go stats.logEvery( + statsCtx, + fmt.Sprintf("[VLESS %s]", dtlsConn.RemoteAddr()), + "to-client", + "from-client", + ) + + kcpSess, err := tcputil.NewKCPOverDTLS(&countingConn{Conn: dtlsConn, stats: stats}, true) if err != nil { log.Printf("KCP session error: %s", err) return @@ -260,6 +831,23 @@ func handleVLESSConnection(ctx context.Context, dtlsConn net.Conn, connectAddr s go func(s *smux.Stream) { defer wg.Done() + var prefix [4]byte + if _, err := io.ReadFull(s, prefix[:]); err != nil { + if err != io.EOF && err != io.ErrUnexpectedEOF { + log.Printf("smux stream prefix read error: %v", err) + } + _ = s.Close() + return + } + if string(prefix[:]) == bondMagic { + log.Printf("auto-detected bond smux stream") + handleBondServerStreamAfterMagic(ctx, s, connectAddr, prefix) + return + } + if bond { + log.Printf("non-bond smux stream accepted while -vless-bond is enabled") + } + defer func() { if err := s.Close(); err != nil && err != smux.ErrGoAway { log.Printf("failed to close smux stream: %v", err) @@ -279,7 +867,7 @@ func handleVLESSConnection(ctx context.Context, dtlsConn net.Conn, connectAddr s }() // Bidirectional copy - pipeConn(ctx, s, backendConn) + pipeConn(ctx, &prefixedConn{Conn: s, prefix: prefix[:]}, backendConn) }(stream) } wg.Wait() diff --git a/tcputil/tcputil.go b/tcputil/tcputil.go index 896d31b..bb2aab2 100644 --- a/tcputil/tcputil.go +++ b/tcputil/tcputil.go @@ -2,12 +2,111 @@ package tcputil import ( "net" + "os" + "strconv" + "strings" "time" "github.com/xtaci/kcp-go/v5" "github.com/xtaci/smux" ) +type kcpProfile struct { + nodelay int + interval int + resend int + nc int + sndWnd int + rcvWnd int + mtu int + ackNoDelay bool +} + +func selectedKCPProfile() kcpProfile { + profile := strings.ToLower(strings.TrimSpace(os.Getenv("VK_TURN_KCP_PROFILE"))) + var cfg kcpProfile + switch profile { + case "legacy", "fast": + cfg = kcpProfile{ + nodelay: 1, + interval: 10, + resend: 2, + nc: 1, + sndWnd: 4096, + rcvWnd: 4096, + mtu: 1280, + ackNoDelay: true, + } + case "cc", "balanced": + cfg = kcpProfile{ + nodelay: 1, + interval: 20, + resend: 2, + nc: 0, + sndWnd: 512, + rcvWnd: 512, + mtu: 1200, + ackNoDelay: true, + } + case "slow", "conservative": + cfg = kcpProfile{ + nodelay: 0, + interval: 40, + resend: 2, + nc: 0, + sndWnd: 256, + rcvWnd: 256, + mtu: 1150, + ackNoDelay: false, + } + default: + cfg = kcpProfile{ + nodelay: 1, + interval: 20, + resend: 2, + nc: 1, + sndWnd: 512, + rcvWnd: 512, + mtu: 1200, + ackNoDelay: true, + } + } + + cfg.nodelay = envInt("VK_TURN_KCP_NODELAY", cfg.nodelay) + cfg.interval = envInt("VK_TURN_KCP_INTERVAL", cfg.interval) + cfg.resend = envInt("VK_TURN_KCP_RESEND", cfg.resend) + cfg.nc = envInt("VK_TURN_KCP_NC", cfg.nc) + cfg.sndWnd = envInt("VK_TURN_KCP_SNDWND", cfg.sndWnd) + cfg.rcvWnd = envInt("VK_TURN_KCP_RCVWND", cfg.rcvWnd) + cfg.mtu = envInt("VK_TURN_KCP_MTU", cfg.mtu) + cfg.ackNoDelay = envBool("VK_TURN_KCP_ACK_NODELAY", cfg.ackNoDelay) + return cfg +} + +func envInt(name string, fallback int) int { + raw := strings.TrimSpace(os.Getenv(name)) + if raw == "" { + return fallback + } + value, err := strconv.Atoi(raw) + if err != nil { + return fallback + } + return value +} + +func envBool(name string, fallback bool) bool { + raw := strings.ToLower(strings.TrimSpace(os.Getenv(name))) + switch raw { + case "1", "true", "yes", "on": + return true + case "0", "false", "no", "off": + return false + default: + return fallback + } +} + // DtlsPacketConn wraps a net.Conn (DTLS) as a net.PacketConn for KCP. // Each DTLS Read/Write preserves message boundaries (datagram semantics). type DtlsPacketConn struct { @@ -81,13 +180,11 @@ func NewKCPOverDTLS(dtlsConn net.Conn, isServer bool) (*kcp.UDPSession, error) { } } - // Tune KCP for TURN tunnel: - // - NoDelay mode for lower latency - // - Window sizes suitable for ~5Mbit/s - sess.SetNoDelay(1, 10, 2, 1) // nodelay, interval(ms), resend, nc - sess.SetWindowSize(4096, 4096) - sess.SetMtu(1280) // conservative MTU to fit inside DTLS+TURN - sess.SetACKNoDelay(true) + profile := selectedKCPProfile() + sess.SetNoDelay(profile.nodelay, profile.interval, profile.resend, profile.nc) + sess.SetWindowSize(profile.sndWnd, profile.rcvWnd) + sess.SetMtu(profile.mtu) + sess.SetACKNoDelay(profile.ackNoDelay) return sess, nil }