diff --git a/server/main.go b/server/main.go index 36cd7cf..729262b 100644 --- a/server/main.go +++ b/server/main.go @@ -48,7 +48,7 @@ func main() { // Generate a certificate and private key to secure the connection certificate, genErr := selfsign.GenerateSelfSigned() if genErr != nil { - panic(err) + panic(genErr) } // Prepare the configuration of the DTLS connection @@ -98,19 +98,18 @@ func main() { // Perform the handshake with a 30-second timeout ctx1, cancel1 := context.WithTimeout(ctx, 30*time.Second) + defer cancel1() + dtlsConn, ok := conn.(*dtls.Conn) if !ok { - log.Println("Type error") - cancel1() + log.Println("Type error: expected *dtls.Conn") return } log.Println("Start handshake") if err := dtlsConn.HandshakeContext(ctx1); err != nil { - log.Println(err) - cancel1() + log.Printf("Handshake failed: %v", err) return } - cancel1() log.Println("Handshake done") if *vlessMode { @@ -222,7 +221,11 @@ func handleVLESSConnection(ctx context.Context, dtlsConn net.Conn, connectAddr s log.Printf("KCP session error: %s", err) return } - defer kcpSess.Close() + defer func() { + if err := kcpSess.Close(); err != nil { + log.Printf("failed to close KCP session: %v", err) + } + }() log.Printf("KCP session established (server)") // 2. Create smux server session over KCP @@ -231,7 +234,11 @@ func handleVLESSConnection(ctx context.Context, dtlsConn net.Conn, connectAddr s log.Printf("smux server error: %s", err) return } - defer smuxSess.Close() + defer func() { + if err := smuxSess.Close(); err != nil { + log.Printf("failed to close smux session: %v", err) + } + }() log.Printf("smux session established (server)") // 3. Accept smux streams and forward to backend via TCP @@ -250,7 +257,12 @@ func handleVLESSConnection(ctx context.Context, dtlsConn net.Conn, connectAddr s wg.Add(1) go func(s *smux.Stream) { defer wg.Done() - defer s.Close() + + defer func() { + if err := s.Close(); err != nil && err != smux.ErrGoAway { + log.Printf("failed to close smux stream: %v", err) + } + }() // Connect to backend (Xray/VLESS) backendConn, err := net.DialTimeout("tcp", connectAddr, 10*time.Second) @@ -258,7 +270,11 @@ func handleVLESSConnection(ctx context.Context, dtlsConn net.Conn, connectAddr s log.Printf("backend dial error: %s", err) return } - defer backendConn.Close() + defer func() { + if err := backendConn.Close(); err != nil { + log.Printf("failed to close backend connection: %v", err) + } + }() // Bidirectional copy pipeConn(ctx, s, backendConn) @@ -270,6 +286,8 @@ func handleVLESSConnection(ctx context.Context, dtlsConn net.Conn, connectAddr s // pipeConn copies data bidirectionally between two connections. func pipeConn(ctx context.Context, c1, c2 net.Conn) { ctx2, cancel := context.WithCancel(ctx) + defer cancel() + context.AfterFunc(ctx2, func() { if err := c1.SetDeadline(time.Now()); err != nil { log.Printf("pipeConn: failed to set deadline c1: %v", err) @@ -281,25 +299,24 @@ func pipeConn(ctx context.Context, c1, c2 net.Conn) { var wg sync.WaitGroup wg.Add(2) + go func() { defer wg.Done() - defer cancel() if _, err := io.Copy(c1, c2); err != nil { log.Printf("pipeConn: c1<-c2 copy error: %v", err) } }() + go func() { defer wg.Done() - defer cancel() if _, err := io.Copy(c2, c1); err != nil { log.Printf("pipeConn: c2<-c1 copy error: %v", err) } }() + wg.Wait() - if err := c1.SetDeadline(time.Time{}); err != nil { - log.Printf("pipeConn: failed to reset deadline c1: %v", err) - } - if err := c2.SetDeadline(time.Time{}); err != nil { - log.Printf("pipeConn: failed to reset deadline c2: %v", err) - } + + // Reset deadlines + _ = c1.SetDeadline(time.Time{}) + _ = c2.SetDeadline(time.Time{}) }