committed by
GitHub
8 changed files with 2237 additions and 518 deletions
@ -0,0 +1,601 @@ |
|||
// DNS-over-HTTPS resolver for mobile networks where UDP/53 is blocked or
|
|||
// spoofed.
|
|||
|
|||
package main |
|||
|
|||
import ( |
|||
"bytes" |
|||
"context" |
|||
"crypto/tls" |
|||
"errors" |
|||
"fmt" |
|||
"io" |
|||
"log" |
|||
"net" |
|||
"net/http" |
|||
"sort" |
|||
"sync" |
|||
"sync/atomic" |
|||
"time" |
|||
|
|||
"github.com/miekg/dns" |
|||
|
|||
// Embedded Mozilla CA roots for CGO_ENABLED=0 builds (Android).
|
|||
_ "golang.org/x/crypto/x509roots/fallback" |
|||
) |
|||
|
|||
const ( |
|||
dohQueryTimeout = 6 * time.Second |
|||
dohCacheMinTTL = 10 * time.Second |
|||
dohCacheMaxTTL = 1 * time.Hour |
|||
dohMaxResponseBytes = 64 * 1024 |
|||
dohContentType = "application/dns-message" |
|||
|
|||
dohDialerTimeout = 5 * time.Second |
|||
dohDialerKeepAlive = 30 * time.Second |
|||
appDialerTimeout = 20 * time.Second |
|||
appDialerKeepAlive = 30 * time.Second |
|||
|
|||
forwarderUDPBufSize = 4096 |
|||
forwarderTCPReadDL = 30 * time.Second |
|||
forwarderTCPWriteDL = 10 * time.Second |
|||
autoUDPBudget = 1500 * time.Millisecond |
|||
) |
|||
|
|||
// DohEndpoint describes a single DNS-over-HTTPS server together with the IPs
|
|||
// we bootstrap to — so that resolving the endpoint hostname does not itself
|
|||
// require DNS.
|
|||
type DohEndpoint struct { |
|||
URL string |
|||
Hostname string |
|||
BootstrapIPs []string |
|||
} |
|||
|
|||
// Yandex is tried first because it tends to stay reachable on RU mobile
|
|||
// operators even when international resolvers get blocked; Google and
|
|||
// Cloudflare follow as fallbacks.
|
|||
var defaultDohEndpoints = []DohEndpoint{ |
|||
{"https://common.dot.dns.yandex.net/dns-query", "common.dot.dns.yandex.net", []string{"77.88.8.8", "77.88.8.1"}}, |
|||
{"https://secure.dot.dns.yandex.net/dns-query", "secure.dot.dns.yandex.net", []string{"77.88.8.88", "77.88.8.2"}}, |
|||
{"https://family.dot.dns.yandex.net/dns-query", "family.dot.dns.yandex.net", []string{"77.88.8.7", "77.88.8.3"}}, |
|||
{"https://dns.google/dns-query", "dns.google", []string{"8.8.8.8", "8.8.4.4"}}, |
|||
{"https://cloudflare-dns.com/dns-query", "cloudflare-dns.com", []string{"1.1.1.1", "1.0.0.1"}}, |
|||
} |
|||
|
|||
// DohResolver resolves hostnames to IPs via DNS-over-HTTPS (RFC 8484).
|
|||
type DohResolver struct { |
|||
endpoints []DohEndpoint |
|||
client *http.Client |
|||
cache *dohCache |
|||
} |
|||
|
|||
// NewDohResolver constructs a resolver using defaultDohEndpoints if endpoints
|
|||
// is nil. Endpoint hostnames are dialed by IP using BootstrapIPs, so the DoH
|
|||
// transport never depends on the system resolver.
|
|||
func NewDohResolver(endpoints []DohEndpoint) *DohResolver { |
|||
if len(endpoints) == 0 { |
|||
endpoints = defaultDohEndpoints |
|||
} |
|||
return &DohResolver{ |
|||
endpoints: endpoints, |
|||
client: &http.Client{Timeout: dohQueryTimeout, Transport: newBootstrapTransport(endpoints)}, |
|||
cache: newDohCache(), |
|||
} |
|||
} |
|||
|
|||
// newDohResolverWithClient is a test hook that skips the bootstrap transport.
|
|||
func newDohResolverWithClient(endpoints []DohEndpoint, client *http.Client) *DohResolver { |
|||
return &DohResolver{endpoints: endpoints, client: client, cache: newDohCache()} |
|||
} |
|||
|
|||
// newBootstrapTransport returns an http.Transport whose DialContext only
|
|||
// knows how to reach the configured DoH endpoint hostnames, by mapping each
|
|||
// to its BootstrapIPs.
|
|||
func newBootstrapTransport(endpoints []DohEndpoint) *http.Transport { |
|||
bootstrap := make(map[string][]string, len(endpoints)) |
|||
for _, ep := range endpoints { |
|||
bootstrap[ep.Hostname] = ep.BootstrapIPs |
|||
} |
|||
dialer := &net.Dialer{Timeout: dohDialerTimeout, KeepAlive: dohDialerKeepAlive} |
|||
|
|||
return &http.Transport{ |
|||
MaxIdleConns: 8, |
|||
MaxIdleConnsPerHost: 2, |
|||
IdleConnTimeout: 90 * time.Second, |
|||
ForceAttemptHTTP2: true, |
|||
TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12}, |
|||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { |
|||
host, port, err := net.SplitHostPort(addr) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
ips, ok := bootstrap[host] |
|||
if !ok { |
|||
return nil, fmt.Errorf("doh: no bootstrap IPs for %q", host) |
|||
} |
|||
var lastErr error |
|||
for _, ip := range ips { |
|||
conn, derr := dialer.DialContext(ctx, network, net.JoinHostPort(ip, port)) |
|||
if derr == nil { |
|||
return conn, nil |
|||
} |
|||
lastErr = derr |
|||
} |
|||
return nil, lastErr |
|||
}, |
|||
} |
|||
} |
|||
|
|||
// LookupIPAddr resolves host to a combined list of A+AAAA IPs (IPv4 first).
|
|||
// Cached results bypass the network entirely.
|
|||
func (r *DohResolver) LookupIPAddr(ctx context.Context, host string) ([]net.IP, error) { |
|||
if ip := net.ParseIP(host); ip != nil { |
|||
return []net.IP{ip}, nil |
|||
} |
|||
if ips, ok := r.cache.get(host); ok { |
|||
return ips, nil |
|||
} |
|||
|
|||
type res struct { |
|||
ips []net.IP |
|||
ttl time.Duration |
|||
err error |
|||
} |
|||
results := make(chan res, 2) |
|||
for _, qt := range [...]uint16{dns.TypeA, dns.TypeAAAA} { |
|||
go func(qtype uint16) { |
|||
ips, ttl, err := r.queryIPs(ctx, host, qtype) |
|||
results <- res{ips, ttl, err} |
|||
}(qt) |
|||
} |
|||
|
|||
var ( |
|||
all []net.IP |
|||
lastErr error |
|||
minTTL = dohCacheMaxTTL |
|||
) |
|||
for range 2 { |
|||
rr := <-results |
|||
if rr.err != nil { |
|||
lastErr = rr.err |
|||
continue |
|||
} |
|||
all = append(all, rr.ips...) |
|||
if rr.ttl > 0 && rr.ttl < minTTL { |
|||
minTTL = rr.ttl |
|||
} |
|||
} |
|||
if len(all) == 0 { |
|||
if lastErr == nil { |
|||
lastErr = fmt.Errorf("doh: no records for %s", host) |
|||
} |
|||
return nil, lastErr |
|||
} |
|||
|
|||
// IPv4 before IPv6 — better compat with mobile IPv4-only CGNAT.
|
|||
sort.SliceStable(all, func(i, j int) bool { |
|||
return (all[i].To4() != nil) && (all[j].To4() == nil) |
|||
}) |
|||
|
|||
if minTTL < dohCacheMinTTL { |
|||
minTTL = dohCacheMinTTL |
|||
} |
|||
r.cache.set(host, all, minTTL) |
|||
return all, nil |
|||
} |
|||
|
|||
// queryIPs issues one DoH query for qtype, walking endpoints until one
|
|||
// succeeds, and parses the wire reply into IPs + min TTL.
|
|||
func (r *DohResolver) queryIPs(ctx context.Context, host string, qtype uint16) ([]net.IP, time.Duration, error) { |
|||
m := new(dns.Msg) |
|||
m.SetQuestion(dns.Fqdn(host), qtype) |
|||
m.Id = 0 // RFC 8484 §4.1 — zero ID is cache-friendly on shared caches.
|
|||
m.RecursionDesired = true |
|||
wire, err := m.Pack() |
|||
if err != nil { |
|||
return nil, 0, fmt.Errorf("doh: pack query: %w", err) |
|||
} |
|||
|
|||
body, ep, err := r.forwardRaw(ctx, wire) |
|||
if err != nil { |
|||
return nil, 0, err |
|||
} |
|||
ips, ttl, err := parseAnswer(body) |
|||
if err != nil { |
|||
return nil, 0, fmt.Errorf("doh: parse %s: %w", ep.Hostname, err) |
|||
} |
|||
log.Printf("[DoH] %s %s via %s → %d IPs (ttl %s)", host, dns.TypeToString[qtype], ep.Hostname, len(ips), ttl) |
|||
return ips, ttl, nil |
|||
} |
|||
|
|||
// parseAnswer decodes a DNS wire reply into A/AAAA records and the minimum TTL.
|
|||
func parseAnswer(body []byte) ([]net.IP, time.Duration, error) { |
|||
reply := new(dns.Msg) |
|||
if err := reply.Unpack(body); err != nil { |
|||
return nil, 0, fmt.Errorf("unpack: %w", err) |
|||
} |
|||
if reply.Rcode != dns.RcodeSuccess { |
|||
return nil, 0, fmt.Errorf("rcode %s", dns.RcodeToString[reply.Rcode]) |
|||
} |
|||
var ( |
|||
ips []net.IP |
|||
minTTL uint32 |
|||
) |
|||
updateTTL := func(ttl uint32) { |
|||
if minTTL == 0 || ttl < minTTL { |
|||
minTTL = ttl |
|||
} |
|||
} |
|||
for _, ans := range reply.Answer { |
|||
switch a := ans.(type) { |
|||
case *dns.A: |
|||
ips = append(ips, a.A) |
|||
updateTTL(a.Hdr.Ttl) |
|||
case *dns.AAAA: |
|||
ips = append(ips, a.AAAA) |
|||
updateTTL(a.Hdr.Ttl) |
|||
} |
|||
} |
|||
return ips, time.Duration(minTTL) * time.Second, nil |
|||
} |
|||
|
|||
// forwardRaw POSTs an opaque DNS-wire query to the configured DoH endpoints
|
|||
// in order and returns the first successful raw response together with the
|
|||
// endpoint that produced it. No parsing — useful for the local forwarder
|
|||
// which needs to pass through whatever the upstream resolver answers
|
|||
// (RESINFO/HTTPS/SVCB/EDNS options/…).
|
|||
func (r *DohResolver) forwardRaw(ctx context.Context, query []byte) ([]byte, DohEndpoint, error) { |
|||
if len(r.endpoints) == 0 { |
|||
return nil, DohEndpoint{}, errors.New("doh: no endpoints configured") |
|||
} |
|||
var lastErr error |
|||
for _, ep := range r.endpoints { |
|||
body, err := r.postWire(ctx, ep, query) |
|||
if err != nil { |
|||
log.Printf("[DoH] %s: %v", ep.Hostname, err) |
|||
lastErr = err |
|||
continue |
|||
} |
|||
return body, ep, nil |
|||
} |
|||
return nil, DohEndpoint{}, lastErr |
|||
} |
|||
|
|||
// postWire performs a single application/dns-message POST to one endpoint.
|
|||
func (r *DohResolver) postWire(ctx context.Context, ep DohEndpoint, query []byte) ([]byte, error) { |
|||
req, err := http.NewRequestWithContext(ctx, "POST", ep.URL, bytes.NewReader(query)) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("build request: %w", err) |
|||
} |
|||
req.Header.Set("Content-Type", dohContentType) |
|||
req.Header.Set("Accept", dohContentType) |
|||
|
|||
resp, err := r.client.Do(req) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
defer func() { _ = resp.Body.Close() }() |
|||
|
|||
if resp.StatusCode != http.StatusOK { |
|||
_, _ = io.Copy(io.Discard, resp.Body) |
|||
return nil, fmt.Errorf("HTTP %d", resp.StatusCode) |
|||
} |
|||
body, err := io.ReadAll(io.LimitReader(resp.Body, dohMaxResponseBytes)) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("read body: %w", err) |
|||
} |
|||
return body, nil |
|||
} |
|||
|
|||
type dohCacheEntry struct { |
|||
ips []net.IP |
|||
expiry time.Time |
|||
} |
|||
|
|||
type dohCache struct { |
|||
mu sync.RWMutex |
|||
m map[string]dohCacheEntry |
|||
} |
|||
|
|||
func newDohCache() *dohCache { |
|||
return &dohCache{m: make(map[string]dohCacheEntry)} |
|||
} |
|||
|
|||
func (c *dohCache) get(host string) ([]net.IP, bool) { |
|||
c.mu.RLock() |
|||
e, ok := c.m[host] |
|||
c.mu.RUnlock() |
|||
if !ok || time.Now().After(e.expiry) { |
|||
return nil, false |
|||
} |
|||
out := make([]net.IP, len(e.ips)) |
|||
copy(out, e.ips) |
|||
return out, true |
|||
} |
|||
|
|||
func (c *dohCache) set(host string, ips []net.IP, ttl time.Duration) { |
|||
if ttl <= 0 { |
|||
return |
|||
} |
|||
if ttl > dohCacheMaxTTL { |
|||
ttl = dohCacheMaxTTL |
|||
} |
|||
cp := make([]net.IP, len(ips)) |
|||
copy(cp, ips) |
|||
c.mu.Lock() |
|||
c.m[host] = dohCacheEntry{ips: cp, expiry: time.Now().Add(ttl)} |
|||
c.mu.Unlock() |
|||
} |
|||
|
|||
// Go's net.Resolver dials this stub like a regular nameserver, which avoids
|
|||
// the many edge cases of a fake-net.Conn approach (RESINFO probes, EDNS
|
|||
// handshakes, truncation, …). Whatever it reads on UDP/TCP is sent verbatim
|
|||
// to a DoH endpoint and the wire response is sent back to the client.
|
|||
|
|||
type dohForwarder struct { |
|||
udpAddr string |
|||
tcpAddr string |
|||
} |
|||
|
|||
var ( |
|||
dohForwarderOnce sync.Once |
|||
dohForwarderInst *dohForwarder |
|||
dohForwarderErr error |
|||
) |
|||
|
|||
// sharedDohForwarder lazily starts a process-wide forwarder bound to the
|
|||
// supplied resolver. The first caller wins; subsequent callers reuse the
|
|||
// same forwarder regardless of what they pass in.
|
|||
func sharedDohForwarder(r *DohResolver) (*dohForwarder, error) { |
|||
dohForwarderOnce.Do(func() { |
|||
dohForwarderInst, dohForwarderErr = startDohForwarder(r) |
|||
}) |
|||
return dohForwarderInst, dohForwarderErr |
|||
} |
|||
|
|||
func startDohForwarder(r *DohResolver) (_ *dohForwarder, err error) { |
|||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("doh forwarder: listen UDP: %w", err) |
|||
} |
|||
defer func() { |
|||
if err != nil { |
|||
_ = udpConn.Close() |
|||
} |
|||
}() |
|||
tcpLn, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("doh forwarder: listen TCP: %w", err) |
|||
} |
|||
defer func() { |
|||
if err != nil { |
|||
_ = tcpLn.Close() |
|||
} |
|||
}() |
|||
|
|||
fwd := &dohForwarder{ |
|||
udpAddr: udpConn.LocalAddr().String(), |
|||
tcpAddr: tcpLn.Addr().String(), |
|||
} |
|||
log.Printf("[DoH] forwarder listening udp=%s tcp=%s", fwd.udpAddr, fwd.tcpAddr) |
|||
|
|||
go fwd.serveUDP(udpConn, r) |
|||
go fwd.serveTCP(tcpLn, r) |
|||
return fwd, nil |
|||
} |
|||
|
|||
func (f *dohForwarder) serveUDP(conn *net.UDPConn, r *DohResolver) { |
|||
defer func() { _ = conn.Close() }() |
|||
buf := make([]byte, forwarderUDPBufSize) |
|||
for { |
|||
n, client, err := conn.ReadFromUDP(buf) |
|||
if err != nil { |
|||
log.Printf("[DoH] udp read: %v", err) |
|||
return |
|||
} |
|||
query := append([]byte(nil), buf[:n]...) |
|||
go func(q []byte, c *net.UDPAddr) { |
|||
ctx, cancel := context.WithTimeout(context.Background(), dohQueryTimeout) |
|||
defer cancel() |
|||
resp, _, err := r.forwardRaw(ctx, q) |
|||
if err != nil { |
|||
log.Printf("[DoH] udp forward failed: %v", err) |
|||
return |
|||
} |
|||
if _, err := conn.WriteToUDP(resp, c); err != nil { |
|||
log.Printf("[DoH] udp write: %v", err) |
|||
} |
|||
}(query, client) |
|||
} |
|||
} |
|||
|
|||
func (f *dohForwarder) serveTCP(ln *net.TCPListener, r *DohResolver) { |
|||
defer func() { _ = ln.Close() }() |
|||
for { |
|||
conn, err := ln.Accept() |
|||
if err != nil { |
|||
log.Printf("[DoH] tcp accept: %v", err) |
|||
return |
|||
} |
|||
go handleDohForwarderTCP(conn, r) |
|||
} |
|||
} |
|||
|
|||
func handleDohForwarderTCP(conn net.Conn, r *DohResolver) { |
|||
defer func() { _ = conn.Close() }() |
|||
for { |
|||
_ = conn.SetReadDeadline(time.Now().Add(forwarderTCPReadDL)) |
|||
var lenBuf [2]byte |
|||
if _, err := io.ReadFull(conn, lenBuf[:]); err != nil { |
|||
return |
|||
} |
|||
qlen := int(lenBuf[0])<<8 | int(lenBuf[1]) |
|||
if qlen == 0 || qlen > forwarderUDPBufSize { |
|||
return |
|||
} |
|||
query := make([]byte, qlen) |
|||
if _, err := io.ReadFull(conn, query); err != nil { |
|||
return |
|||
} |
|||
|
|||
ctx, cancel := context.WithTimeout(context.Background(), dohQueryTimeout) |
|||
resp, _, err := r.forwardRaw(ctx, query) |
|||
cancel() |
|||
if err != nil { |
|||
log.Printf("[DoH] tcp forward failed: %v", err) |
|||
return |
|||
} |
|||
out := make([]byte, 2+len(resp)) |
|||
out[0] = byte(len(resp) >> 8) |
|||
out[1] = byte(len(resp)) |
|||
copy(out[2:], resp) |
|||
_ = conn.SetWriteDeadline(time.Now().Add(forwarderTCPWriteDL)) |
|||
if _, err := conn.Write(out); err != nil { |
|||
return |
|||
} |
|||
} |
|||
} |
|||
|
|||
// dohForwarderDial returns a Resolver.Dial that connects to the local DoH
|
|||
// forwarder over UDP or TCP (whichever the resolver asked for).
|
|||
func dohForwarderDial(r *DohResolver) dialFunc { |
|||
return func(ctx context.Context, network, _ string) (net.Conn, error) { |
|||
fwd, err := sharedDohForwarder(r) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
var d net.Dialer |
|||
switch network { |
|||
case "tcp", "tcp4", "tcp6": |
|||
return d.DialContext(ctx, "tcp", fwd.tcpAddr) |
|||
default: |
|||
return d.DialContext(ctx, "udp", fwd.udpAddr) |
|||
} |
|||
} |
|||
} |
|||
|
|||
const ( |
|||
DNSModeUDP = "udp" |
|||
DNSModeDoH = "doh" |
|||
DNSModeAuto = "auto" |
|||
) |
|||
|
|||
var udpDNSServers = []string{ |
|||
"77.88.8.8:53", "77.88.8.1:53", |
|||
"8.8.8.8:53", "8.8.4.4:53", |
|||
"1.1.1.1:53", "1.0.0.1:53", |
|||
} |
|||
|
|||
type dialFunc = func(context.Context, string, string) (net.Conn, error) |
|||
|
|||
// buildDialer returns a net.Dialer whose internal Go resolver uses the
|
|||
// chosen DNS transport. In "auto" mode the first total-failure of UDP/53
|
|||
// sticks the process onto DoH for the rest of its lifetime.
|
|||
func buildDialer(mode string, r *DohResolver) net.Dialer { |
|||
switch mode { |
|||
case DNSModeUDP: |
|||
return newAppDialer(udpDNSDial) |
|||
case DNSModeDoH: |
|||
return newAppDialer(dohForwarderDial(r)) |
|||
case DNSModeAuto: |
|||
return newAppDialer(autoDial(r)) |
|||
default: |
|||
log.Panicf("unknown DNS mode %q", mode) |
|||
return net.Dialer{} |
|||
} |
|||
} |
|||
|
|||
// newAppDialer wraps a Resolver.Dial with the timeouts used everywhere in
|
|||
// the app for outbound TCP/HTTP connections.
|
|||
func newAppDialer(dial dialFunc) net.Dialer { |
|||
return net.Dialer{ |
|||
Timeout: appDialerTimeout, |
|||
KeepAlive: appDialerKeepAlive, |
|||
Resolver: &net.Resolver{PreferGo: true, Dial: dial}, |
|||
} |
|||
} |
|||
|
|||
// udpDNSDial picks the first reachable UDP/53 resolver from udpDNSServers.
|
|||
func udpDNSDial(ctx context.Context, _ string, _ string) (net.Conn, error) { |
|||
var ( |
|||
d net.Dialer |
|||
lastErr error |
|||
) |
|||
for _, s := range udpDNSServers { |
|||
conn, err := d.DialContext(ctx, "udp", s) |
|||
if err == nil { |
|||
return conn, nil |
|||
} |
|||
lastErr = err |
|||
} |
|||
if lastErr == nil { |
|||
lastErr = errors.New("no UDP DNS servers available") |
|||
} |
|||
return nil, lastErr |
|||
} |
|||
|
|||
// autoDial returns a Dial that probes UDP/53 once with a real DNS round-trip;
|
|||
// if the probe fails it latches onto DoH for the rest of the process. Built
|
|||
// for Android, where the network can flip between Wi-Fi (UDP/53 works) and
|
|||
// mobile (UDP/53 blocked).
|
|||
//
|
|||
// A simple dial-timeout doesn't work for UDP because UDP "dial" is
|
|||
// connectionless and always succeeds instantly. The only way to know whether
|
|||
// UDP/53 actually works is to send a real query and wait for a response.
|
|||
func autoDial(r *DohResolver) dialFunc { |
|||
var ( |
|||
probed sync.Once |
|||
useDoH atomic.Bool |
|||
doh = dohForwarderDial(r) |
|||
) |
|||
return func(ctx context.Context, network, addr string) (net.Conn, error) { |
|||
probed.Do(func() { |
|||
if udpProbe(autoUDPBudget) { |
|||
log.Printf("[DNS] UDP/53 probe OK, using UDP") |
|||
} else { |
|||
log.Printf("[DNS] UDP/53 unreachable; sticky-switching to DoH") |
|||
useDoH.Store(true) |
|||
} |
|||
}) |
|||
if useDoH.Load() { |
|||
return doh(ctx, network, addr) |
|||
} |
|||
return udpDNSDial(ctx, network, addr) |
|||
} |
|||
} |
|||
|
|||
// udpProbe sends a real DNS A query for a well-known domain via UDP and
|
|||
// checks whether any response arrives within the deadline. We try the first
|
|||
// two servers from udpDNSServers under a shared deadline — if neither
|
|||
// responds, UDP/53 is blocked.
|
|||
func udpProbe(timeout time.Duration) bool { |
|||
m := new(dns.Msg) |
|||
m.SetQuestion("dns.google.", dns.TypeA) |
|||
m.RecursionDesired = true |
|||
wire, err := m.Pack() |
|||
if err != nil { |
|||
return false |
|||
} |
|||
|
|||
deadline := time.Now().Add(timeout) |
|||
buf := make([]byte, 512) |
|||
limit := min(len(udpDNSServers), 2) |
|||
for _, server := range udpDNSServers[:limit] { |
|||
remaining := time.Until(deadline) |
|||
if remaining <= 0 { |
|||
break |
|||
} |
|||
conn, err := net.DialTimeout("udp", server, remaining) |
|||
if err != nil { |
|||
continue |
|||
} |
|||
_ = conn.SetDeadline(deadline) |
|||
_, _ = conn.Write(wire) |
|||
n, err := conn.Read(buf) |
|||
_ = conn.Close() |
|||
if err == nil && n > 12 { |
|||
return true |
|||
} |
|||
} |
|||
return false |
|||
} |
|||
@ -0,0 +1,197 @@ |
|||
package main |
|||
|
|||
import ( |
|||
"context" |
|||
"io" |
|||
"net" |
|||
"net/http" |
|||
"net/http/httptest" |
|||
"sync/atomic" |
|||
"testing" |
|||
"time" |
|||
|
|||
"github.com/miekg/dns" |
|||
) |
|||
|
|||
// dohAnswer builds a wire-format DNS reply for a single question with one
|
|||
// answer of the matching type (A or AAAA). TTL is returned as-is.
|
|||
func dohAnswer(t *testing.T, query []byte, ip net.IP, ttl uint32) []byte { |
|||
t.Helper() |
|||
req := new(dns.Msg) |
|||
if err := req.Unpack(query); err != nil { |
|||
t.Fatalf("unpack query: %v", err) |
|||
} |
|||
reply := new(dns.Msg) |
|||
reply.SetReply(req) |
|||
if len(req.Question) != 1 { |
|||
t.Fatalf("expected 1 question, got %d", len(req.Question)) |
|||
} |
|||
q := req.Question[0] |
|||
switch q.Qtype { |
|||
case dns.TypeA: |
|||
if v4 := ip.To4(); v4 != nil { |
|||
reply.Answer = append(reply.Answer, &dns.A{ |
|||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: ttl}, |
|||
A: v4, |
|||
}) |
|||
} |
|||
case dns.TypeAAAA: |
|||
if ip.To4() == nil { |
|||
reply.Answer = append(reply.Answer, &dns.AAAA{ |
|||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: ttl}, |
|||
AAAA: ip, |
|||
}) |
|||
} |
|||
} |
|||
out, err := reply.Pack() |
|||
if err != nil { |
|||
t.Fatalf("pack reply: %v", err) |
|||
} |
|||
return out |
|||
} |
|||
|
|||
func readWire(t *testing.T, r io.Reader) []byte { |
|||
t.Helper() |
|||
b, err := io.ReadAll(r) |
|||
if err != nil { |
|||
t.Fatalf("read body: %v", err) |
|||
} |
|||
return b |
|||
} |
|||
|
|||
func TestDohResolver_LookupIPAddr_Success(t *testing.T) { |
|||
var hits atomic.Int32 |
|||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|||
hits.Add(1) |
|||
if ct := r.Header.Get("Content-Type"); ct != "application/dns-message" { |
|||
t.Errorf("wrong Content-Type: %q", ct) |
|||
} |
|||
body := readWire(t, r.Body) |
|||
w.Header().Set("Content-Type", "application/dns-message") |
|||
w.WriteHeader(http.StatusOK) |
|||
_, _ = w.Write(dohAnswer(t, body, net.ParseIP("93.184.216.34"), 300)) |
|||
})) |
|||
defer srv.Close() |
|||
|
|||
r := newDohResolverWithClient( |
|||
[]DohEndpoint{{URL: srv.URL, Hostname: "mock", BootstrapIPs: []string{"127.0.0.1"}}}, |
|||
srv.Client(), |
|||
) |
|||
|
|||
ips, err := r.LookupIPAddr(context.Background(), "example.com") |
|||
if err != nil { |
|||
t.Fatalf("lookup: %v", err) |
|||
} |
|||
if len(ips) == 0 { |
|||
t.Fatalf("no ips returned") |
|||
} |
|||
if ips[0].String() != "93.184.216.34" { |
|||
t.Fatalf("unexpected ip %s", ips[0]) |
|||
} |
|||
// Two concurrent queries fire (A + AAAA), so we expect 2 hits.
|
|||
if got := hits.Load(); got != 2 { |
|||
t.Fatalf("expected 2 HTTP hits, got %d", got) |
|||
} |
|||
} |
|||
|
|||
func TestDohResolver_Fallback(t *testing.T) { |
|||
var firstHits, secondHits atomic.Int32 |
|||
first := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|||
firstHits.Add(1) |
|||
w.WriteHeader(http.StatusInternalServerError) |
|||
})) |
|||
defer first.Close() |
|||
second := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|||
secondHits.Add(1) |
|||
body := readWire(t, r.Body) |
|||
w.Header().Set("Content-Type", "application/dns-message") |
|||
_, _ = w.Write(dohAnswer(t, body, net.ParseIP("1.2.3.4"), 300)) |
|||
})) |
|||
defer second.Close() |
|||
|
|||
r := newDohResolverWithClient( |
|||
[]DohEndpoint{ |
|||
{URL: first.URL, Hostname: "first", BootstrapIPs: []string{"127.0.0.1"}}, |
|||
{URL: second.URL, Hostname: "second", BootstrapIPs: []string{"127.0.0.1"}}, |
|||
}, |
|||
first.Client(), |
|||
) |
|||
ips, err := r.LookupIPAddr(context.Background(), "example.com") |
|||
if err != nil { |
|||
t.Fatalf("lookup: %v", err) |
|||
} |
|||
if len(ips) != 1 || ips[0].String() != "1.2.3.4" { |
|||
t.Fatalf("unexpected ips: %v", ips) |
|||
} |
|||
if firstHits.Load() == 0 || secondHits.Load() == 0 { |
|||
t.Fatalf("fallback did not probe both endpoints: first=%d second=%d", firstHits.Load(), secondHits.Load()) |
|||
} |
|||
} |
|||
|
|||
func TestDohResolver_Cache(t *testing.T) { |
|||
var hits atomic.Int32 |
|||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|||
hits.Add(1) |
|||
body := readWire(t, r.Body) |
|||
w.Header().Set("Content-Type", "application/dns-message") |
|||
_, _ = w.Write(dohAnswer(t, body, net.ParseIP("5.6.7.8"), 300)) |
|||
})) |
|||
defer srv.Close() |
|||
|
|||
r := newDohResolverWithClient( |
|||
[]DohEndpoint{{URL: srv.URL, Hostname: "mock", BootstrapIPs: []string{"127.0.0.1"}}}, |
|||
srv.Client(), |
|||
) |
|||
if _, err := r.LookupIPAddr(context.Background(), "example.com"); err != nil { |
|||
t.Fatalf("first lookup: %v", err) |
|||
} |
|||
firstHits := hits.Load() |
|||
if _, err := r.LookupIPAddr(context.Background(), "example.com"); err != nil { |
|||
t.Fatalf("second lookup: %v", err) |
|||
} |
|||
if hits.Load() != firstHits { |
|||
t.Fatalf("cache miss: expected %d HTTP hits, got %d", firstHits, hits.Load()) |
|||
} |
|||
} |
|||
|
|||
func TestAutoDial_StickyAfterUDPFailure(t *testing.T) { |
|||
// DoH backend: always responds with a valid wire-format reply.
|
|||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|||
body := readWire(t, r.Body) |
|||
w.Header().Set("Content-Type", "application/dns-message") |
|||
_, _ = w.Write(dohAnswer(t, body, net.ParseIP("9.9.9.9"), 300)) |
|||
})) |
|||
defer srv.Close() |
|||
|
|||
resolver := newDohResolverWithClient( |
|||
[]DohEndpoint{{URL: srv.URL, Hostname: "mock", BootstrapIPs: []string{"127.0.0.1"}}}, |
|||
srv.Client(), |
|||
) |
|||
|
|||
dial := autoDial(resolver) |
|||
|
|||
// Poison udpDNSServers so that udpProbe (real DNS round-trip) fails
|
|||
// immediately — net.DialTimeout rejects the malformed address.
|
|||
old := udpDNSServers |
|||
udpDNSServers = []string{"not-a-valid-host-port"} |
|||
defer func() { udpDNSServers = old }() |
|||
|
|||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) |
|||
defer cancel() |
|||
|
|||
conn1, err := dial(ctx, "udp", "unused") |
|||
if err != nil { |
|||
t.Fatalf("first dial: %v", err) |
|||
} |
|||
_ = conn1.Close() |
|||
|
|||
// Second call must skip UDP entirely. We assert this by poisoning
|
|||
// udpDNSServers with a value that would fail parsing — if the dialer
|
|||
// touches UDP again the call errors loudly.
|
|||
udpDNSServers = []string{"still-not-a-valid-host-port"} |
|||
conn2, err := dial(ctx, "udp", "unused") |
|||
if err != nil { |
|||
t.Fatalf("second dial: %v", err) |
|||
} |
|||
_ = conn2.Close() |
|||
} |
|||
File diff suppressed because it is too large
Loading…
Reference in new issue