You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

123 lines
2.8 KiB

//go:build linux
package main
import (
"fmt"
"net"
"sync"
"time"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
type packetInfoUDPConn struct {
conn *net.UDPConn
ipv4 *ipv4.PacketConn
ipv6 *ipv6.PacketConn
v6 bool
mu sync.RWMutex
localIPs map[string]net.IP
}
func listenPacketInfoUDP(network string, laddr *net.UDPAddr) (net.PacketConn, error) {
conn, err := net.ListenUDP(network, laddr)
if err != nil {
return nil, err
}
pc := &packetInfoUDPConn{
conn: conn,
ipv4: ipv4.NewPacketConn(conn),
ipv6: ipv6.NewPacketConn(conn),
v6: laddr != nil && laddr.IP != nil && laddr.IP.To4() == nil,
localIPs: make(map[string]net.IP),
}
if pc.v6 {
if err = pc.ipv6.SetControlMessage(ipv6.FlagDst, true); err != nil {
_ = conn.Close()
return nil, fmt.Errorf("enable IPv6 packet info: %w", err)
}
} else if err = pc.ipv4.SetControlMessage(ipv4.FlagDst, true); err != nil {
_ = conn.Close()
return nil, fmt.Errorf("enable IPv4 packet info: %w", err)
}
return pc, nil
}
func (c *packetInfoUDPConn) ReadFrom(p []byte) (int, net.Addr, error) {
if c.v6 {
n, cm, addr, err := c.ipv6.ReadFrom(p)
if err != nil {
return n, addr, err
}
if udpAddr, ok := addr.(*net.UDPAddr); ok && cm != nil && cm.Dst != nil {
c.rememberLocalIP(udpAddr.String(), cm.Dst)
}
return n, addr, nil
}
n, cm, addr, err := c.ipv4.ReadFrom(p)
if err != nil {
return n, addr, err
}
if udpAddr, ok := addr.(*net.UDPAddr); ok && cm != nil && cm.Dst != nil {
c.rememberLocalIP(udpAddr.String(), cm.Dst)
}
return n, addr, nil
}
func (c *packetInfoUDPConn) WriteTo(p []byte, addr net.Addr) (int, error) {
udpAddr, ok := addr.(*net.UDPAddr)
if !ok {
return 0, fmt.Errorf("packet info write: expected *net.UDPAddr, got %T", addr)
}
localIP := c.localIPFor(udpAddr.String())
if localIP == nil {
return c.conn.WriteTo(p, addr)
}
if localIP.To4() != nil {
return c.ipv4.WriteTo(p, &ipv4.ControlMessage{Src: localIP}, addr)
}
return c.ipv6.WriteTo(p, &ipv6.ControlMessage{Src: localIP}, addr)
}
func (c *packetInfoUDPConn) Close() error {
return c.conn.Close()
}
func (c *packetInfoUDPConn) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
func (c *packetInfoUDPConn) SetDeadline(t time.Time) error {
return c.conn.SetDeadline(t)
}
func (c *packetInfoUDPConn) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t)
}
func (c *packetInfoUDPConn) SetWriteDeadline(t time.Time) error {
return c.conn.SetWriteDeadline(t)
}
func (c *packetInfoUDPConn) rememberLocalIP(remote string, ip net.IP) {
c.mu.Lock()
defer c.mu.Unlock()
c.localIPs[remote] = append(net.IP(nil), ip...)
}
func (c *packetInfoUDPConn) localIPFor(remote string) net.IP {
c.mu.RLock()
defer c.mu.RUnlock()
ip := c.localIPs[remote]
if ip == nil {
return nil
}
return append(net.IP(nil), ip...)
}