Browse Source

feat: add connection bonding and configurable kcp profiles

Implement multi-lane transport (bonding) for VLESS to aggregate bandwidth
from multiple KCP/DTLS sessions. Added environment-based KCP tuning
and throughput statistics.
- Support multi-lane bonding in client and server
- Add KCP profiles (fast, balanced, slow) via VK_TURN_KCP_PROFILE
- Auto-detect bonding streams via magic prefix
- Add real-time throughput logging for active connections
pull/162/head
Moroka8 1 month ago
parent
commit
96831328bf
  1. 1
      Dockerfile
  2. 543
      client/main.go
  3. 7
      docker-entrypoint.sh
  4. 600
      server/main.go
  5. 111
      tcputil/tcputil.go

1
Dockerfile

@ -15,6 +15,7 @@ COPY docker-entrypoint.sh .
COPY --from=builder /build/vk-turn-proxy .
RUN chmod +x docker-entrypoint.sh
EXPOSE 56000/tcp
EXPOSE 56000/udp
ENTRYPOINT ["./docker-entrypoint.sh"]

543
client/main.go

@ -9,6 +9,7 @@ import (
"crypto/md5"
"crypto/sha256"
"encoding/base64"
"encoding/binary"
"encoding/hex"
"encoding/json"
"flag"
@ -121,6 +122,102 @@ var packetPool = sync.Pool{
New: func() any { return &UDPPacket{Data: make([]byte, 2048)} },
}
type throughputStats struct {
tx atomic.Uint64
rx atomic.Uint64
}
func (s *throughputStats) addTx(n int) {
if n > 0 {
s.tx.Add(uint64(n))
}
}
func (s *throughputStats) addRx(n int) {
if n > 0 {
s.rx.Add(uint64(n))
}
}
func (s *throughputStats) logEvery(ctx context.Context, label, txName, rxName string) {
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
var prevTx, prevRx uint64
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
tx := s.tx.Load()
rx := s.rx.Load()
deltaTx := tx - prevTx
deltaRx := rx - prevRx
prevTx = tx
prevRx = rx
if deltaTx == 0 && deltaRx == 0 {
continue
}
log.Printf(
"%s throughput: %s=%s %s=%s total_%s=%s total_%s=%s",
label,
txName,
formatBitsPerSecond(deltaTx, 5*time.Second),
rxName,
formatBitsPerSecond(deltaRx, 5*time.Second),
txName,
formatByteCount(tx),
rxName,
formatByteCount(rx),
)
}
}
}
func formatBitsPerSecond(bytes uint64, interval time.Duration) string {
if interval <= 0 {
interval = time.Second
}
bps := float64(bytes*8) / interval.Seconds()
if bps >= 1_000_000 {
return fmt.Sprintf("%.2f Mbit/s", bps/1_000_000)
}
if bps >= 1_000 {
return fmt.Sprintf("%.1f kbit/s", bps/1_000)
}
return fmt.Sprintf("%.0f bit/s", bps)
}
func formatByteCount(bytes uint64) string {
if bytes >= 1024*1024 {
return fmt.Sprintf("%.2f MiB", float64(bytes)/(1024*1024))
}
if bytes >= 1024 {
return fmt.Sprintf("%.1f KiB", float64(bytes)/1024)
}
return fmt.Sprintf("%d B", bytes)
}
type countingConn struct {
net.Conn
stats *throughputStats
}
func (c *countingConn) Read(p []byte) (int, error) {
n, err := c.Conn.Read(p)
c.stats.addRx(n)
return n, err
}
func (c *countingConn) Write(p []byte) (int, error) {
n, err := c.Conn.Write(p)
c.stats.addTx(n)
return n, err
}
func newDirectNet() transport.Net {
return directNet{}
}
@ -1754,6 +1851,9 @@ func oneTurnConnection(ctx context.Context, turnParams *turnParams, peer *net.UD
wg := sync.WaitGroup{}
wg.Add(1)
turnctx, turncancel := context.WithCancel(ctx)
stats := &throughputStats{}
go stats.logEvery(turnctx, fmt.Sprintf("[STREAM %d] TURN", streamID), "to-turn", "from-turn")
context.AfterFunc(turnctx, func() {
if err := relayConn.SetDeadline(time.Now()); err != nil {
log.Printf("Failed to set relay deadline: %s", err)
@ -1779,7 +1879,8 @@ func oneTurnConnection(ctx context.Context, turnParams *turnParams, peer *net.UD
internalPipeAddr.Store(addr1)
_, err1 = relayConn.WriteTo(buf[:n], peer)
written, err1 := relayConn.WriteTo(buf[:n], peer)
stats.addTx(written)
if err1 != nil {
return
}
@ -1801,6 +1902,7 @@ func oneTurnConnection(ctx context.Context, turnParams *turnParams, peer *net.UD
}
if addr, ok := addr1.(net.Addr); ok {
stats.addRx(n)
if _, err := conn2.WriteTo(buf[:n], addr); err != nil {
return
}
@ -1937,6 +2039,7 @@ func main() {
udp := flag.Bool("udp", false, "connect to TURN with UDP")
direct := flag.Bool("no-dtls", false, "connect without obfuscation. DO NOT USE")
vlessMode := flag.Bool("vless", false, "VLESS mode: forward TCP connections (for VLESS) instead of UDP packets")
vlessBond := flag.Bool("vless-bond", false, "bond one VLESS TCP connection across all active smux sessions")
debugFlag := flag.Bool("debug", false, "enable debug logging")
manualCaptchaFlag := flag.Bool("manual-captcha", false, "skip auto captcha solving, use manual mode immediately")
flag.Parse()
@ -1996,7 +2099,7 @@ func main() {
}
if *vlessMode {
runVLESSMode(ctx, params, peer, *listen, *n)
runVLESSMode(ctx, params, peer, *listen, *n, *vlessBond)
return
}
@ -2097,22 +2200,35 @@ func main() {
}
// sessionPool manages a pool of smux sessions for round-robin TCP distribution.
type pooledSession struct {
id int
sess *smux.Session
active atomic.Int32
opened atomic.Uint64
closed atomic.Uint64
toSession atomic.Uint64
fromSession atomic.Uint64
}
type sessionPool struct {
mu sync.RWMutex
sessions []*smux.Session
counter atomic.Uint64
mu sync.RWMutex
sessions []*pooledSession
counter atomic.Uint64
connCounter atomic.Uint64
}
func (p *sessionPool) add(s *smux.Session) {
func (p *sessionPool) add(id int, s *smux.Session) *pooledSession {
ps := &pooledSession{id: id, sess: s}
p.mu.Lock()
p.sessions = append(p.sessions, s)
p.sessions = append(p.sessions, ps)
p.mu.Unlock()
return ps
}
func (p *sessionPool) remove(s *smux.Session) {
func (p *sessionPool) remove(ps *pooledSession) {
p.mu.Lock()
for i, sess := range p.sessions {
if sess == s {
if sess == ps {
p.sessions = append(p.sessions[:i], p.sessions[i+1:]...)
break
}
@ -2120,25 +2236,346 @@ func (p *sessionPool) remove(s *smux.Session) {
p.mu.Unlock()
}
func (p *sessionPool) pick() *smux.Session {
func (p *sessionPool) pick() *pooledSession {
p.mu.RLock()
defer p.mu.RUnlock()
n := len(p.sessions)
if n == 0 {
return nil
}
idx := p.counter.Add(1) % uint64(n)
idx := (p.counter.Add(1) - 1) % uint64(n)
return p.sessions[idx]
}
func (p *sessionPool) nextConnID() uint64 {
return p.connCounter.Add(1)
}
func (p *sessionPool) snapshot() []*pooledSession {
p.mu.RLock()
defer p.mu.RUnlock()
out := make([]*pooledSession, 0, len(p.sessions))
for _, ps := range p.sessions {
if !ps.sess.IsClosed() {
out = append(out, ps)
}
}
return out
}
func (p *sessionPool) count() int {
p.mu.RLock()
defer p.mu.RUnlock()
return len(p.sessions)
}
const (
bondVersion = 1
bondMagic = "VLB1"
bondFrameData byte = 1
bondFrameFIN byte = 2
bondMaxChunk = 16 * 1024
)
type bondFrame struct {
typ byte
seq uint64
data []byte
}
type bondClientLane struct {
ps *pooledSession
stream *smux.Stream
mu sync.Mutex
dead atomic.Bool
}
func writeBondHello(w io.Writer, connID uint64, laneIndex, laneCount uint16) error {
var hdr [17]byte
copy(hdr[0:4], bondMagic)
hdr[4] = bondVersion
binary.BigEndian.PutUint64(hdr[5:13], connID)
binary.BigEndian.PutUint16(hdr[13:15], laneIndex)
binary.BigEndian.PutUint16(hdr[15:17], laneCount)
_, err := w.Write(hdr[:])
return err
}
func writeBondFrame(w io.Writer, typ byte, seq uint64, data []byte) error {
var hdr [13]byte
hdr[0] = typ
binary.BigEndian.PutUint64(hdr[1:9], seq)
binary.BigEndian.PutUint32(hdr[9:13], uint32(len(data)))
if _, err := w.Write(hdr[:]); err != nil {
return err
}
if len(data) == 0 {
return nil
}
_, err := w.Write(data)
return err
}
func readBondFrame(r io.Reader) (bondFrame, error) {
var hdr [13]byte
if _, err := io.ReadFull(r, hdr[:]); err != nil {
return bondFrame{}, err
}
size := binary.BigEndian.Uint32(hdr[9:13])
if size > 4*1024*1024 {
return bondFrame{}, fmt.Errorf("bond frame too large: %d", size)
}
f := bondFrame{
typ: hdr[0],
seq: binary.BigEndian.Uint64(hdr[1:9]),
}
if size > 0 {
f.data = make([]byte, size)
if _, err := io.ReadFull(r, f.data); err != nil {
return bondFrame{}, err
}
}
return f, nil
}
func closeWrite(conn net.Conn) {
type closeWriter interface {
CloseWrite() error
}
if cw, ok := conn.(closeWriter); ok {
if err := cw.CloseWrite(); err != nil && isDebug {
log.Printf("CloseWrite failed: %v", err)
}
}
}
func handleBondedTCP(ctx context.Context, tcpConn net.Conn, connID uint64, candidates []*pooledSession) {
defer func() { _ = tcpConn.Close() }()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
lanes := make([]*bondClientLane, 0, len(candidates))
laneIDs := make([]string, 0, len(candidates))
for i, ps := range candidates {
if ps.sess.IsClosed() {
continue
}
stream, err := ps.sess.OpenStream()
if err != nil {
log.Printf("[bond %d] session %d open stream error: %s", connID, ps.id, err)
continue
}
if err := writeBondHello(stream, connID, uint16(i), uint16(len(candidates))); err != nil {
log.Printf("[bond %d] session %d hello error: %s", connID, ps.id, err)
_ = stream.Close()
continue
}
ps.opened.Add(1)
ps.active.Add(1)
lanes = append(lanes, &bondClientLane{ps: ps, stream: stream})
laneIDs = append(laneIDs, strconv.Itoa(ps.id))
}
if len(lanes) == 0 {
log.Printf("[bond %d] no usable lanes, rejecting TCP from %s", connID, tcpConn.RemoteAddr())
return
}
context.AfterFunc(ctx, func() {
now := time.Now()
if err := tcpConn.SetDeadline(now); err != nil && isDebug {
log.Printf("[bond %d] local TCP deadline error: %v", connID, err)
}
for _, lane := range lanes {
if err := lane.stream.SetDeadline(now); err != nil && isDebug {
log.Printf("[bond %d] session %d stream deadline error: %v", connID, lane.ps.id, err)
}
}
})
log.Printf("[bond %d] TCP accept from=%s lanes=%d [%s]", connID, tcpConn.RemoteAddr(), len(lanes), strings.Join(laneIDs, ","))
defer func() {
for _, lane := range lanes {
_ = lane.stream.Close()
active := lane.ps.active.Add(-1)
closed := lane.ps.closed.Add(1)
log.Printf("[bond %d] lane session %d close active=%d closed=%d totals: to-session=%s from-session=%s",
connID, lane.ps.id, active, closed,
formatByteCount(lane.ps.toSession.Load()), formatByteCount(lane.ps.fromSession.Load()))
}
}()
recvCh := make(chan bondFrame, 1024)
var readWG sync.WaitGroup
for _, lane := range lanes {
readWG.Add(1)
go func(l *bondClientLane) {
defer readWG.Done()
for {
f, err := readBondFrame(l.stream)
if err != nil {
l.dead.Store(true)
select {
case <-ctx.Done():
default:
if err != io.EOF {
log.Printf("[bond %d] session %d read frame error: %v", connID, l.ps.id, err)
}
}
return
}
if f.typ == bondFrameData {
l.ps.fromSession.Add(uint64(len(f.data)))
}
select {
case recvCh <- f:
case <-ctx.Done():
return
}
}
}(lane)
}
go func() {
readWG.Wait()
close(recvCh)
}()
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
copyTCPToBond(ctx, connID, tcpConn, lanes)
}()
go func() {
defer wg.Done()
copyBondToTCP(ctx, connID, tcpConn, recvCh)
cancel()
}()
wg.Wait()
}
func copyTCPToBond(ctx context.Context, connID uint64, tcpConn net.Conn, lanes []*bondClientLane) {
buf := make([]byte, bondMaxChunk)
var seq uint64
var laneIdx uint64
for {
n, err := tcpConn.Read(buf)
if n > 0 {
data := make([]byte, n)
copy(data, buf[:n])
lane, writeErr := writeBondFrameToNextLane(ctx, lanes, bondFrameData, seq, data, &laneIdx)
if writeErr != nil {
log.Printf("[bond %d] write data error: %v", connID, writeErr)
return
}
lane.ps.toSession.Add(uint64(n))
seq++
}
if err != nil {
if isDebug && err != io.EOF {
log.Printf("[bond %d] local TCP read finished with error: %v", connID, err)
}
for _, lane := range lanes {
if lane.dead.Load() {
continue
}
lane.mu.Lock()
writeErr := writeBondFrame(lane.stream, bondFrameFIN, seq, nil)
lane.mu.Unlock()
if writeErr != nil && ctx.Err() == nil {
log.Printf("[bond %d] session %d write FIN error: %v", connID, lane.ps.id, writeErr)
}
}
log.Printf("[bond %d] upload finished chunks=%d", connID, seq)
return
}
select {
case <-ctx.Done():
return
default:
}
}
}
func writeBondFrameToNextLane(ctx context.Context, lanes []*bondClientLane, typ byte, seq uint64, data []byte, laneIdx *uint64) (*bondClientLane, error) {
for attempts := 0; attempts < len(lanes); attempts++ {
idx := *laneIdx % uint64(len(lanes))
*laneIdx++
lane := lanes[idx]
if lane.dead.Load() {
continue
}
lane.mu.Lock()
err := writeBondFrame(lane.stream, typ, seq, data)
lane.mu.Unlock()
if err == nil {
return lane, nil
}
lane.dead.Store(true)
if ctx.Err() != nil {
return nil, ctx.Err()
}
}
if ctx.Err() != nil {
return nil, ctx.Err()
}
return nil, fmt.Errorf("no live bond lanes")
}
func copyBondToTCP(ctx context.Context, connID uint64, tcpConn net.Conn, recvCh <-chan bondFrame) {
pending := make(map[uint64][]byte)
var expect uint64
var finSeq *uint64
for {
if finSeq != nil && expect == *finSeq {
closeWrite(tcpConn)
log.Printf("[bond %d] download finished chunks=%d", connID, expect)
return
}
select {
case <-ctx.Done():
return
case f, ok := <-recvCh:
if !ok {
return
}
switch f.typ {
case bondFrameData:
pending[f.seq] = f.data
case bondFrameFIN:
v := f.seq
if finSeq == nil || v < *finSeq {
finSeq = &v
}
default:
log.Printf("[bond %d] unknown frame type %d", connID, f.typ)
return
}
for {
data, ok := pending[expect]
if !ok {
break
}
delete(pending, expect)
if len(data) > 0 {
if _, err := tcpConn.Write(data); err != nil {
log.Printf("[bond %d] local TCP write error: %v", connID, err)
return
}
}
expect++
}
}
}
}
// runVLESSMode implements TCP forwarding with round-robin across N TURN sessions.
func runVLESSMode(ctx context.Context, tp *turnParams, peer *net.UDPAddr, listenAddr string, numSessions int) {
func runVLESSMode(ctx context.Context, tp *turnParams, peer *net.UDPAddr, listenAddr string, numSessions int, bond bool) {
pool := &sessionPool{}
// Start N session maintainers with staggered startup
@ -2182,7 +2619,11 @@ func runVLESSMode(ctx context.Context, tp *turnParams, peer *net.UDPAddr, listen
}
context.AfterFunc(ctx, func() { _ = wrappedListener.Close() })
log.Printf("VLESS mode: listening on %s (round-robin across %d sessions)", listenAddr, numSessions)
if bond {
log.Printf("VLESS bond mode: listening on %s (striping each TCP connection across active sessions)", listenAddr)
} else {
log.Printf("VLESS mode: listening on %s (round-robin across %d sessions)", listenAddr, numSessions)
}
var wgConn sync.WaitGroup
for {
@ -2199,25 +2640,60 @@ func runVLESSMode(ctx context.Context, tp *turnParams, peer *net.UDPAddr, listen
continue
}
sess := pool.pick()
if sess == nil || sess.IsClosed() {
if bond {
connID := (uint64(time.Now().UnixNano()) << 16) ^ pool.nextConnID()
lanes := pool.snapshot()
if len(lanes) == 0 {
log.Printf("No active sessions, rejecting connection")
_ = tcpConn.Close()
continue
}
wgConn.Add(1)
go func(tc net.Conn, connID uint64, lanes []*pooledSession) {
defer wgConn.Done()
handleBondedTCP(ctx, tc, connID, lanes)
}(tcpConn, connID, lanes)
continue
}
ps := pool.pick()
if ps == nil || ps.sess.IsClosed() {
log.Printf("No active sessions, rejecting connection")
_ = tcpConn.Close()
continue
}
connID := pool.nextConnID()
opened := ps.opened.Add(1)
active := ps.active.Add(1)
log.Printf("[session %d] TCP accept #%d from=%s active=%d opened=%d pool=%d",
ps.id, connID, tcpConn.RemoteAddr(), active, opened, pool.count())
wgConn.Add(1)
go func(tc net.Conn, s *smux.Session) {
go func(tc net.Conn, ps *pooledSession, connID uint64) {
defer wgConn.Done()
defer func() { _ = tc.Close() }()
stream, err := s.OpenStream()
defer func() {
active := ps.active.Add(-1)
closed := ps.closed.Add(1)
log.Printf("[session %d] TCP close #%d active=%d closed=%d totals: to-session=%s from-session=%s",
ps.id, connID, active, closed,
formatByteCount(ps.toSession.Load()), formatByteCount(ps.fromSession.Load()))
}()
stream, err := ps.sess.OpenStream()
if err != nil {
log.Printf("smux open stream error: %s", err)
log.Printf("[session %d] smux open stream error for TCP #%d: %s", ps.id, connID, err)
return
}
defer func() { _ = stream.Close() }()
pipe(ctx, tc, stream)
}(tcpConn, sess)
fromSession, toSession := pipe(ctx, tc, stream)
ps.fromSession.Add(uint64(fromSession))
ps.toSession.Add(uint64(toSession))
log.Printf("[session %d] TCP done #%d local<-session=%s local->session=%s",
ps.id, connID, formatByteCount(uint64(fromSession)), formatByteCount(uint64(toSession)))
}(tcpConn, ps, connID)
}
}
@ -2241,20 +2717,20 @@ func maintainVLESSSession(ctx context.Context, tp *turnParams, peer *net.UDPAddr
continue
}
pool.add(smuxSess)
ps := pool.add(id, smuxSess)
log.Printf("[session %d] connected (active: %d)", id, pool.count())
for !smuxSess.IsClosed() {
select {
case <-ctx.Done():
pool.remove(smuxSess)
pool.remove(ps)
cleanup()
return
case <-time.After(1 * time.Second):
}
}
pool.remove(smuxSess)
pool.remove(ps)
cleanup()
log.Printf("[session %d] disconnected (active: %d), reconnecting...", id, pool.count())
@ -2384,7 +2860,12 @@ func createSmuxSession(ctx context.Context, tp *turnParams, peer *net.UDPAddr, i
log.Printf("DTLS connection established")
// 5. Create KCP session over DTLS
kcpSess, err := tcputil.NewKCPOverDTLS(dtlsConn, false)
statsCtx, statsCancel := context.WithCancel(ctx)
cleanupFns = append(cleanupFns, statsCancel)
stats := &throughputStats{}
go stats.logEvery(statsCtx, fmt.Sprintf("[session %d] VLESS", id), "to-turn", "from-turn")
kcpSess, err := tcputil.NewKCPOverDTLS(&countingConn{Conn: dtlsConn, stats: stats}, false)
if err != nil {
cleanup()
return nil, nil, fmt.Errorf("KCP session: %w", err)
@ -2425,7 +2906,8 @@ func (r *relayPacketConn) SetReadDeadline(t time.Time) error { return r.relay.S
func (r *relayPacketConn) SetWriteDeadline(t time.Time) error { return r.relay.SetWriteDeadline(t) }
// pipe copies data bidirectionally between two connections.
func pipe(ctx context.Context, c1, c2 net.Conn) {
// It returns bytes copied as c1<-c2 and c2<-c1.
func pipe(ctx context.Context, c1, c2 net.Conn) (int64, int64) {
ctx2, cancel := context.WithCancel(ctx)
context.AfterFunc(ctx2, func() {
if err := c1.SetDeadline(time.Now()); err != nil {
@ -2437,11 +2919,15 @@ func pipe(ctx context.Context, c1, c2 net.Conn) {
})
var wg sync.WaitGroup
var c1FromC2 int64
var c2FromC1 int64
wg.Add(2)
go func() {
defer wg.Done()
defer cancel()
if _, err := io.Copy(c1, c2); err != nil {
n, err := io.Copy(c1, c2)
c1FromC2 = n
if err != nil {
if isDebug {
log.Printf("pipe: c1<-c2 copy error: %v", err)
}
@ -2450,7 +2936,9 @@ func pipe(ctx context.Context, c1, c2 net.Conn) {
go func() {
defer wg.Done()
defer cancel()
if _, err := io.Copy(c2, c1); err != nil {
n, err := io.Copy(c2, c1)
c2FromC1 = n
if err != nil {
if isDebug {
log.Printf("pipe: c2<-c1 copy error: %v", err)
}
@ -2467,4 +2955,5 @@ func pipe(ctx context.Context, c1, c2 net.Conn) {
log.Printf("pipe: failed to reset deadline c2: %v", err)
}
}
return c1FromC2, c2FromC1
}

7
docker-entrypoint.sh

@ -8,4 +8,9 @@ if [ "${VLESS_MODE}" = "true" ]; then
VLESS_FLAG="-vless"
fi
exec ./vk-turn-proxy -listen 0.0.0.0:56000 -connect "$CONNECT" $VLESS_FLAG
BOND_FLAG=""
if [ "${VLESS_BOND}" = "true" ]; then
BOND_FLAG="-vless-bond"
fi
exec ./vk-turn-proxy -listen 0.0.0.0:56000 -connect "$CONNECT" $VLESS_FLAG $BOND_FLAG

600
server/main.go

@ -2,6 +2,7 @@ package main
import (
"context"
"encoding/binary"
"flag"
"fmt"
"io"
@ -10,6 +11,7 @@ import (
"os"
"os/signal"
"sync"
"sync/atomic"
"syscall"
"time"
@ -23,6 +25,7 @@ func main() {
listen := flag.String("listen", "0.0.0.0:56000", "listen on ip:port")
connect := flag.String("connect", "", "connect to ip:port")
vlessMode := flag.Bool("vless", false, "VLESS mode: forward TCP connections (for VLESS) instead of UDP packets")
vlessBond := flag.Bool("vless-bond", false, "bond one VLESS TCP connection across all active smux sessions")
flag.Parse()
ctx, cancel := context.WithCancel(context.Background())
@ -44,6 +47,7 @@ func main() {
if len(*connect) == 0 {
log.Panicf("server address is required")
}
log.Printf("Starting server listen=%s connect=%s vless=%t vless-bond=%t bond-autodetect=true", *listen, *connect, *vlessMode, *vlessBond)
// Generate a certificate and private key to secure the connection
certificate, genErr := selfsign.GenerateSelfSigned()
if genErr != nil {
@ -115,7 +119,7 @@ func main() {
log.Println("Handshake done")
if *vlessMode {
handleVLESSConnection(ctx, dtlsConn, *connect)
handleVLESSConnection(ctx, dtlsConn, *connect, *vlessBond)
} else {
handleUDPConnection(ctx, conn, *connect)
}
@ -125,6 +129,553 @@ func main() {
}
}
type throughputStats struct {
tx atomic.Uint64
rx atomic.Uint64
}
func (s *throughputStats) addTx(n int) {
if n > 0 {
s.tx.Add(uint64(n))
}
}
func (s *throughputStats) addRx(n int) {
if n > 0 {
s.rx.Add(uint64(n))
}
}
func (s *throughputStats) logEvery(ctx context.Context, label, txName, rxName string) {
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
var prevTx, prevRx uint64
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
tx := s.tx.Load()
rx := s.rx.Load()
deltaTx := tx - prevTx
deltaRx := rx - prevRx
prevTx = tx
prevRx = rx
if deltaTx == 0 && deltaRx == 0 {
continue
}
log.Printf(
"%s throughput: %s=%s %s=%s total_%s=%s total_%s=%s",
label,
txName,
formatBitsPerSecond(deltaTx, 5*time.Second),
rxName,
formatBitsPerSecond(deltaRx, 5*time.Second),
txName,
formatByteCount(tx),
rxName,
formatByteCount(rx),
)
}
}
}
func formatBitsPerSecond(bytes uint64, interval time.Duration) string {
if interval <= 0 {
interval = time.Second
}
bps := float64(bytes*8) / interval.Seconds()
if bps >= 1_000_000 {
return fmt.Sprintf("%.2f Mbit/s", bps/1_000_000)
}
if bps >= 1_000 {
return fmt.Sprintf("%.1f kbit/s", bps/1_000)
}
return fmt.Sprintf("%.0f bit/s", bps)
}
func formatByteCount(bytes uint64) string {
if bytes >= 1024*1024 {
return fmt.Sprintf("%.2f MiB", float64(bytes)/(1024*1024))
}
if bytes >= 1024 {
return fmt.Sprintf("%.1f KiB", float64(bytes)/1024)
}
return fmt.Sprintf("%d B", bytes)
}
type countingConn struct {
net.Conn
stats *throughputStats
}
func (c *countingConn) Read(p []byte) (int, error) {
n, err := c.Conn.Read(p)
c.stats.addRx(n)
return n, err
}
func (c *countingConn) Write(p []byte) (int, error) {
n, err := c.Conn.Write(p)
c.stats.addTx(n)
return n, err
}
const (
bondVersion = 1
bondMagic = "VLB1"
bondFrameData byte = 1
bondFrameFIN byte = 2
bondMaxChunk = 16 * 1024
bondLaneAttachTimeout = 300 * time.Millisecond
)
type bondHello struct {
connID uint64
laneIndex uint16
laneCount uint16
}
type bondFrame struct {
typ byte
seq uint64
data []byte
}
func readBondHello(r io.Reader) (bondHello, error) {
var hdr [17]byte
if _, err := io.ReadFull(r, hdr[:]); err != nil {
return bondHello{}, err
}
return parseBondHelloHeader(hdr[:])
}
func readBondHelloAfterMagic(r io.Reader, magic [4]byte) (bondHello, error) {
var hdr [17]byte
copy(hdr[0:4], magic[:])
if _, err := io.ReadFull(r, hdr[4:]); err != nil {
return bondHello{}, err
}
return parseBondHelloHeader(hdr[:])
}
func parseBondHelloHeader(hdr []byte) (bondHello, error) {
if len(hdr) != 17 {
return bondHello{}, fmt.Errorf("bad bond hello size: %d", len(hdr))
}
if string(hdr[0:4]) != bondMagic {
return bondHello{}, fmt.Errorf("bad bond magic")
}
if hdr[4] != bondVersion {
return bondHello{}, fmt.Errorf("unsupported bond version: %d", hdr[4])
}
return bondHello{
connID: binary.BigEndian.Uint64(hdr[5:13]),
laneIndex: binary.BigEndian.Uint16(hdr[13:15]),
laneCount: binary.BigEndian.Uint16(hdr[15:17]),
}, nil
}
func writeBondFrame(w io.Writer, typ byte, seq uint64, data []byte) error {
var hdr [13]byte
hdr[0] = typ
binary.BigEndian.PutUint64(hdr[1:9], seq)
binary.BigEndian.PutUint32(hdr[9:13], uint32(len(data)))
if _, err := w.Write(hdr[:]); err != nil {
return err
}
if len(data) == 0 {
return nil
}
_, err := w.Write(data)
return err
}
func readBondFrame(r io.Reader) (bondFrame, error) {
var hdr [13]byte
if _, err := io.ReadFull(r, hdr[:]); err != nil {
return bondFrame{}, err
}
size := binary.BigEndian.Uint32(hdr[9:13])
if size > 4*1024*1024 {
return bondFrame{}, fmt.Errorf("bond frame too large: %d", size)
}
f := bondFrame{
typ: hdr[0],
seq: binary.BigEndian.Uint64(hdr[1:9]),
}
if size > 0 {
f.data = make([]byte, size)
if _, err := io.ReadFull(r, f.data); err != nil {
return bondFrame{}, err
}
}
return f, nil
}
func closeWrite(conn net.Conn) {
type closeWriter interface {
CloseWrite() error
}
if cw, ok := conn.(closeWriter); ok {
if err := cw.CloseWrite(); err != nil {
log.Printf("CloseWrite failed: %v", err)
}
}
}
type bondServerLane struct {
index uint16
stream *smux.Stream
mu sync.Mutex
}
type bondServerConn struct {
id uint64
connectAddr string
ctx context.Context
cancel context.CancelFunc
done chan struct{}
lanesMu sync.RWMutex
lanes []*bondServerLane
want uint16
ready chan struct{}
recvCh chan bondFrame
once sync.Once
}
type bondRegistry struct {
mu sync.Mutex
conns map[uint64]*bondServerConn
}
var globalBondRegistry = &bondRegistry{conns: make(map[uint64]*bondServerConn)}
func (r *bondRegistry) get(ctx context.Context, id uint64, connectAddr string) *bondServerConn {
r.mu.Lock()
defer r.mu.Unlock()
if c := r.conns[id]; c != nil {
return c
}
connCtx, cancel := context.WithCancel(ctx)
c := &bondServerConn{
id: id,
connectAddr: connectAddr,
ctx: connCtx,
cancel: cancel,
done: make(chan struct{}),
ready: make(chan struct{}, 1),
recvCh: make(chan bondFrame, 1024),
}
r.conns[id] = c
go func() {
<-c.done
r.mu.Lock()
if r.conns[id] == c {
delete(r.conns, id)
}
r.mu.Unlock()
}()
return c
}
func (c *bondServerConn) addLane(l *bondServerLane, laneCount uint16) {
c.lanesMu.Lock()
if laneCount > c.want {
c.want = laneCount
}
c.lanes = append(c.lanes, l)
count := len(c.lanes)
c.lanesMu.Unlock()
log.Printf("[bond %d] lane %d attached (lanes=%d)", c.id, l.index, count)
select {
case c.ready <- struct{}{}:
default:
}
go c.readLane(l)
c.once.Do(func() {
go c.run()
})
}
func (c *bondServerConn) snapshotLanes() []*bondServerLane {
c.lanesMu.RLock()
defer c.lanesMu.RUnlock()
out := make([]*bondServerLane, len(c.lanes))
copy(out, c.lanes)
return out
}
func (c *bondServerConn) removeLane(l *bondServerLane) int {
c.lanesMu.Lock()
defer c.lanesMu.Unlock()
for i, lane := range c.lanes {
if lane == l {
c.lanes = append(c.lanes[:i], c.lanes[i+1:]...)
break
}
}
return len(c.lanes)
}
func (c *bondServerConn) waitForInitialLanes() {
timer := time.NewTimer(bondLaneAttachTimeout)
defer timer.Stop()
for {
c.lanesMu.RLock()
count := len(c.lanes)
want := int(c.want)
c.lanesMu.RUnlock()
if want <= 0 || count >= want {
return
}
select {
case <-c.ctx.Done():
return
case <-c.ready:
case <-timer.C:
log.Printf("[bond %d] starting with %d/%d lanes after attach timeout", c.id, count, want)
return
}
}
}
func (c *bondServerConn) readLane(l *bondServerLane) {
for {
f, err := readBondFrame(l.stream)
if err != nil {
left := c.removeLane(l)
select {
case <-c.ctx.Done():
default:
if err != io.EOF {
log.Printf("[bond %d] lane %d read error: %v (lanes=%d)", c.id, l.index, err, left)
}
if left == 0 {
c.cancel()
}
}
return
}
select {
case c.recvCh <- f:
case <-c.ctx.Done():
return
}
}
}
func (c *bondServerConn) run() {
defer close(c.done)
defer c.cancel()
c.waitForInitialLanes()
backendConn, err := net.DialTimeout("tcp", c.connectAddr, 10*time.Second)
if err != nil {
log.Printf("[bond %d] backend dial error: %s", c.id, err)
return
}
defer func() {
if err := backendConn.Close(); err != nil {
log.Printf("[bond %d] failed to close backend connection: %v", c.id, err)
}
}()
context.AfterFunc(c.ctx, func() {
now := time.Now()
if err := backendConn.SetDeadline(now); err != nil {
log.Printf("[bond %d] backend deadline error: %v", c.id, err)
}
for _, lane := range c.snapshotLanes() {
if err := lane.stream.SetDeadline(now); err != nil {
log.Printf("[bond %d] lane %d deadline error: %v", c.id, lane.index, err)
}
}
})
log.Printf("[bond %d] backend connected", c.id)
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
c.copyBondToBackend(backendConn)
}()
go func() {
defer wg.Done()
defer c.cancel()
c.copyBackendToBond(backendConn)
}()
wg.Wait()
}
func (c *bondServerConn) copyBondToBackend(backendConn net.Conn) {
pending := make(map[uint64][]byte)
var expect uint64
var finSeq *uint64
for {
if finSeq != nil && expect == *finSeq {
closeWrite(backendConn)
log.Printf("[bond %d] upload to backend finished chunks=%d", c.id, expect)
return
}
select {
case <-c.ctx.Done():
return
case f := <-c.recvCh:
switch f.typ {
case bondFrameData:
pending[f.seq] = f.data
case bondFrameFIN:
v := f.seq
if finSeq == nil || v < *finSeq {
finSeq = &v
}
default:
log.Printf("[bond %d] unknown frame type %d", c.id, f.typ)
return
}
for {
data, ok := pending[expect]
if !ok {
break
}
delete(pending, expect)
if len(data) > 0 {
if _, err := backendConn.Write(data); err != nil {
log.Printf("[bond %d] backend write error: %v", c.id, err)
return
}
}
expect++
}
}
}
}
func (c *bondServerConn) copyBackendToBond(backendConn net.Conn) {
buf := make([]byte, bondMaxChunk)
var seq uint64
var laneIdx uint64
for {
n, err := backendConn.Read(buf)
if n > 0 {
data := make([]byte, n)
copy(data, buf[:n])
if writeErr := c.writeToNextLane(bondFrameData, seq, data, &laneIdx); writeErr != nil {
log.Printf("[bond %d] lane write data error: %v", c.id, writeErr)
return
}
seq++
}
if err != nil {
lanes := c.snapshotLanes()
for _, lane := range lanes {
lane.mu.Lock()
writeErr := writeBondFrame(lane.stream, bondFrameFIN, seq, nil)
lane.mu.Unlock()
if writeErr != nil && c.ctx.Err() == nil {
log.Printf("[bond %d] lane %d write FIN error: %v", c.id, lane.index, writeErr)
}
}
log.Printf("[bond %d] download from backend finished chunks=%d", c.id, seq)
return
}
select {
case <-c.ctx.Done():
return
default:
}
}
}
func (c *bondServerConn) writeToNextLane(typ byte, seq uint64, data []byte, laneIdx *uint64) error {
for {
lanes := c.snapshotLanes()
for attempts := 0; attempts < len(lanes); attempts++ {
lane := lanes[*laneIdx%uint64(len(lanes))]
(*laneIdx)++
lane.mu.Lock()
err := writeBondFrame(lane.stream, typ, seq, data)
lane.mu.Unlock()
if err == nil {
return nil
}
left := c.removeLane(lane)
log.Printf("[bond %d] lane %d write error: %v (lanes=%d)", c.id, lane.index, err, left)
if left == 0 {
return err
}
}
select {
case <-c.ctx.Done():
return c.ctx.Err()
case <-time.After(10 * time.Millisecond):
}
}
}
func handleBondServerStream(ctx context.Context, stream *smux.Stream, connectAddr string) {
handleBondServerStreamWithHello(ctx, stream, connectAddr, readBondHello)
}
func handleBondServerStreamAfterMagic(ctx context.Context, stream *smux.Stream, connectAddr string, magic [4]byte) {
handleBondServerStreamWithHello(ctx, stream, connectAddr, func(r io.Reader) (bondHello, error) {
return readBondHelloAfterMagic(r, magic)
})
}
func handleBondServerStreamWithHello(ctx context.Context, stream *smux.Stream, connectAddr string, readHello func(io.Reader) (bondHello, error)) {
defer func() {
if err := stream.Close(); err != nil && err != smux.ErrGoAway {
log.Printf("failed to close bond smux stream: %v", err)
}
}()
hello, err := readHello(stream)
if err != nil {
log.Printf("bond hello error: %v", err)
return
}
conn := globalBondRegistry.get(ctx, hello.connID, connectAddr)
conn.addLane(&bondServerLane{
index: hello.laneIndex,
stream: stream,
}, hello.laneCount)
select {
case <-ctx.Done():
case <-conn.done:
}
}
type prefixedConn struct {
net.Conn
prefix []byte
}
func (c *prefixedConn) Read(p []byte) (int, error) {
if len(c.prefix) > 0 {
n := copy(p, c.prefix)
c.prefix = c.prefix[n:]
return n, nil
}
return c.Conn.Read(p)
}
// handleUDPConnection forwards DTLS packets to a UDP backend (WireGuard).
func handleUDPConnection(ctx context.Context, conn net.Conn, connectAddr string) {
serverConn, err := net.Dial("udp", connectAddr)
@ -141,6 +692,14 @@ func handleUDPConnection(ctx context.Context, conn net.Conn, connectAddr string)
var wg sync.WaitGroup
wg.Add(2)
ctx2, cancel2 := context.WithCancel(ctx)
stats := &throughputStats{}
go stats.logEvery(
ctx2,
fmt.Sprintf("[DTLS %s]", conn.RemoteAddr()),
"dtls-to-backend",
"backend-to-dtls",
)
context.AfterFunc(ctx2, func() {
if err := conn.SetDeadline(time.Now()); err != nil {
log.Printf("failed to set incoming deadline: %s", err)
@ -173,7 +732,8 @@ func handleUDPConnection(ctx context.Context, conn net.Conn, connectAddr string)
log.Printf("Failed: %s", err1)
return
}
_, err1 = serverConn.Write(buf[:n])
written, err1 := serverConn.Write(buf[:n])
stats.addTx(written)
if err1 != nil {
log.Printf("Failed: %s", err1)
return
@ -204,7 +764,8 @@ func handleUDPConnection(ctx context.Context, conn net.Conn, connectAddr string)
log.Printf("Failed: %s", err1)
return
}
_, err1 = conn.Write(buf[:n])
written, err1 := conn.Write(buf[:n])
stats.addRx(written)
if err1 != nil {
log.Printf("Failed: %s", err1)
return
@ -216,9 +777,19 @@ func handleUDPConnection(ctx context.Context, conn net.Conn, connectAddr string)
// handleVLESSConnection creates a KCP+smux session over DTLS and forwards
// each smux stream as a TCP connection to the backend (Xray/VLESS).
func handleVLESSConnection(ctx context.Context, dtlsConn net.Conn, connectAddr string) {
func handleVLESSConnection(ctx context.Context, dtlsConn net.Conn, connectAddr string, bond bool) {
// 1. Create KCP session over DTLS
kcpSess, err := tcputil.NewKCPOverDTLS(dtlsConn, true)
statsCtx, statsCancel := context.WithCancel(ctx)
defer statsCancel()
stats := &throughputStats{}
go stats.logEvery(
statsCtx,
fmt.Sprintf("[VLESS %s]", dtlsConn.RemoteAddr()),
"to-client",
"from-client",
)
kcpSess, err := tcputil.NewKCPOverDTLS(&countingConn{Conn: dtlsConn, stats: stats}, true)
if err != nil {
log.Printf("KCP session error: %s", err)
return
@ -260,6 +831,23 @@ func handleVLESSConnection(ctx context.Context, dtlsConn net.Conn, connectAddr s
go func(s *smux.Stream) {
defer wg.Done()
var prefix [4]byte
if _, err := io.ReadFull(s, prefix[:]); err != nil {
if err != io.EOF && err != io.ErrUnexpectedEOF {
log.Printf("smux stream prefix read error: %v", err)
}
_ = s.Close()
return
}
if string(prefix[:]) == bondMagic {
log.Printf("auto-detected bond smux stream")
handleBondServerStreamAfterMagic(ctx, s, connectAddr, prefix)
return
}
if bond {
log.Printf("non-bond smux stream accepted while -vless-bond is enabled")
}
defer func() {
if err := s.Close(); err != nil && err != smux.ErrGoAway {
log.Printf("failed to close smux stream: %v", err)
@ -279,7 +867,7 @@ func handleVLESSConnection(ctx context.Context, dtlsConn net.Conn, connectAddr s
}()
// Bidirectional copy
pipeConn(ctx, s, backendConn)
pipeConn(ctx, &prefixedConn{Conn: s, prefix: prefix[:]}, backendConn)
}(stream)
}
wg.Wait()

111
tcputil/tcputil.go

@ -2,12 +2,111 @@ package tcputil
import (
"net"
"os"
"strconv"
"strings"
"time"
"github.com/xtaci/kcp-go/v5"
"github.com/xtaci/smux"
)
type kcpProfile struct {
nodelay int
interval int
resend int
nc int
sndWnd int
rcvWnd int
mtu int
ackNoDelay bool
}
func selectedKCPProfile() kcpProfile {
profile := strings.ToLower(strings.TrimSpace(os.Getenv("VK_TURN_KCP_PROFILE")))
var cfg kcpProfile
switch profile {
case "legacy", "fast":
cfg = kcpProfile{
nodelay: 1,
interval: 10,
resend: 2,
nc: 1,
sndWnd: 4096,
rcvWnd: 4096,
mtu: 1280,
ackNoDelay: true,
}
case "cc", "balanced":
cfg = kcpProfile{
nodelay: 1,
interval: 20,
resend: 2,
nc: 0,
sndWnd: 512,
rcvWnd: 512,
mtu: 1200,
ackNoDelay: true,
}
case "slow", "conservative":
cfg = kcpProfile{
nodelay: 0,
interval: 40,
resend: 2,
nc: 0,
sndWnd: 256,
rcvWnd: 256,
mtu: 1150,
ackNoDelay: false,
}
default:
cfg = kcpProfile{
nodelay: 1,
interval: 20,
resend: 2,
nc: 1,
sndWnd: 512,
rcvWnd: 512,
mtu: 1200,
ackNoDelay: true,
}
}
cfg.nodelay = envInt("VK_TURN_KCP_NODELAY", cfg.nodelay)
cfg.interval = envInt("VK_TURN_KCP_INTERVAL", cfg.interval)
cfg.resend = envInt("VK_TURN_KCP_RESEND", cfg.resend)
cfg.nc = envInt("VK_TURN_KCP_NC", cfg.nc)
cfg.sndWnd = envInt("VK_TURN_KCP_SNDWND", cfg.sndWnd)
cfg.rcvWnd = envInt("VK_TURN_KCP_RCVWND", cfg.rcvWnd)
cfg.mtu = envInt("VK_TURN_KCP_MTU", cfg.mtu)
cfg.ackNoDelay = envBool("VK_TURN_KCP_ACK_NODELAY", cfg.ackNoDelay)
return cfg
}
func envInt(name string, fallback int) int {
raw := strings.TrimSpace(os.Getenv(name))
if raw == "" {
return fallback
}
value, err := strconv.Atoi(raw)
if err != nil {
return fallback
}
return value
}
func envBool(name string, fallback bool) bool {
raw := strings.ToLower(strings.TrimSpace(os.Getenv(name)))
switch raw {
case "1", "true", "yes", "on":
return true
case "0", "false", "no", "off":
return false
default:
return fallback
}
}
// DtlsPacketConn wraps a net.Conn (DTLS) as a net.PacketConn for KCP.
// Each DTLS Read/Write preserves message boundaries (datagram semantics).
type DtlsPacketConn struct {
@ -81,13 +180,11 @@ func NewKCPOverDTLS(dtlsConn net.Conn, isServer bool) (*kcp.UDPSession, error) {
}
}
// Tune KCP for TURN tunnel:
// - NoDelay mode for lower latency
// - Window sizes suitable for ~5Mbit/s
sess.SetNoDelay(1, 10, 2, 1) // nodelay, interval(ms), resend, nc
sess.SetWindowSize(4096, 4096)
sess.SetMtu(1280) // conservative MTU to fit inside DTLS+TURN
sess.SetACKNoDelay(true)
profile := selectedKCPProfile()
sess.SetNoDelay(profile.nodelay, profile.interval, profile.resend, profile.nc)
sess.SetWindowSize(profile.sndWnd, profile.rcvWnd)
sess.SetMtu(profile.mtu)
sess.SetACKNoDelay(profile.ackNoDelay)
return sess, nil
}

Loading…
Cancel
Save