diff --git a/server/main.go b/server/main.go index e4f2e0e..fcb68dc 100644 --- a/server/main.go +++ b/server/main.go @@ -110,7 +110,11 @@ func main() { } listener, err = dtls.NewListenerWithOptions(wrapListener, dtlsOpts...) } else { - listener, err = dtls.ListenWithOptions("udp", addr, dtlsOpts...) + udpListener, lerr := listenUDPForDTLS(addr) + if lerr != nil { + panic(lerr) + } + listener, err = dtls.NewListenerWithOptions(udpListener, dtlsOpts...) } if err != nil { panic(err) diff --git a/server/pktinfo_linux.go b/server/pktinfo_linux.go new file mode 100644 index 0000000..0b53a70 --- /dev/null +++ b/server/pktinfo_linux.go @@ -0,0 +1,123 @@ +//go:build linux + +package main + +import ( + "fmt" + "net" + "sync" + "time" + + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +type packetInfoUDPConn struct { + conn *net.UDPConn + ipv4 *ipv4.PacketConn + ipv6 *ipv6.PacketConn + v6 bool + + mu sync.RWMutex + localIPs map[string]net.IP +} + +func listenPacketInfoUDP(network string, laddr *net.UDPAddr) (net.PacketConn, error) { + conn, err := net.ListenUDP(network, laddr) + if err != nil { + return nil, err + } + + pc := &packetInfoUDPConn{ + conn: conn, + ipv4: ipv4.NewPacketConn(conn), + ipv6: ipv6.NewPacketConn(conn), + v6: laddr != nil && laddr.IP != nil && laddr.IP.To4() == nil, + localIPs: make(map[string]net.IP), + } + if pc.v6 { + if err = pc.ipv6.SetControlMessage(ipv6.FlagDst, true); err != nil { + _ = conn.Close() + return nil, fmt.Errorf("enable IPv6 packet info: %w", err) + } + } else if err = pc.ipv4.SetControlMessage(ipv4.FlagDst, true); err != nil { + _ = conn.Close() + return nil, fmt.Errorf("enable IPv4 packet info: %w", err) + } + + return pc, nil +} + +func (c *packetInfoUDPConn) ReadFrom(p []byte) (int, net.Addr, error) { + if c.v6 { + n, cm, addr, err := c.ipv6.ReadFrom(p) + if err != nil { + return n, addr, err + } + if udpAddr, ok := addr.(*net.UDPAddr); ok && cm != nil && cm.Dst != nil { + c.rememberLocalIP(udpAddr.String(), cm.Dst) + } + return n, addr, nil + } + + n, cm, addr, err := c.ipv4.ReadFrom(p) + if err != nil { + return n, addr, err + } + if udpAddr, ok := addr.(*net.UDPAddr); ok && cm != nil && cm.Dst != nil { + c.rememberLocalIP(udpAddr.String(), cm.Dst) + } + return n, addr, nil +} + +func (c *packetInfoUDPConn) WriteTo(p []byte, addr net.Addr) (int, error) { + udpAddr, ok := addr.(*net.UDPAddr) + if !ok { + return 0, fmt.Errorf("packet info write: expected *net.UDPAddr, got %T", addr) + } + + localIP := c.localIPFor(udpAddr.String()) + if localIP == nil { + return c.conn.WriteTo(p, addr) + } + if localIP.To4() != nil { + return c.ipv4.WriteTo(p, &ipv4.ControlMessage{Src: localIP}, addr) + } + return c.ipv6.WriteTo(p, &ipv6.ControlMessage{Src: localIP}, addr) +} + +func (c *packetInfoUDPConn) Close() error { + return c.conn.Close() +} + +func (c *packetInfoUDPConn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *packetInfoUDPConn) SetDeadline(t time.Time) error { + return c.conn.SetDeadline(t) +} + +func (c *packetInfoUDPConn) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +func (c *packetInfoUDPConn) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} + +func (c *packetInfoUDPConn) rememberLocalIP(remote string, ip net.IP) { + c.mu.Lock() + defer c.mu.Unlock() + c.localIPs[remote] = append(net.IP(nil), ip...) +} + +func (c *packetInfoUDPConn) localIPFor(remote string) net.IP { + c.mu.RLock() + defer c.mu.RUnlock() + ip := c.localIPs[remote] + if ip == nil { + return nil + } + return append(net.IP(nil), ip...) +} diff --git a/server/pktinfo_other.go b/server/pktinfo_other.go new file mode 100644 index 0000000..ac10c4d --- /dev/null +++ b/server/pktinfo_other.go @@ -0,0 +1,9 @@ +//go:build !linux + +package main + +import "net" + +func listenPacketInfoUDP(network string, laddr *net.UDPAddr) (net.PacketConn, error) { + return net.ListenUDP(network, laddr) +} diff --git a/server/udp_listener.go b/server/udp_listener.go new file mode 100644 index 0000000..1085a26 --- /dev/null +++ b/server/udp_listener.go @@ -0,0 +1,238 @@ +package main + +import ( + "context" + "errors" + "net" + "sync" + "sync/atomic" + "time" + + dtlsnet "github.com/pion/dtls/v3/pkg/net" + "github.com/pion/dtls/v3/pkg/protocol" + "github.com/pion/dtls/v3/pkg/protocol/recordlayer" + "github.com/pion/transport/v4/deadline" + "github.com/pion/transport/v4/packetio" +) + +const udpReceiveMTU = 8192 + +var errUDPPacketListenerClosed = errors.New("udp packet listener closed") + +type udpAcceptFilter func([]byte) bool + +type udpPacketListener struct { + pConn net.PacketConn + acceptFilter udpAcceptFilter + + accepting atomic.Bool + acceptCh chan *udpPacketConn + doneCh chan struct{} + doneOnce sync.Once + + connLock sync.Mutex + conns map[string]*udpPacketConn + connWG sync.WaitGroup + + readDoneCh chan struct{} + readWG sync.WaitGroup + errRead atomic.Value + errClose atomic.Value +} + +func listenUDPForDTLS(addr *net.UDPAddr) (dtlsnet.PacketListener, error) { + pConn, err := listenPacketInfoUDP("udp", addr) + if err != nil { + return nil, err + } + return newUDPPacketListener(pConn, isDTLSHandshakePacket), nil +} + +func newUDPPacketListener(pConn net.PacketConn, acceptFilter udpAcceptFilter) dtlsnet.PacketListener { + l := &udpPacketListener{ + pConn: pConn, + acceptFilter: acceptFilter, + acceptCh: make(chan *udpPacketConn, 128), + doneCh: make(chan struct{}), + conns: make(map[string]*udpPacketConn), + readDoneCh: make(chan struct{}), + } + l.accepting.Store(true) + l.connWG.Add(1) + l.readWG.Add(2) + go l.readLoop() + go func() { + l.connWG.Wait() + if err := l.pConn.Close(); err != nil { + l.errClose.Store(err) + } + l.readWG.Done() + }() + return l +} + +func (l *udpPacketListener) Accept() (net.PacketConn, net.Addr, error) { + select { + case c := <-l.acceptCh: + l.connWG.Add(1) + return c, c.rAddr, nil + case <-l.readDoneCh: + if err, ok := l.errRead.Load().(error); ok { + return nil, nil, err + } + return nil, nil, errUDPPacketListenerClosed + case <-l.doneCh: + return nil, nil, errUDPPacketListenerClosed + } +} + +func (l *udpPacketListener) Close() error { + var err error + l.doneOnce.Do(func() { + l.accepting.Store(false) + close(l.doneCh) + + l.connLock.Lock() + for { + select { + case c := <-l.acceptCh: + close(c.doneCh) + delete(l.conns, c.rAddr.String()) + default: + l.connLock.Unlock() + l.connWG.Done() + l.readWG.Wait() + if errClose, ok := l.errClose.Load().(error); ok { + err = errClose + } + return + } + } + }) + return err +} + +func (l *udpPacketListener) Addr() net.Addr { + return l.pConn.LocalAddr() +} + +func (l *udpPacketListener) readLoop() { + defer l.readWG.Done() + defer close(l.readDoneCh) + + buf := make([]byte, udpReceiveMTU) + for { + n, raddr, err := l.pConn.ReadFrom(buf) + if err != nil { + l.errRead.Store(err) + return + } + l.dispatchMsg(raddr, buf[:n]) + } +} + +func (l *udpPacketListener) dispatchMsg(raddr net.Addr, buf []byte) { + conn, ok := l.getConn(raddr, buf) + if ok { + _, _ = conn.buffer.Write(buf) + } +} + +func (l *udpPacketListener) getConn(raddr net.Addr, buf []byte) (*udpPacketConn, bool) { + l.connLock.Lock() + defer l.connLock.Unlock() + + conn, ok := l.conns[raddr.String()] + if !ok { + if !l.accepting.Load() { + return nil, false + } + if l.acceptFilter != nil && !l.acceptFilter(buf) { + return nil, false + } + conn = &udpPacketConn{ + listener: l, + rAddr: raddr, + buffer: packetio.NewBuffer(), + doneCh: make(chan struct{}), + writeDeadline: deadline.New(), + } + select { + case l.acceptCh <- conn: + l.conns[raddr.String()] = conn + default: + return nil, false + } + } + return conn, true +} + +type udpPacketConn struct { + listener *udpPacketListener + rAddr net.Addr + buffer *packetio.Buffer + + doneCh chan struct{} + doneOnce sync.Once + + writeDeadline *deadline.Deadline +} + +func (c *udpPacketConn) ReadFrom(p []byte) (int, net.Addr, error) { + n, err := c.buffer.Read(p) + return n, c.rAddr, err +} + +func (c *udpPacketConn) WriteTo(p []byte, _ net.Addr) (int, error) { + select { + case <-c.writeDeadline.Done(): + return 0, context.DeadlineExceeded + default: + } + return c.listener.pConn.WriteTo(p, c.rAddr) +} + +func (c *udpPacketConn) Close() error { + var err error + c.doneOnce.Do(func() { + c.listener.connWG.Done() + close(c.doneCh) + c.listener.connLock.Lock() + delete(c.listener.conns, c.rAddr.String()) + c.listener.connLock.Unlock() + if errBuf := c.buffer.Close(); errBuf != nil { + err = errBuf + } + }) + return err +} + +func (c *udpPacketConn) LocalAddr() net.Addr { + return c.listener.pConn.LocalAddr() +} + +func (c *udpPacketConn) SetDeadline(t time.Time) error { + c.writeDeadline.Set(t) + return c.SetReadDeadline(t) +} + +func (c *udpPacketConn) SetReadDeadline(t time.Time) error { + return c.buffer.SetReadDeadline(t) +} + +func (c *udpPacketConn) SetWriteDeadline(t time.Time) error { + c.writeDeadline.Set(t) + return nil +} + +func isDTLSHandshakePacket(packet []byte) bool { + pkts, err := recordlayer.UnpackDatagram(packet) + if err != nil || len(pkts) == 0 { + return false + } + h := &recordlayer.Header{} + if err := h.Unmarshal(pkts[0]); err != nil { + return false + } + return h.ContentType == protocol.ContentTypeHandshake +} diff --git a/server/wrap.go b/server/wrap.go index 1a0917b..3c5bc67 100644 --- a/server/wrap.go +++ b/server/wrap.go @@ -14,7 +14,6 @@ import ( "time" dtlsnet "github.com/pion/dtls/v3/pkg/net" - pionudp "github.com/pion/transport/v4/udp" "golang.org/x/crypto/chacha20poly1305" ) @@ -60,12 +59,12 @@ func listenWrapped(addr *net.UDPAddr, key []byte) (dtlsnet.PacketListener, error if err != nil { return nil, err } - inner, err := pionudp.Listen("udp", addr) + innerConn, err := listenPacketInfoUDP("udp", addr) if err != nil { return nil, fmt.Errorf("wrap: udp listen: %w", err) } return &wrapPacketListener{ - inner: dtlsnet.PacketListenerFromListener(inner), + inner: newUDPPacketListener(innerConn, nil), ws: ws, }, nil }