Browse Source

DNS解析时,TCP 从哪个 IP 进来,就从哪个 IP 解析

pull/1077/head
karl 2 years ago
parent
commit
5cdc7c1263
  1. 9
      chain.go
  2. 5
      resolver.go

9
chain.go

@ -154,7 +154,7 @@ func (c *Chain) dialWithOptions(ctx context.Context, network, address string, op
ipAddr := address
if address != "" {
ipAddr = c.resolve(address, options.Resolver, options.Hosts)
ipAddr = c.resolve(ctx, address, options.Resolver, options.Hosts)
if ipAddr == "" {
return nil, fmt.Errorf("resolver: domain %s does not exists", address)
}
@ -213,8 +213,7 @@ func (c *Chain) dialWithOptions(ctx context.Context, network, address string, op
}
}
} else if inboundIP != nil && strings.ToLower(network) == "udp" {
ip := inboundIP.(net.IP)
if !ip.IsLoopback() && !ip.IsPrivate() {
if ip, ok := inboundIP.(net.IP); ok && !ip.IsLoopback() {
d.LocalAddr = &net.UDPAddr{
IP: ip,
}
@ -237,7 +236,7 @@ func (c *Chain) dialWithOptions(ctx context.Context, network, address string, op
return cc, nil
}
func (*Chain) resolve(addr string, resolver Resolver, hosts *Hosts) string {
func (*Chain) resolve(ctx context.Context, addr string, resolver Resolver, hosts *Hosts) string {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return addr
@ -247,7 +246,7 @@ func (*Chain) resolve(addr string, resolver Resolver, hosts *Hosts) string {
return net.JoinHostPort(ip.String(), port)
}
if resolver != nil {
ips, err := resolver.Resolve(host)
ips, err := resolver.Resolve(ctx, host)
if err != nil {
log.Logf("[resolver] %s: %v", host, err)
}

5
resolver.go

@ -171,7 +171,7 @@ type Resolver interface {
// Init initializes the Resolver instance.
Init(opts ...ResolverOption) error
// Resolve returns a slice of that host's IPv4 and IPv6 addresses.
Resolve(host string) ([]net.IP, error)
Resolve(ctx context.Context, host string) ([]net.IP, error)
// Exchange performs a synchronous query,
// It sends the message query and waits for a reply.
Exchange(ctx context.Context, query []byte) (reply []byte, err error)
@ -268,7 +268,7 @@ func (r *resolver) copyServers() []NameServer {
return servers
}
func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
func (r *resolver) Resolve(ctx context.Context, host string) (ips []net.IP, err error) {
r.mux.RLock()
domain := r.domain
r.mux.RUnlock()
@ -281,7 +281,6 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
host = host + "." + domain
}
ctx := context.Background()
for _, ns := range r.copyServers() {
ips, err = r.resolve(ctx, ns.exchanger, host)
if err != nil {

Loading…
Cancel
Save