@ -11,7 +11,6 @@ import (
"os"
"os/signal"
"sync"
"sync/atomic"
"syscall"
"time"
@ -19,11 +18,15 @@ import (
"github.com/pion/dtls/v3/pkg/crypto/selfsign"
)
type streamEntry struct {
id byte
conn net . Conn
}
type UserSession struct {
ID string
Conns [ ] net . Conn
Conns [ ] streamEntry
BackendConn net . Conn
LastUsed uint32
Lock sync . RWMutex
Ctx context . Context
Cancel context . CancelFunc
@ -51,12 +54,12 @@ func (s *SessionManager) GetOrCreate(ctx context.Context, id string, connectAddr
sessionCtx , cancel := context . WithCancel ( ctx )
session := & UserSession {
ID : id ,
Conns : make ( [ ] streamEntry , 0 ) ,
BackendConn : backendConn ,
Manager : s ,
Ctx : sessionCtx ,
Cancel : cancel ,
}
s . Sessions [ id ] = session
go session . backendReaderLoop ( )
@ -66,6 +69,7 @@ func (s *SessionManager) GetOrCreate(ctx context.Context, id string, connectAddr
func ( s * UserSession ) backendReaderLoop ( ) {
defer s . Cleanup ( )
buf := make ( [ ] byte , 1600 )
var lastUsed uint32 = 0
for {
select {
case <- s . Ctx . Done ( ) :
@ -81,43 +85,54 @@ func (s *UserSession) backendReaderLoop() {
}
s . Lock . RLock ( )
if len ( s . Conns ) == 0 {
nConns := uint32 ( len ( s . Conns ) )
if nConns == 0 {
s . Lock . RUnlock ( )
continue
}
// Round-robin selection of DTLS connection
idx := atomic . AddUint32 ( & s . LastUsed , 1 ) % uint32 ( len ( s . Conns ) )
conn := s . Conns [ idx ]
// Fast Round-robin selection using local variable
lastUsed = ( lastUsed + 1 ) % nConns
conn := s . Conns [ lastUsed ] . conn
s . Lock . RUnlock ( )
conn . SetWriteDeadline ( time . Now ( ) . Add ( time . Second * 10 ) )
_ , err = conn . Write ( buf [ : n ] )
if err != nil {
log . Printf ( "Session %s DTLS write error: %v" , s . ID , err )
// Connection will be removed by its own reader loop
conn . Close ( )
}
}
}
}
func ( s * UserSession ) AddConn ( conn net . Conn ) {
func ( s * UserSession ) AddConn ( id byte , conn net . Conn ) {
s . Lock . Lock ( )
defer s . Lock . Unlock ( )
s . Conns = append ( s . Conns , conn )
// Evict existing connection with same ID
for i , entry := range s . Conns {
if entry . id == id {
//log.Printf("Session %s: Evicting old stream %d", s.ID, id)
entry . conn . Close ( )
s . Conns [ i ] . conn = conn
return
}
}
s . Conns = append ( s . Conns , streamEntry { id : id , conn : conn } )
}
func ( s * UserSession ) RemoveConn ( conn net . Conn ) {
func ( s * UserSession ) RemoveConn ( id byte , conn net . Conn ) {
s . Lock . Lock ( )
defer s . Lock . Unlock ( )
for i , c := range s . Conns {
if c == conn {
for i , entry := range s . Conns {
if entry . id == id && entry . conn == conn {
s . Conns = append ( s . Conns [ : i ] , s . Conns [ i + 1 : ] ... )
break
}
}
// If all connections are gone, we might want to start a timer to cleanup the session
// but for now we'll keep it alive until backendReaderLoop fails or context is cancelled.
}
func ( s * UserSession ) Cleanup ( ) {
s . Cancel ( )
s . BackendConn . Close ( )
@ -127,13 +142,12 @@ func (s *UserSession) Cleanup() {
s . Manager . Lock . Unlock ( )
s . Lock . Lock ( )
for _ , c := range s . Conns {
c . Close ( )
for _ , entry := range s . Conns {
entry . conn . Close ( )
}
s . Conns = nil
s . Lock . Unlock ( )
}
func main ( ) {
listen := flag . String ( "listen" , "0.0.0.0:56000" , "listen on ip:port" )
connect := flag . String ( "connect" , "" , "connect to ip:port" )
@ -213,15 +227,16 @@ func main() {
return
}
// Phase 1: Read Session ID (16 bytes)
idBuf := make ( [ ] byte , 16 )
// Phase 1: Read Session ID + Stream ID (17 bytes)
idBuf := make ( [ ] byte , 17 )
conn . SetReadDeadline ( time . Now ( ) . Add ( time . Second * 5 ) )
_ , err := io . ReadFull ( conn , idBuf )
if err != nil {
log . Println ( "Failed to read session ID:" , err )
return
}
sessionID := fmt . Sprintf ( "%x" , idBuf )
sessionID := fmt . Sprintf ( "%x" , idBuf [ : 16 ] )
streamID := idBuf [ 16 ]
session , err := manager . GetOrCreate ( ctx , sessionID , * connect )
if err != nil {
@ -229,15 +244,15 @@ func main() {
return
}
session . AddConn ( conn )
defer session . RemoveConn ( conn )
session . AddConn ( streamID , conn )
defer session . RemoveConn ( streamID , conn )
log . Printf ( "New stream for session %s from %s" , sessionID , conn . RemoteAddr ( ) )
log . Printf ( "New stream %d for session %s from %s" , streamID , sessionID , conn . RemoteAddr ( ) )
// Upstream Loop: DTLS -> Backend
buf := make ( [ ] byte , 1600 )
for {
conn . SetReadDeadline ( time . Now ( ) . Add ( time . Minute * 10 ) )
conn . SetReadDeadline ( time . Now ( ) . Add ( time . Minute * 5 ) )
n , err := conn . Read ( buf )
if err != nil {
log . Printf ( "Stream %s closed: %v" , sessionID , err )