Browse Source
Use packet info on Linux to remember the local destination IP for each client and send DTLS replies from that same address. Refs #3pull/162/head
5 changed files with 377 additions and 4 deletions
@ -0,0 +1,123 @@ |
|||||
|
//go:build linux
|
||||
|
|
||||
|
package main |
||||
|
|
||||
|
import ( |
||||
|
"fmt" |
||||
|
"net" |
||||
|
"sync" |
||||
|
"time" |
||||
|
|
||||
|
"golang.org/x/net/ipv4" |
||||
|
"golang.org/x/net/ipv6" |
||||
|
) |
||||
|
|
||||
|
type packetInfoUDPConn struct { |
||||
|
conn *net.UDPConn |
||||
|
ipv4 *ipv4.PacketConn |
||||
|
ipv6 *ipv6.PacketConn |
||||
|
v6 bool |
||||
|
|
||||
|
mu sync.RWMutex |
||||
|
localIPs map[string]net.IP |
||||
|
} |
||||
|
|
||||
|
func listenPacketInfoUDP(network string, laddr *net.UDPAddr) (net.PacketConn, error) { |
||||
|
conn, err := net.ListenUDP(network, laddr) |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
|
||||
|
pc := &packetInfoUDPConn{ |
||||
|
conn: conn, |
||||
|
ipv4: ipv4.NewPacketConn(conn), |
||||
|
ipv6: ipv6.NewPacketConn(conn), |
||||
|
v6: laddr != nil && laddr.IP != nil && laddr.IP.To4() == nil, |
||||
|
localIPs: make(map[string]net.IP), |
||||
|
} |
||||
|
if pc.v6 { |
||||
|
if err = pc.ipv6.SetControlMessage(ipv6.FlagDst, true); err != nil { |
||||
|
_ = conn.Close() |
||||
|
return nil, fmt.Errorf("enable IPv6 packet info: %w", err) |
||||
|
} |
||||
|
} else if err = pc.ipv4.SetControlMessage(ipv4.FlagDst, true); err != nil { |
||||
|
_ = conn.Close() |
||||
|
return nil, fmt.Errorf("enable IPv4 packet info: %w", err) |
||||
|
} |
||||
|
|
||||
|
return pc, nil |
||||
|
} |
||||
|
|
||||
|
func (c *packetInfoUDPConn) ReadFrom(p []byte) (int, net.Addr, error) { |
||||
|
if c.v6 { |
||||
|
n, cm, addr, err := c.ipv6.ReadFrom(p) |
||||
|
if err != nil { |
||||
|
return n, addr, err |
||||
|
} |
||||
|
if udpAddr, ok := addr.(*net.UDPAddr); ok && cm != nil && cm.Dst != nil { |
||||
|
c.rememberLocalIP(udpAddr.String(), cm.Dst) |
||||
|
} |
||||
|
return n, addr, nil |
||||
|
} |
||||
|
|
||||
|
n, cm, addr, err := c.ipv4.ReadFrom(p) |
||||
|
if err != nil { |
||||
|
return n, addr, err |
||||
|
} |
||||
|
if udpAddr, ok := addr.(*net.UDPAddr); ok && cm != nil && cm.Dst != nil { |
||||
|
c.rememberLocalIP(udpAddr.String(), cm.Dst) |
||||
|
} |
||||
|
return n, addr, nil |
||||
|
} |
||||
|
|
||||
|
func (c *packetInfoUDPConn) WriteTo(p []byte, addr net.Addr) (int, error) { |
||||
|
udpAddr, ok := addr.(*net.UDPAddr) |
||||
|
if !ok { |
||||
|
return 0, fmt.Errorf("packet info write: expected *net.UDPAddr, got %T", addr) |
||||
|
} |
||||
|
|
||||
|
localIP := c.localIPFor(udpAddr.String()) |
||||
|
if localIP == nil { |
||||
|
return c.conn.WriteTo(p, addr) |
||||
|
} |
||||
|
if localIP.To4() != nil { |
||||
|
return c.ipv4.WriteTo(p, &ipv4.ControlMessage{Src: localIP}, addr) |
||||
|
} |
||||
|
return c.ipv6.WriteTo(p, &ipv6.ControlMessage{Src: localIP}, addr) |
||||
|
} |
||||
|
|
||||
|
func (c *packetInfoUDPConn) Close() error { |
||||
|
return c.conn.Close() |
||||
|
} |
||||
|
|
||||
|
func (c *packetInfoUDPConn) LocalAddr() net.Addr { |
||||
|
return c.conn.LocalAddr() |
||||
|
} |
||||
|
|
||||
|
func (c *packetInfoUDPConn) SetDeadline(t time.Time) error { |
||||
|
return c.conn.SetDeadline(t) |
||||
|
} |
||||
|
|
||||
|
func (c *packetInfoUDPConn) SetReadDeadline(t time.Time) error { |
||||
|
return c.conn.SetReadDeadline(t) |
||||
|
} |
||||
|
|
||||
|
func (c *packetInfoUDPConn) SetWriteDeadline(t time.Time) error { |
||||
|
return c.conn.SetWriteDeadline(t) |
||||
|
} |
||||
|
|
||||
|
func (c *packetInfoUDPConn) rememberLocalIP(remote string, ip net.IP) { |
||||
|
c.mu.Lock() |
||||
|
defer c.mu.Unlock() |
||||
|
c.localIPs[remote] = append(net.IP(nil), ip...) |
||||
|
} |
||||
|
|
||||
|
func (c *packetInfoUDPConn) localIPFor(remote string) net.IP { |
||||
|
c.mu.RLock() |
||||
|
defer c.mu.RUnlock() |
||||
|
ip := c.localIPs[remote] |
||||
|
if ip == nil { |
||||
|
return nil |
||||
|
} |
||||
|
return append(net.IP(nil), ip...) |
||||
|
} |
||||
@ -0,0 +1,9 @@ |
|||||
|
//go:build !linux
|
||||
|
|
||||
|
package main |
||||
|
|
||||
|
import "net" |
||||
|
|
||||
|
func listenPacketInfoUDP(network string, laddr *net.UDPAddr) (net.PacketConn, error) { |
||||
|
return net.ListenUDP(network, laddr) |
||||
|
} |
||||
@ -0,0 +1,238 @@ |
|||||
|
package main |
||||
|
|
||||
|
import ( |
||||
|
"context" |
||||
|
"errors" |
||||
|
"net" |
||||
|
"sync" |
||||
|
"sync/atomic" |
||||
|
"time" |
||||
|
|
||||
|
dtlsnet "github.com/pion/dtls/v3/pkg/net" |
||||
|
"github.com/pion/dtls/v3/pkg/protocol" |
||||
|
"github.com/pion/dtls/v3/pkg/protocol/recordlayer" |
||||
|
"github.com/pion/transport/v4/deadline" |
||||
|
"github.com/pion/transport/v4/packetio" |
||||
|
) |
||||
|
|
||||
|
const udpReceiveMTU = 8192 |
||||
|
|
||||
|
var errUDPPacketListenerClosed = errors.New("udp packet listener closed") |
||||
|
|
||||
|
type udpAcceptFilter func([]byte) bool |
||||
|
|
||||
|
type udpPacketListener struct { |
||||
|
pConn net.PacketConn |
||||
|
acceptFilter udpAcceptFilter |
||||
|
|
||||
|
accepting atomic.Bool |
||||
|
acceptCh chan *udpPacketConn |
||||
|
doneCh chan struct{} |
||||
|
doneOnce sync.Once |
||||
|
|
||||
|
connLock sync.Mutex |
||||
|
conns map[string]*udpPacketConn |
||||
|
connWG sync.WaitGroup |
||||
|
|
||||
|
readDoneCh chan struct{} |
||||
|
readWG sync.WaitGroup |
||||
|
errRead atomic.Value |
||||
|
errClose atomic.Value |
||||
|
} |
||||
|
|
||||
|
func listenUDPForDTLS(addr *net.UDPAddr) (dtlsnet.PacketListener, error) { |
||||
|
pConn, err := listenPacketInfoUDP("udp", addr) |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
return newUDPPacketListener(pConn, isDTLSHandshakePacket), nil |
||||
|
} |
||||
|
|
||||
|
func newUDPPacketListener(pConn net.PacketConn, acceptFilter udpAcceptFilter) dtlsnet.PacketListener { |
||||
|
l := &udpPacketListener{ |
||||
|
pConn: pConn, |
||||
|
acceptFilter: acceptFilter, |
||||
|
acceptCh: make(chan *udpPacketConn, 128), |
||||
|
doneCh: make(chan struct{}), |
||||
|
conns: make(map[string]*udpPacketConn), |
||||
|
readDoneCh: make(chan struct{}), |
||||
|
} |
||||
|
l.accepting.Store(true) |
||||
|
l.connWG.Add(1) |
||||
|
l.readWG.Add(2) |
||||
|
go l.readLoop() |
||||
|
go func() { |
||||
|
l.connWG.Wait() |
||||
|
if err := l.pConn.Close(); err != nil { |
||||
|
l.errClose.Store(err) |
||||
|
} |
||||
|
l.readWG.Done() |
||||
|
}() |
||||
|
return l |
||||
|
} |
||||
|
|
||||
|
func (l *udpPacketListener) Accept() (net.PacketConn, net.Addr, error) { |
||||
|
select { |
||||
|
case c := <-l.acceptCh: |
||||
|
l.connWG.Add(1) |
||||
|
return c, c.rAddr, nil |
||||
|
case <-l.readDoneCh: |
||||
|
if err, ok := l.errRead.Load().(error); ok { |
||||
|
return nil, nil, err |
||||
|
} |
||||
|
return nil, nil, errUDPPacketListenerClosed |
||||
|
case <-l.doneCh: |
||||
|
return nil, nil, errUDPPacketListenerClosed |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func (l *udpPacketListener) Close() error { |
||||
|
var err error |
||||
|
l.doneOnce.Do(func() { |
||||
|
l.accepting.Store(false) |
||||
|
close(l.doneCh) |
||||
|
|
||||
|
l.connLock.Lock() |
||||
|
for { |
||||
|
select { |
||||
|
case c := <-l.acceptCh: |
||||
|
close(c.doneCh) |
||||
|
delete(l.conns, c.rAddr.String()) |
||||
|
default: |
||||
|
l.connLock.Unlock() |
||||
|
l.connWG.Done() |
||||
|
l.readWG.Wait() |
||||
|
if errClose, ok := l.errClose.Load().(error); ok { |
||||
|
err = errClose |
||||
|
} |
||||
|
return |
||||
|
} |
||||
|
} |
||||
|
}) |
||||
|
return err |
||||
|
} |
||||
|
|
||||
|
func (l *udpPacketListener) Addr() net.Addr { |
||||
|
return l.pConn.LocalAddr() |
||||
|
} |
||||
|
|
||||
|
func (l *udpPacketListener) readLoop() { |
||||
|
defer l.readWG.Done() |
||||
|
defer close(l.readDoneCh) |
||||
|
|
||||
|
buf := make([]byte, udpReceiveMTU) |
||||
|
for { |
||||
|
n, raddr, err := l.pConn.ReadFrom(buf) |
||||
|
if err != nil { |
||||
|
l.errRead.Store(err) |
||||
|
return |
||||
|
} |
||||
|
l.dispatchMsg(raddr, buf[:n]) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func (l *udpPacketListener) dispatchMsg(raddr net.Addr, buf []byte) { |
||||
|
conn, ok := l.getConn(raddr, buf) |
||||
|
if ok { |
||||
|
_, _ = conn.buffer.Write(buf) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func (l *udpPacketListener) getConn(raddr net.Addr, buf []byte) (*udpPacketConn, bool) { |
||||
|
l.connLock.Lock() |
||||
|
defer l.connLock.Unlock() |
||||
|
|
||||
|
conn, ok := l.conns[raddr.String()] |
||||
|
if !ok { |
||||
|
if !l.accepting.Load() { |
||||
|
return nil, false |
||||
|
} |
||||
|
if l.acceptFilter != nil && !l.acceptFilter(buf) { |
||||
|
return nil, false |
||||
|
} |
||||
|
conn = &udpPacketConn{ |
||||
|
listener: l, |
||||
|
rAddr: raddr, |
||||
|
buffer: packetio.NewBuffer(), |
||||
|
doneCh: make(chan struct{}), |
||||
|
writeDeadline: deadline.New(), |
||||
|
} |
||||
|
select { |
||||
|
case l.acceptCh <- conn: |
||||
|
l.conns[raddr.String()] = conn |
||||
|
default: |
||||
|
return nil, false |
||||
|
} |
||||
|
} |
||||
|
return conn, true |
||||
|
} |
||||
|
|
||||
|
type udpPacketConn struct { |
||||
|
listener *udpPacketListener |
||||
|
rAddr net.Addr |
||||
|
buffer *packetio.Buffer |
||||
|
|
||||
|
doneCh chan struct{} |
||||
|
doneOnce sync.Once |
||||
|
|
||||
|
writeDeadline *deadline.Deadline |
||||
|
} |
||||
|
|
||||
|
func (c *udpPacketConn) ReadFrom(p []byte) (int, net.Addr, error) { |
||||
|
n, err := c.buffer.Read(p) |
||||
|
return n, c.rAddr, err |
||||
|
} |
||||
|
|
||||
|
func (c *udpPacketConn) WriteTo(p []byte, _ net.Addr) (int, error) { |
||||
|
select { |
||||
|
case <-c.writeDeadline.Done(): |
||||
|
return 0, context.DeadlineExceeded |
||||
|
default: |
||||
|
} |
||||
|
return c.listener.pConn.WriteTo(p, c.rAddr) |
||||
|
} |
||||
|
|
||||
|
func (c *udpPacketConn) Close() error { |
||||
|
var err error |
||||
|
c.doneOnce.Do(func() { |
||||
|
c.listener.connWG.Done() |
||||
|
close(c.doneCh) |
||||
|
c.listener.connLock.Lock() |
||||
|
delete(c.listener.conns, c.rAddr.String()) |
||||
|
c.listener.connLock.Unlock() |
||||
|
if errBuf := c.buffer.Close(); errBuf != nil { |
||||
|
err = errBuf |
||||
|
} |
||||
|
}) |
||||
|
return err |
||||
|
} |
||||
|
|
||||
|
func (c *udpPacketConn) LocalAddr() net.Addr { |
||||
|
return c.listener.pConn.LocalAddr() |
||||
|
} |
||||
|
|
||||
|
func (c *udpPacketConn) SetDeadline(t time.Time) error { |
||||
|
c.writeDeadline.Set(t) |
||||
|
return c.SetReadDeadline(t) |
||||
|
} |
||||
|
|
||||
|
func (c *udpPacketConn) SetReadDeadline(t time.Time) error { |
||||
|
return c.buffer.SetReadDeadline(t) |
||||
|
} |
||||
|
|
||||
|
func (c *udpPacketConn) SetWriteDeadline(t time.Time) error { |
||||
|
c.writeDeadline.Set(t) |
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
func isDTLSHandshakePacket(packet []byte) bool { |
||||
|
pkts, err := recordlayer.UnpackDatagram(packet) |
||||
|
if err != nil || len(pkts) == 0 { |
||||
|
return false |
||||
|
} |
||||
|
h := &recordlayer.Header{} |
||||
|
if err := h.Unmarshal(pkts[0]); err != nil { |
||||
|
return false |
||||
|
} |
||||
|
return h.ContentType == protocol.ContentTypeHandshake |
||||
|
} |
||||
Loading…
Reference in new issue