diff --git a/chain.go b/chain.go index f98b781..07d8fdc 100644 --- a/chain.go +++ b/chain.go @@ -154,7 +154,7 @@ func (c *Chain) dialWithOptions(ctx context.Context, network, address string, op ipAddr := address if address != "" { - ipAddr = c.resolve(address, options.Resolver, options.Hosts) + ipAddr = c.resolve(ctx, address, options.Resolver, options.Hosts) if ipAddr == "" { return nil, fmt.Errorf("resolver: domain %s does not exists", address) } @@ -212,6 +212,12 @@ func (c *Chain) dialWithOptions(ctx context.Context, network, address string, op IP: ip, } } + } else if inboundIP != nil && strings.ToLower(network) == "udp" { + if ip, ok := inboundIP.(net.IP); ok && !ip.IsLoopback() { + d.LocalAddr = &net.UDPAddr{ + IP: ip, + } + } } return d.DialContext(ctx, network, ipAddr) } @@ -230,7 +236,7 @@ func (c *Chain) dialWithOptions(ctx context.Context, network, address string, op return cc, nil } -func (*Chain) resolve(addr string, resolver Resolver, hosts *Hosts) string { +func (*Chain) resolve(ctx context.Context, addr string, resolver Resolver, hosts *Hosts) string { host, port, err := net.SplitHostPort(addr) if err != nil { return addr @@ -240,7 +246,7 @@ func (*Chain) resolve(addr string, resolver Resolver, hosts *Hosts) string { return net.JoinHostPort(ip.String(), port) } if resolver != nil { - ips, err := resolver.Resolve(host) + ips, err := resolver.Resolve(ctx, host) if err != nil { log.Logf("[resolver] %s: %v", host, err) } diff --git a/dns.go b/dns.go index 1b02404..6575846 100644 --- a/dns.go +++ b/dns.go @@ -83,7 +83,11 @@ func (h *dnsHandler) Handle(conn net.Conn) { if resolver == nil { resolver = defaultResolver } - reply, err := resolver.Exchange(context.Background(), b[:n]) + ctx := context.Background() + if inboundAddr, ok := conn.LocalAddr().(*net.TCPAddr); ok { + ctx = context.WithValue(ctx, "InboundIP", inboundAddr.IP) + } + reply, err := resolver.Exchange(ctx, b[:n]) if err != nil { log.Logf("[dns] %s - %s exchange: %v", conn.RemoteAddr(), conn.LocalAddr(), err) return diff --git a/resolver.go b/resolver.go index 69e659d..a848056 100644 --- a/resolver.go +++ b/resolver.go @@ -171,7 +171,7 @@ type Resolver interface { // Init initializes the Resolver instance. Init(opts ...ResolverOption) error // Resolve returns a slice of that host's IPv4 and IPv6 addresses. - Resolve(host string) ([]net.IP, error) + Resolve(ctx context.Context, host string) ([]net.IP, error) // Exchange performs a synchronous query, // It sends the message query and waits for a reply. Exchange(ctx context.Context, query []byte) (reply []byte, err error) @@ -268,7 +268,7 @@ func (r *resolver) copyServers() []NameServer { return servers } -func (r *resolver) Resolve(host string) (ips []net.IP, err error) { +func (r *resolver) Resolve(ctx context.Context, host string) (ips []net.IP, err error) { r.mux.RLock() domain := r.domain r.mux.RUnlock() @@ -281,7 +281,6 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) { host = host + "." + domain } - ctx := context.Background() for _, ns := range r.copyServers() { ips, err = r.resolve(ctx, ns.exchanger, host) if err != nil {