Browse Source

fix(server): preserve UDP reply source IP

Use packet info on Linux to remember the local destination IP for each client and send DTLS replies from that same address.

Refs #3
pull/162/head
Moroka8 3 weeks ago
parent
commit
acd1eacd53
  1. 6
      server/main.go
  2. 123
      server/pktinfo_linux.go
  3. 9
      server/pktinfo_other.go
  4. 238
      server/udp_listener.go
  5. 5
      server/wrap.go

6
server/main.go

@ -110,7 +110,11 @@ func main() {
}
listener, err = dtls.NewListenerWithOptions(wrapListener, dtlsOpts...)
} else {
listener, err = dtls.ListenWithOptions("udp", addr, dtlsOpts...)
udpListener, lerr := listenUDPForDTLS(addr)
if lerr != nil {
panic(lerr)
}
listener, err = dtls.NewListenerWithOptions(udpListener, dtlsOpts...)
}
if err != nil {
panic(err)

123
server/pktinfo_linux.go

@ -0,0 +1,123 @@
//go:build linux
package main
import (
"fmt"
"net"
"sync"
"time"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
type packetInfoUDPConn struct {
conn *net.UDPConn
ipv4 *ipv4.PacketConn
ipv6 *ipv6.PacketConn
v6 bool
mu sync.RWMutex
localIPs map[string]net.IP
}
func listenPacketInfoUDP(network string, laddr *net.UDPAddr) (net.PacketConn, error) {
conn, err := net.ListenUDP(network, laddr)
if err != nil {
return nil, err
}
pc := &packetInfoUDPConn{
conn: conn,
ipv4: ipv4.NewPacketConn(conn),
ipv6: ipv6.NewPacketConn(conn),
v6: laddr != nil && laddr.IP != nil && laddr.IP.To4() == nil,
localIPs: make(map[string]net.IP),
}
if pc.v6 {
if err = pc.ipv6.SetControlMessage(ipv6.FlagDst, true); err != nil {
_ = conn.Close()
return nil, fmt.Errorf("enable IPv6 packet info: %w", err)
}
} else if err = pc.ipv4.SetControlMessage(ipv4.FlagDst, true); err != nil {
_ = conn.Close()
return nil, fmt.Errorf("enable IPv4 packet info: %w", err)
}
return pc, nil
}
func (c *packetInfoUDPConn) ReadFrom(p []byte) (int, net.Addr, error) {
if c.v6 {
n, cm, addr, err := c.ipv6.ReadFrom(p)
if err != nil {
return n, addr, err
}
if udpAddr, ok := addr.(*net.UDPAddr); ok && cm != nil && cm.Dst != nil {
c.rememberLocalIP(udpAddr.String(), cm.Dst)
}
return n, addr, nil
}
n, cm, addr, err := c.ipv4.ReadFrom(p)
if err != nil {
return n, addr, err
}
if udpAddr, ok := addr.(*net.UDPAddr); ok && cm != nil && cm.Dst != nil {
c.rememberLocalIP(udpAddr.String(), cm.Dst)
}
return n, addr, nil
}
func (c *packetInfoUDPConn) WriteTo(p []byte, addr net.Addr) (int, error) {
udpAddr, ok := addr.(*net.UDPAddr)
if !ok {
return 0, fmt.Errorf("packet info write: expected *net.UDPAddr, got %T", addr)
}
localIP := c.localIPFor(udpAddr.String())
if localIP == nil {
return c.conn.WriteTo(p, addr)
}
if localIP.To4() != nil {
return c.ipv4.WriteTo(p, &ipv4.ControlMessage{Src: localIP}, addr)
}
return c.ipv6.WriteTo(p, &ipv6.ControlMessage{Src: localIP}, addr)
}
func (c *packetInfoUDPConn) Close() error {
return c.conn.Close()
}
func (c *packetInfoUDPConn) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
func (c *packetInfoUDPConn) SetDeadline(t time.Time) error {
return c.conn.SetDeadline(t)
}
func (c *packetInfoUDPConn) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t)
}
func (c *packetInfoUDPConn) SetWriteDeadline(t time.Time) error {
return c.conn.SetWriteDeadline(t)
}
func (c *packetInfoUDPConn) rememberLocalIP(remote string, ip net.IP) {
c.mu.Lock()
defer c.mu.Unlock()
c.localIPs[remote] = append(net.IP(nil), ip...)
}
func (c *packetInfoUDPConn) localIPFor(remote string) net.IP {
c.mu.RLock()
defer c.mu.RUnlock()
ip := c.localIPs[remote]
if ip == nil {
return nil
}
return append(net.IP(nil), ip...)
}

9
server/pktinfo_other.go

@ -0,0 +1,9 @@
//go:build !linux
package main
import "net"
func listenPacketInfoUDP(network string, laddr *net.UDPAddr) (net.PacketConn, error) {
return net.ListenUDP(network, laddr)
}

238
server/udp_listener.go

@ -0,0 +1,238 @@
package main
import (
"context"
"errors"
"net"
"sync"
"sync/atomic"
"time"
dtlsnet "github.com/pion/dtls/v3/pkg/net"
"github.com/pion/dtls/v3/pkg/protocol"
"github.com/pion/dtls/v3/pkg/protocol/recordlayer"
"github.com/pion/transport/v4/deadline"
"github.com/pion/transport/v4/packetio"
)
const udpReceiveMTU = 8192
var errUDPPacketListenerClosed = errors.New("udp packet listener closed")
type udpAcceptFilter func([]byte) bool
type udpPacketListener struct {
pConn net.PacketConn
acceptFilter udpAcceptFilter
accepting atomic.Bool
acceptCh chan *udpPacketConn
doneCh chan struct{}
doneOnce sync.Once
connLock sync.Mutex
conns map[string]*udpPacketConn
connWG sync.WaitGroup
readDoneCh chan struct{}
readWG sync.WaitGroup
errRead atomic.Value
errClose atomic.Value
}
func listenUDPForDTLS(addr *net.UDPAddr) (dtlsnet.PacketListener, error) {
pConn, err := listenPacketInfoUDP("udp", addr)
if err != nil {
return nil, err
}
return newUDPPacketListener(pConn, isDTLSHandshakePacket), nil
}
func newUDPPacketListener(pConn net.PacketConn, acceptFilter udpAcceptFilter) dtlsnet.PacketListener {
l := &udpPacketListener{
pConn: pConn,
acceptFilter: acceptFilter,
acceptCh: make(chan *udpPacketConn, 128),
doneCh: make(chan struct{}),
conns: make(map[string]*udpPacketConn),
readDoneCh: make(chan struct{}),
}
l.accepting.Store(true)
l.connWG.Add(1)
l.readWG.Add(2)
go l.readLoop()
go func() {
l.connWG.Wait()
if err := l.pConn.Close(); err != nil {
l.errClose.Store(err)
}
l.readWG.Done()
}()
return l
}
func (l *udpPacketListener) Accept() (net.PacketConn, net.Addr, error) {
select {
case c := <-l.acceptCh:
l.connWG.Add(1)
return c, c.rAddr, nil
case <-l.readDoneCh:
if err, ok := l.errRead.Load().(error); ok {
return nil, nil, err
}
return nil, nil, errUDPPacketListenerClosed
case <-l.doneCh:
return nil, nil, errUDPPacketListenerClosed
}
}
func (l *udpPacketListener) Close() error {
var err error
l.doneOnce.Do(func() {
l.accepting.Store(false)
close(l.doneCh)
l.connLock.Lock()
for {
select {
case c := <-l.acceptCh:
close(c.doneCh)
delete(l.conns, c.rAddr.String())
default:
l.connLock.Unlock()
l.connWG.Done()
l.readWG.Wait()
if errClose, ok := l.errClose.Load().(error); ok {
err = errClose
}
return
}
}
})
return err
}
func (l *udpPacketListener) Addr() net.Addr {
return l.pConn.LocalAddr()
}
func (l *udpPacketListener) readLoop() {
defer l.readWG.Done()
defer close(l.readDoneCh)
buf := make([]byte, udpReceiveMTU)
for {
n, raddr, err := l.pConn.ReadFrom(buf)
if err != nil {
l.errRead.Store(err)
return
}
l.dispatchMsg(raddr, buf[:n])
}
}
func (l *udpPacketListener) dispatchMsg(raddr net.Addr, buf []byte) {
conn, ok := l.getConn(raddr, buf)
if ok {
_, _ = conn.buffer.Write(buf)
}
}
func (l *udpPacketListener) getConn(raddr net.Addr, buf []byte) (*udpPacketConn, bool) {
l.connLock.Lock()
defer l.connLock.Unlock()
conn, ok := l.conns[raddr.String()]
if !ok {
if !l.accepting.Load() {
return nil, false
}
if l.acceptFilter != nil && !l.acceptFilter(buf) {
return nil, false
}
conn = &udpPacketConn{
listener: l,
rAddr: raddr,
buffer: packetio.NewBuffer(),
doneCh: make(chan struct{}),
writeDeadline: deadline.New(),
}
select {
case l.acceptCh <- conn:
l.conns[raddr.String()] = conn
default:
return nil, false
}
}
return conn, true
}
type udpPacketConn struct {
listener *udpPacketListener
rAddr net.Addr
buffer *packetio.Buffer
doneCh chan struct{}
doneOnce sync.Once
writeDeadline *deadline.Deadline
}
func (c *udpPacketConn) ReadFrom(p []byte) (int, net.Addr, error) {
n, err := c.buffer.Read(p)
return n, c.rAddr, err
}
func (c *udpPacketConn) WriteTo(p []byte, _ net.Addr) (int, error) {
select {
case <-c.writeDeadline.Done():
return 0, context.DeadlineExceeded
default:
}
return c.listener.pConn.WriteTo(p, c.rAddr)
}
func (c *udpPacketConn) Close() error {
var err error
c.doneOnce.Do(func() {
c.listener.connWG.Done()
close(c.doneCh)
c.listener.connLock.Lock()
delete(c.listener.conns, c.rAddr.String())
c.listener.connLock.Unlock()
if errBuf := c.buffer.Close(); errBuf != nil {
err = errBuf
}
})
return err
}
func (c *udpPacketConn) LocalAddr() net.Addr {
return c.listener.pConn.LocalAddr()
}
func (c *udpPacketConn) SetDeadline(t time.Time) error {
c.writeDeadline.Set(t)
return c.SetReadDeadline(t)
}
func (c *udpPacketConn) SetReadDeadline(t time.Time) error {
return c.buffer.SetReadDeadline(t)
}
func (c *udpPacketConn) SetWriteDeadline(t time.Time) error {
c.writeDeadline.Set(t)
return nil
}
func isDTLSHandshakePacket(packet []byte) bool {
pkts, err := recordlayer.UnpackDatagram(packet)
if err != nil || len(pkts) == 0 {
return false
}
h := &recordlayer.Header{}
if err := h.Unmarshal(pkts[0]); err != nil {
return false
}
return h.ContentType == protocol.ContentTypeHandshake
}

5
server/wrap.go

@ -14,7 +14,6 @@ import (
"time"
dtlsnet "github.com/pion/dtls/v3/pkg/net"
pionudp "github.com/pion/transport/v4/udp"
"golang.org/x/crypto/chacha20poly1305"
)
@ -60,12 +59,12 @@ func listenWrapped(addr *net.UDPAddr, key []byte) (dtlsnet.PacketListener, error
if err != nil {
return nil, err
}
inner, err := pionudp.Listen("udp", addr)
innerConn, err := listenPacketInfoUDP("udp", addr)
if err != nil {
return nil, fmt.Errorf("wrap: udp listen: %w", err)
}
return &wrapPacketListener{
inner: dtlsnet.PacketListenerFromListener(inner),
inner: newUDPPacketListener(innerConn, nil),
ws: ws,
}, nil
}

Loading…
Cancel
Save