From 92bfb95bcf5b7f4d9dfad85f198e493eb5b84a6b Mon Sep 17 00:00:00 2001 From: karl Date: Thu, 2 Jan 2025 19:41:31 +0800 Subject: [PATCH 1/2] =?UTF-8?q?DNS=E8=A7=A3=E6=9E=90=E6=97=B6=EF=BC=8CTCP?= =?UTF-8?q?=20=E4=BB=8E=E5=93=AA=E4=B8=AA=20IP=20=E8=BF=9B=E6=9D=A5?= =?UTF-8?q?=EF=BC=8C=E5=B0=B1=E4=BB=8E=E5=93=AA=E4=B8=AA=20IP=20=E8=A7=A3?= =?UTF-8?q?=E6=9E=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- chain.go | 7 +++++++ dns.go | 6 +++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/chain.go b/chain.go index f98b781..59b0c95 100644 --- a/chain.go +++ b/chain.go @@ -212,6 +212,13 @@ func (c *Chain) dialWithOptions(ctx context.Context, network, address string, op IP: ip, } } + } else if inboundIP != nil && strings.ToLower(network) == "udp" { + ip := inboundIP.(net.IP) + if !ip.IsLoopback() && !ip.IsPrivate() { + d.LocalAddr = &net.UDPAddr{ + IP: ip, + } + } } return d.DialContext(ctx, network, ipAddr) } diff --git a/dns.go b/dns.go index 1b02404..6575846 100644 --- a/dns.go +++ b/dns.go @@ -83,7 +83,11 @@ func (h *dnsHandler) Handle(conn net.Conn) { if resolver == nil { resolver = defaultResolver } - reply, err := resolver.Exchange(context.Background(), b[:n]) + ctx := context.Background() + if inboundAddr, ok := conn.LocalAddr().(*net.TCPAddr); ok { + ctx = context.WithValue(ctx, "InboundIP", inboundAddr.IP) + } + reply, err := resolver.Exchange(ctx, b[:n]) if err != nil { log.Logf("[dns] %s - %s exchange: %v", conn.RemoteAddr(), conn.LocalAddr(), err) return From 5cdc7c1263e0c76d9a10fe1b623e5e7a20556ad0 Mon Sep 17 00:00:00 2001 From: karl Date: Thu, 2 Jan 2025 20:03:05 +0800 Subject: [PATCH 2/2] =?UTF-8?q?DNS=E8=A7=A3=E6=9E=90=E6=97=B6=EF=BC=8CTCP?= =?UTF-8?q?=20=E4=BB=8E=E5=93=AA=E4=B8=AA=20IP=20=E8=BF=9B=E6=9D=A5?= =?UTF-8?q?=EF=BC=8C=E5=B0=B1=E4=BB=8E=E5=93=AA=E4=B8=AA=20IP=20=E8=A7=A3?= =?UTF-8?q?=E6=9E=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- chain.go | 9 ++++----- resolver.go | 5 ++--- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/chain.go b/chain.go index 59b0c95..07d8fdc 100644 --- a/chain.go +++ b/chain.go @@ -154,7 +154,7 @@ func (c *Chain) dialWithOptions(ctx context.Context, network, address string, op ipAddr := address if address != "" { - ipAddr = c.resolve(address, options.Resolver, options.Hosts) + ipAddr = c.resolve(ctx, address, options.Resolver, options.Hosts) if ipAddr == "" { return nil, fmt.Errorf("resolver: domain %s does not exists", address) } @@ -213,8 +213,7 @@ func (c *Chain) dialWithOptions(ctx context.Context, network, address string, op } } } else if inboundIP != nil && strings.ToLower(network) == "udp" { - ip := inboundIP.(net.IP) - if !ip.IsLoopback() && !ip.IsPrivate() { + if ip, ok := inboundIP.(net.IP); ok && !ip.IsLoopback() { d.LocalAddr = &net.UDPAddr{ IP: ip, } @@ -237,7 +236,7 @@ func (c *Chain) dialWithOptions(ctx context.Context, network, address string, op return cc, nil } -func (*Chain) resolve(addr string, resolver Resolver, hosts *Hosts) string { +func (*Chain) resolve(ctx context.Context, addr string, resolver Resolver, hosts *Hosts) string { host, port, err := net.SplitHostPort(addr) if err != nil { return addr @@ -247,7 +246,7 @@ func (*Chain) resolve(addr string, resolver Resolver, hosts *Hosts) string { return net.JoinHostPort(ip.String(), port) } if resolver != nil { - ips, err := resolver.Resolve(host) + ips, err := resolver.Resolve(ctx, host) if err != nil { log.Logf("[resolver] %s: %v", host, err) } diff --git a/resolver.go b/resolver.go index 69e659d..a848056 100644 --- a/resolver.go +++ b/resolver.go @@ -171,7 +171,7 @@ type Resolver interface { // Init initializes the Resolver instance. Init(opts ...ResolverOption) error // Resolve returns a slice of that host's IPv4 and IPv6 addresses. - Resolve(host string) ([]net.IP, error) + Resolve(ctx context.Context, host string) ([]net.IP, error) // Exchange performs a synchronous query, // It sends the message query and waits for a reply. Exchange(ctx context.Context, query []byte) (reply []byte, err error) @@ -268,7 +268,7 @@ func (r *resolver) copyServers() []NameServer { return servers } -func (r *resolver) Resolve(host string) (ips []net.IP, err error) { +func (r *resolver) Resolve(ctx context.Context, host string) (ips []net.IP, err error) { r.mux.RLock() domain := r.domain r.mux.RUnlock() @@ -281,7 +281,6 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) { host = host + "." + domain } - ctx := context.Background() for _, ns := range r.copyServers() { ips, err = r.resolve(ctx, ns.exchanger, host) if err != nil {