Browse Source
Use packet info on Linux to remember the local destination IP for each client and send DTLS replies from that same address. Refs #3pull/162/head
5 changed files with 377 additions and 4 deletions
@ -0,0 +1,123 @@ |
|||
//go:build linux
|
|||
|
|||
package main |
|||
|
|||
import ( |
|||
"fmt" |
|||
"net" |
|||
"sync" |
|||
"time" |
|||
|
|||
"golang.org/x/net/ipv4" |
|||
"golang.org/x/net/ipv6" |
|||
) |
|||
|
|||
type packetInfoUDPConn struct { |
|||
conn *net.UDPConn |
|||
ipv4 *ipv4.PacketConn |
|||
ipv6 *ipv6.PacketConn |
|||
v6 bool |
|||
|
|||
mu sync.RWMutex |
|||
localIPs map[string]net.IP |
|||
} |
|||
|
|||
func listenPacketInfoUDP(network string, laddr *net.UDPAddr) (net.PacketConn, error) { |
|||
conn, err := net.ListenUDP(network, laddr) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
pc := &packetInfoUDPConn{ |
|||
conn: conn, |
|||
ipv4: ipv4.NewPacketConn(conn), |
|||
ipv6: ipv6.NewPacketConn(conn), |
|||
v6: laddr != nil && laddr.IP != nil && laddr.IP.To4() == nil, |
|||
localIPs: make(map[string]net.IP), |
|||
} |
|||
if pc.v6 { |
|||
if err = pc.ipv6.SetControlMessage(ipv6.FlagDst, true); err != nil { |
|||
_ = conn.Close() |
|||
return nil, fmt.Errorf("enable IPv6 packet info: %w", err) |
|||
} |
|||
} else if err = pc.ipv4.SetControlMessage(ipv4.FlagDst, true); err != nil { |
|||
_ = conn.Close() |
|||
return nil, fmt.Errorf("enable IPv4 packet info: %w", err) |
|||
} |
|||
|
|||
return pc, nil |
|||
} |
|||
|
|||
func (c *packetInfoUDPConn) ReadFrom(p []byte) (int, net.Addr, error) { |
|||
if c.v6 { |
|||
n, cm, addr, err := c.ipv6.ReadFrom(p) |
|||
if err != nil { |
|||
return n, addr, err |
|||
} |
|||
if udpAddr, ok := addr.(*net.UDPAddr); ok && cm != nil && cm.Dst != nil { |
|||
c.rememberLocalIP(udpAddr.String(), cm.Dst) |
|||
} |
|||
return n, addr, nil |
|||
} |
|||
|
|||
n, cm, addr, err := c.ipv4.ReadFrom(p) |
|||
if err != nil { |
|||
return n, addr, err |
|||
} |
|||
if udpAddr, ok := addr.(*net.UDPAddr); ok && cm != nil && cm.Dst != nil { |
|||
c.rememberLocalIP(udpAddr.String(), cm.Dst) |
|||
} |
|||
return n, addr, nil |
|||
} |
|||
|
|||
func (c *packetInfoUDPConn) WriteTo(p []byte, addr net.Addr) (int, error) { |
|||
udpAddr, ok := addr.(*net.UDPAddr) |
|||
if !ok { |
|||
return 0, fmt.Errorf("packet info write: expected *net.UDPAddr, got %T", addr) |
|||
} |
|||
|
|||
localIP := c.localIPFor(udpAddr.String()) |
|||
if localIP == nil { |
|||
return c.conn.WriteTo(p, addr) |
|||
} |
|||
if localIP.To4() != nil { |
|||
return c.ipv4.WriteTo(p, &ipv4.ControlMessage{Src: localIP}, addr) |
|||
} |
|||
return c.ipv6.WriteTo(p, &ipv6.ControlMessage{Src: localIP}, addr) |
|||
} |
|||
|
|||
func (c *packetInfoUDPConn) Close() error { |
|||
return c.conn.Close() |
|||
} |
|||
|
|||
func (c *packetInfoUDPConn) LocalAddr() net.Addr { |
|||
return c.conn.LocalAddr() |
|||
} |
|||
|
|||
func (c *packetInfoUDPConn) SetDeadline(t time.Time) error { |
|||
return c.conn.SetDeadline(t) |
|||
} |
|||
|
|||
func (c *packetInfoUDPConn) SetReadDeadline(t time.Time) error { |
|||
return c.conn.SetReadDeadline(t) |
|||
} |
|||
|
|||
func (c *packetInfoUDPConn) SetWriteDeadline(t time.Time) error { |
|||
return c.conn.SetWriteDeadline(t) |
|||
} |
|||
|
|||
func (c *packetInfoUDPConn) rememberLocalIP(remote string, ip net.IP) { |
|||
c.mu.Lock() |
|||
defer c.mu.Unlock() |
|||
c.localIPs[remote] = append(net.IP(nil), ip...) |
|||
} |
|||
|
|||
func (c *packetInfoUDPConn) localIPFor(remote string) net.IP { |
|||
c.mu.RLock() |
|||
defer c.mu.RUnlock() |
|||
ip := c.localIPs[remote] |
|||
if ip == nil { |
|||
return nil |
|||
} |
|||
return append(net.IP(nil), ip...) |
|||
} |
|||
@ -0,0 +1,9 @@ |
|||
//go:build !linux
|
|||
|
|||
package main |
|||
|
|||
import "net" |
|||
|
|||
func listenPacketInfoUDP(network string, laddr *net.UDPAddr) (net.PacketConn, error) { |
|||
return net.ListenUDP(network, laddr) |
|||
} |
|||
@ -0,0 +1,238 @@ |
|||
package main |
|||
|
|||
import ( |
|||
"context" |
|||
"errors" |
|||
"net" |
|||
"sync" |
|||
"sync/atomic" |
|||
"time" |
|||
|
|||
dtlsnet "github.com/pion/dtls/v3/pkg/net" |
|||
"github.com/pion/dtls/v3/pkg/protocol" |
|||
"github.com/pion/dtls/v3/pkg/protocol/recordlayer" |
|||
"github.com/pion/transport/v4/deadline" |
|||
"github.com/pion/transport/v4/packetio" |
|||
) |
|||
|
|||
const udpReceiveMTU = 8192 |
|||
|
|||
var errUDPPacketListenerClosed = errors.New("udp packet listener closed") |
|||
|
|||
type udpAcceptFilter func([]byte) bool |
|||
|
|||
type udpPacketListener struct { |
|||
pConn net.PacketConn |
|||
acceptFilter udpAcceptFilter |
|||
|
|||
accepting atomic.Bool |
|||
acceptCh chan *udpPacketConn |
|||
doneCh chan struct{} |
|||
doneOnce sync.Once |
|||
|
|||
connLock sync.Mutex |
|||
conns map[string]*udpPacketConn |
|||
connWG sync.WaitGroup |
|||
|
|||
readDoneCh chan struct{} |
|||
readWG sync.WaitGroup |
|||
errRead atomic.Value |
|||
errClose atomic.Value |
|||
} |
|||
|
|||
func listenUDPForDTLS(addr *net.UDPAddr) (dtlsnet.PacketListener, error) { |
|||
pConn, err := listenPacketInfoUDP("udp", addr) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
return newUDPPacketListener(pConn, isDTLSHandshakePacket), nil |
|||
} |
|||
|
|||
func newUDPPacketListener(pConn net.PacketConn, acceptFilter udpAcceptFilter) dtlsnet.PacketListener { |
|||
l := &udpPacketListener{ |
|||
pConn: pConn, |
|||
acceptFilter: acceptFilter, |
|||
acceptCh: make(chan *udpPacketConn, 128), |
|||
doneCh: make(chan struct{}), |
|||
conns: make(map[string]*udpPacketConn), |
|||
readDoneCh: make(chan struct{}), |
|||
} |
|||
l.accepting.Store(true) |
|||
l.connWG.Add(1) |
|||
l.readWG.Add(2) |
|||
go l.readLoop() |
|||
go func() { |
|||
l.connWG.Wait() |
|||
if err := l.pConn.Close(); err != nil { |
|||
l.errClose.Store(err) |
|||
} |
|||
l.readWG.Done() |
|||
}() |
|||
return l |
|||
} |
|||
|
|||
func (l *udpPacketListener) Accept() (net.PacketConn, net.Addr, error) { |
|||
select { |
|||
case c := <-l.acceptCh: |
|||
l.connWG.Add(1) |
|||
return c, c.rAddr, nil |
|||
case <-l.readDoneCh: |
|||
if err, ok := l.errRead.Load().(error); ok { |
|||
return nil, nil, err |
|||
} |
|||
return nil, nil, errUDPPacketListenerClosed |
|||
case <-l.doneCh: |
|||
return nil, nil, errUDPPacketListenerClosed |
|||
} |
|||
} |
|||
|
|||
func (l *udpPacketListener) Close() error { |
|||
var err error |
|||
l.doneOnce.Do(func() { |
|||
l.accepting.Store(false) |
|||
close(l.doneCh) |
|||
|
|||
l.connLock.Lock() |
|||
for { |
|||
select { |
|||
case c := <-l.acceptCh: |
|||
close(c.doneCh) |
|||
delete(l.conns, c.rAddr.String()) |
|||
default: |
|||
l.connLock.Unlock() |
|||
l.connWG.Done() |
|||
l.readWG.Wait() |
|||
if errClose, ok := l.errClose.Load().(error); ok { |
|||
err = errClose |
|||
} |
|||
return |
|||
} |
|||
} |
|||
}) |
|||
return err |
|||
} |
|||
|
|||
func (l *udpPacketListener) Addr() net.Addr { |
|||
return l.pConn.LocalAddr() |
|||
} |
|||
|
|||
func (l *udpPacketListener) readLoop() { |
|||
defer l.readWG.Done() |
|||
defer close(l.readDoneCh) |
|||
|
|||
buf := make([]byte, udpReceiveMTU) |
|||
for { |
|||
n, raddr, err := l.pConn.ReadFrom(buf) |
|||
if err != nil { |
|||
l.errRead.Store(err) |
|||
return |
|||
} |
|||
l.dispatchMsg(raddr, buf[:n]) |
|||
} |
|||
} |
|||
|
|||
func (l *udpPacketListener) dispatchMsg(raddr net.Addr, buf []byte) { |
|||
conn, ok := l.getConn(raddr, buf) |
|||
if ok { |
|||
_, _ = conn.buffer.Write(buf) |
|||
} |
|||
} |
|||
|
|||
func (l *udpPacketListener) getConn(raddr net.Addr, buf []byte) (*udpPacketConn, bool) { |
|||
l.connLock.Lock() |
|||
defer l.connLock.Unlock() |
|||
|
|||
conn, ok := l.conns[raddr.String()] |
|||
if !ok { |
|||
if !l.accepting.Load() { |
|||
return nil, false |
|||
} |
|||
if l.acceptFilter != nil && !l.acceptFilter(buf) { |
|||
return nil, false |
|||
} |
|||
conn = &udpPacketConn{ |
|||
listener: l, |
|||
rAddr: raddr, |
|||
buffer: packetio.NewBuffer(), |
|||
doneCh: make(chan struct{}), |
|||
writeDeadline: deadline.New(), |
|||
} |
|||
select { |
|||
case l.acceptCh <- conn: |
|||
l.conns[raddr.String()] = conn |
|||
default: |
|||
return nil, false |
|||
} |
|||
} |
|||
return conn, true |
|||
} |
|||
|
|||
type udpPacketConn struct { |
|||
listener *udpPacketListener |
|||
rAddr net.Addr |
|||
buffer *packetio.Buffer |
|||
|
|||
doneCh chan struct{} |
|||
doneOnce sync.Once |
|||
|
|||
writeDeadline *deadline.Deadline |
|||
} |
|||
|
|||
func (c *udpPacketConn) ReadFrom(p []byte) (int, net.Addr, error) { |
|||
n, err := c.buffer.Read(p) |
|||
return n, c.rAddr, err |
|||
} |
|||
|
|||
func (c *udpPacketConn) WriteTo(p []byte, _ net.Addr) (int, error) { |
|||
select { |
|||
case <-c.writeDeadline.Done(): |
|||
return 0, context.DeadlineExceeded |
|||
default: |
|||
} |
|||
return c.listener.pConn.WriteTo(p, c.rAddr) |
|||
} |
|||
|
|||
func (c *udpPacketConn) Close() error { |
|||
var err error |
|||
c.doneOnce.Do(func() { |
|||
c.listener.connWG.Done() |
|||
close(c.doneCh) |
|||
c.listener.connLock.Lock() |
|||
delete(c.listener.conns, c.rAddr.String()) |
|||
c.listener.connLock.Unlock() |
|||
if errBuf := c.buffer.Close(); errBuf != nil { |
|||
err = errBuf |
|||
} |
|||
}) |
|||
return err |
|||
} |
|||
|
|||
func (c *udpPacketConn) LocalAddr() net.Addr { |
|||
return c.listener.pConn.LocalAddr() |
|||
} |
|||
|
|||
func (c *udpPacketConn) SetDeadline(t time.Time) error { |
|||
c.writeDeadline.Set(t) |
|||
return c.SetReadDeadline(t) |
|||
} |
|||
|
|||
func (c *udpPacketConn) SetReadDeadline(t time.Time) error { |
|||
return c.buffer.SetReadDeadline(t) |
|||
} |
|||
|
|||
func (c *udpPacketConn) SetWriteDeadline(t time.Time) error { |
|||
c.writeDeadline.Set(t) |
|||
return nil |
|||
} |
|||
|
|||
func isDTLSHandshakePacket(packet []byte) bool { |
|||
pkts, err := recordlayer.UnpackDatagram(packet) |
|||
if err != nil || len(pkts) == 0 { |
|||
return false |
|||
} |
|||
h := &recordlayer.Header{} |
|||
if err := h.Unmarshal(pkts[0]); err != nil { |
|||
return false |
|||
} |
|||
return h.ContentType == protocol.ContentTypeHandshake |
|||
} |
|||
Loading…
Reference in new issue