committed by
GitHub
8 changed files with 2237 additions and 518 deletions
@ -0,0 +1,601 @@ |
|||||
|
// DNS-over-HTTPS resolver for mobile networks where UDP/53 is blocked or
|
||||
|
// spoofed.
|
||||
|
|
||||
|
package main |
||||
|
|
||||
|
import ( |
||||
|
"bytes" |
||||
|
"context" |
||||
|
"crypto/tls" |
||||
|
"errors" |
||||
|
"fmt" |
||||
|
"io" |
||||
|
"log" |
||||
|
"net" |
||||
|
"net/http" |
||||
|
"sort" |
||||
|
"sync" |
||||
|
"sync/atomic" |
||||
|
"time" |
||||
|
|
||||
|
"github.com/miekg/dns" |
||||
|
|
||||
|
// Embedded Mozilla CA roots for CGO_ENABLED=0 builds (Android).
|
||||
|
_ "golang.org/x/crypto/x509roots/fallback" |
||||
|
) |
||||
|
|
||||
|
const ( |
||||
|
dohQueryTimeout = 6 * time.Second |
||||
|
dohCacheMinTTL = 10 * time.Second |
||||
|
dohCacheMaxTTL = 1 * time.Hour |
||||
|
dohMaxResponseBytes = 64 * 1024 |
||||
|
dohContentType = "application/dns-message" |
||||
|
|
||||
|
dohDialerTimeout = 5 * time.Second |
||||
|
dohDialerKeepAlive = 30 * time.Second |
||||
|
appDialerTimeout = 20 * time.Second |
||||
|
appDialerKeepAlive = 30 * time.Second |
||||
|
|
||||
|
forwarderUDPBufSize = 4096 |
||||
|
forwarderTCPReadDL = 30 * time.Second |
||||
|
forwarderTCPWriteDL = 10 * time.Second |
||||
|
autoUDPBudget = 1500 * time.Millisecond |
||||
|
) |
||||
|
|
||||
|
// DohEndpoint describes a single DNS-over-HTTPS server together with the IPs
|
||||
|
// we bootstrap to — so that resolving the endpoint hostname does not itself
|
||||
|
// require DNS.
|
||||
|
type DohEndpoint struct { |
||||
|
URL string |
||||
|
Hostname string |
||||
|
BootstrapIPs []string |
||||
|
} |
||||
|
|
||||
|
// Yandex is tried first because it tends to stay reachable on RU mobile
|
||||
|
// operators even when international resolvers get blocked; Google and
|
||||
|
// Cloudflare follow as fallbacks.
|
||||
|
var defaultDohEndpoints = []DohEndpoint{ |
||||
|
{"https://common.dot.dns.yandex.net/dns-query", "common.dot.dns.yandex.net", []string{"77.88.8.8", "77.88.8.1"}}, |
||||
|
{"https://secure.dot.dns.yandex.net/dns-query", "secure.dot.dns.yandex.net", []string{"77.88.8.88", "77.88.8.2"}}, |
||||
|
{"https://family.dot.dns.yandex.net/dns-query", "family.dot.dns.yandex.net", []string{"77.88.8.7", "77.88.8.3"}}, |
||||
|
{"https://dns.google/dns-query", "dns.google", []string{"8.8.8.8", "8.8.4.4"}}, |
||||
|
{"https://cloudflare-dns.com/dns-query", "cloudflare-dns.com", []string{"1.1.1.1", "1.0.0.1"}}, |
||||
|
} |
||||
|
|
||||
|
// DohResolver resolves hostnames to IPs via DNS-over-HTTPS (RFC 8484).
|
||||
|
type DohResolver struct { |
||||
|
endpoints []DohEndpoint |
||||
|
client *http.Client |
||||
|
cache *dohCache |
||||
|
} |
||||
|
|
||||
|
// NewDohResolver constructs a resolver using defaultDohEndpoints if endpoints
|
||||
|
// is nil. Endpoint hostnames are dialed by IP using BootstrapIPs, so the DoH
|
||||
|
// transport never depends on the system resolver.
|
||||
|
func NewDohResolver(endpoints []DohEndpoint) *DohResolver { |
||||
|
if len(endpoints) == 0 { |
||||
|
endpoints = defaultDohEndpoints |
||||
|
} |
||||
|
return &DohResolver{ |
||||
|
endpoints: endpoints, |
||||
|
client: &http.Client{Timeout: dohQueryTimeout, Transport: newBootstrapTransport(endpoints)}, |
||||
|
cache: newDohCache(), |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// newDohResolverWithClient is a test hook that skips the bootstrap transport.
|
||||
|
func newDohResolverWithClient(endpoints []DohEndpoint, client *http.Client) *DohResolver { |
||||
|
return &DohResolver{endpoints: endpoints, client: client, cache: newDohCache()} |
||||
|
} |
||||
|
|
||||
|
// newBootstrapTransport returns an http.Transport whose DialContext only
|
||||
|
// knows how to reach the configured DoH endpoint hostnames, by mapping each
|
||||
|
// to its BootstrapIPs.
|
||||
|
func newBootstrapTransport(endpoints []DohEndpoint) *http.Transport { |
||||
|
bootstrap := make(map[string][]string, len(endpoints)) |
||||
|
for _, ep := range endpoints { |
||||
|
bootstrap[ep.Hostname] = ep.BootstrapIPs |
||||
|
} |
||||
|
dialer := &net.Dialer{Timeout: dohDialerTimeout, KeepAlive: dohDialerKeepAlive} |
||||
|
|
||||
|
return &http.Transport{ |
||||
|
MaxIdleConns: 8, |
||||
|
MaxIdleConnsPerHost: 2, |
||||
|
IdleConnTimeout: 90 * time.Second, |
||||
|
ForceAttemptHTTP2: true, |
||||
|
TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12}, |
||||
|
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { |
||||
|
host, port, err := net.SplitHostPort(addr) |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
ips, ok := bootstrap[host] |
||||
|
if !ok { |
||||
|
return nil, fmt.Errorf("doh: no bootstrap IPs for %q", host) |
||||
|
} |
||||
|
var lastErr error |
||||
|
for _, ip := range ips { |
||||
|
conn, derr := dialer.DialContext(ctx, network, net.JoinHostPort(ip, port)) |
||||
|
if derr == nil { |
||||
|
return conn, nil |
||||
|
} |
||||
|
lastErr = derr |
||||
|
} |
||||
|
return nil, lastErr |
||||
|
}, |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// LookupIPAddr resolves host to a combined list of A+AAAA IPs (IPv4 first).
|
||||
|
// Cached results bypass the network entirely.
|
||||
|
func (r *DohResolver) LookupIPAddr(ctx context.Context, host string) ([]net.IP, error) { |
||||
|
if ip := net.ParseIP(host); ip != nil { |
||||
|
return []net.IP{ip}, nil |
||||
|
} |
||||
|
if ips, ok := r.cache.get(host); ok { |
||||
|
return ips, nil |
||||
|
} |
||||
|
|
||||
|
type res struct { |
||||
|
ips []net.IP |
||||
|
ttl time.Duration |
||||
|
err error |
||||
|
} |
||||
|
results := make(chan res, 2) |
||||
|
for _, qt := range [...]uint16{dns.TypeA, dns.TypeAAAA} { |
||||
|
go func(qtype uint16) { |
||||
|
ips, ttl, err := r.queryIPs(ctx, host, qtype) |
||||
|
results <- res{ips, ttl, err} |
||||
|
}(qt) |
||||
|
} |
||||
|
|
||||
|
var ( |
||||
|
all []net.IP |
||||
|
lastErr error |
||||
|
minTTL = dohCacheMaxTTL |
||||
|
) |
||||
|
for range 2 { |
||||
|
rr := <-results |
||||
|
if rr.err != nil { |
||||
|
lastErr = rr.err |
||||
|
continue |
||||
|
} |
||||
|
all = append(all, rr.ips...) |
||||
|
if rr.ttl > 0 && rr.ttl < minTTL { |
||||
|
minTTL = rr.ttl |
||||
|
} |
||||
|
} |
||||
|
if len(all) == 0 { |
||||
|
if lastErr == nil { |
||||
|
lastErr = fmt.Errorf("doh: no records for %s", host) |
||||
|
} |
||||
|
return nil, lastErr |
||||
|
} |
||||
|
|
||||
|
// IPv4 before IPv6 — better compat with mobile IPv4-only CGNAT.
|
||||
|
sort.SliceStable(all, func(i, j int) bool { |
||||
|
return (all[i].To4() != nil) && (all[j].To4() == nil) |
||||
|
}) |
||||
|
|
||||
|
if minTTL < dohCacheMinTTL { |
||||
|
minTTL = dohCacheMinTTL |
||||
|
} |
||||
|
r.cache.set(host, all, minTTL) |
||||
|
return all, nil |
||||
|
} |
||||
|
|
||||
|
// queryIPs issues one DoH query for qtype, walking endpoints until one
|
||||
|
// succeeds, and parses the wire reply into IPs + min TTL.
|
||||
|
func (r *DohResolver) queryIPs(ctx context.Context, host string, qtype uint16) ([]net.IP, time.Duration, error) { |
||||
|
m := new(dns.Msg) |
||||
|
m.SetQuestion(dns.Fqdn(host), qtype) |
||||
|
m.Id = 0 // RFC 8484 §4.1 — zero ID is cache-friendly on shared caches.
|
||||
|
m.RecursionDesired = true |
||||
|
wire, err := m.Pack() |
||||
|
if err != nil { |
||||
|
return nil, 0, fmt.Errorf("doh: pack query: %w", err) |
||||
|
} |
||||
|
|
||||
|
body, ep, err := r.forwardRaw(ctx, wire) |
||||
|
if err != nil { |
||||
|
return nil, 0, err |
||||
|
} |
||||
|
ips, ttl, err := parseAnswer(body) |
||||
|
if err != nil { |
||||
|
return nil, 0, fmt.Errorf("doh: parse %s: %w", ep.Hostname, err) |
||||
|
} |
||||
|
log.Printf("[DoH] %s %s via %s → %d IPs (ttl %s)", host, dns.TypeToString[qtype], ep.Hostname, len(ips), ttl) |
||||
|
return ips, ttl, nil |
||||
|
} |
||||
|
|
||||
|
// parseAnswer decodes a DNS wire reply into A/AAAA records and the minimum TTL.
|
||||
|
func parseAnswer(body []byte) ([]net.IP, time.Duration, error) { |
||||
|
reply := new(dns.Msg) |
||||
|
if err := reply.Unpack(body); err != nil { |
||||
|
return nil, 0, fmt.Errorf("unpack: %w", err) |
||||
|
} |
||||
|
if reply.Rcode != dns.RcodeSuccess { |
||||
|
return nil, 0, fmt.Errorf("rcode %s", dns.RcodeToString[reply.Rcode]) |
||||
|
} |
||||
|
var ( |
||||
|
ips []net.IP |
||||
|
minTTL uint32 |
||||
|
) |
||||
|
updateTTL := func(ttl uint32) { |
||||
|
if minTTL == 0 || ttl < minTTL { |
||||
|
minTTL = ttl |
||||
|
} |
||||
|
} |
||||
|
for _, ans := range reply.Answer { |
||||
|
switch a := ans.(type) { |
||||
|
case *dns.A: |
||||
|
ips = append(ips, a.A) |
||||
|
updateTTL(a.Hdr.Ttl) |
||||
|
case *dns.AAAA: |
||||
|
ips = append(ips, a.AAAA) |
||||
|
updateTTL(a.Hdr.Ttl) |
||||
|
} |
||||
|
} |
||||
|
return ips, time.Duration(minTTL) * time.Second, nil |
||||
|
} |
||||
|
|
||||
|
// forwardRaw POSTs an opaque DNS-wire query to the configured DoH endpoints
|
||||
|
// in order and returns the first successful raw response together with the
|
||||
|
// endpoint that produced it. No parsing — useful for the local forwarder
|
||||
|
// which needs to pass through whatever the upstream resolver answers
|
||||
|
// (RESINFO/HTTPS/SVCB/EDNS options/…).
|
||||
|
func (r *DohResolver) forwardRaw(ctx context.Context, query []byte) ([]byte, DohEndpoint, error) { |
||||
|
if len(r.endpoints) == 0 { |
||||
|
return nil, DohEndpoint{}, errors.New("doh: no endpoints configured") |
||||
|
} |
||||
|
var lastErr error |
||||
|
for _, ep := range r.endpoints { |
||||
|
body, err := r.postWire(ctx, ep, query) |
||||
|
if err != nil { |
||||
|
log.Printf("[DoH] %s: %v", ep.Hostname, err) |
||||
|
lastErr = err |
||||
|
continue |
||||
|
} |
||||
|
return body, ep, nil |
||||
|
} |
||||
|
return nil, DohEndpoint{}, lastErr |
||||
|
} |
||||
|
|
||||
|
// postWire performs a single application/dns-message POST to one endpoint.
|
||||
|
func (r *DohResolver) postWire(ctx context.Context, ep DohEndpoint, query []byte) ([]byte, error) { |
||||
|
req, err := http.NewRequestWithContext(ctx, "POST", ep.URL, bytes.NewReader(query)) |
||||
|
if err != nil { |
||||
|
return nil, fmt.Errorf("build request: %w", err) |
||||
|
} |
||||
|
req.Header.Set("Content-Type", dohContentType) |
||||
|
req.Header.Set("Accept", dohContentType) |
||||
|
|
||||
|
resp, err := r.client.Do(req) |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
defer func() { _ = resp.Body.Close() }() |
||||
|
|
||||
|
if resp.StatusCode != http.StatusOK { |
||||
|
_, _ = io.Copy(io.Discard, resp.Body) |
||||
|
return nil, fmt.Errorf("HTTP %d", resp.StatusCode) |
||||
|
} |
||||
|
body, err := io.ReadAll(io.LimitReader(resp.Body, dohMaxResponseBytes)) |
||||
|
if err != nil { |
||||
|
return nil, fmt.Errorf("read body: %w", err) |
||||
|
} |
||||
|
return body, nil |
||||
|
} |
||||
|
|
||||
|
type dohCacheEntry struct { |
||||
|
ips []net.IP |
||||
|
expiry time.Time |
||||
|
} |
||||
|
|
||||
|
type dohCache struct { |
||||
|
mu sync.RWMutex |
||||
|
m map[string]dohCacheEntry |
||||
|
} |
||||
|
|
||||
|
func newDohCache() *dohCache { |
||||
|
return &dohCache{m: make(map[string]dohCacheEntry)} |
||||
|
} |
||||
|
|
||||
|
func (c *dohCache) get(host string) ([]net.IP, bool) { |
||||
|
c.mu.RLock() |
||||
|
e, ok := c.m[host] |
||||
|
c.mu.RUnlock() |
||||
|
if !ok || time.Now().After(e.expiry) { |
||||
|
return nil, false |
||||
|
} |
||||
|
out := make([]net.IP, len(e.ips)) |
||||
|
copy(out, e.ips) |
||||
|
return out, true |
||||
|
} |
||||
|
|
||||
|
func (c *dohCache) set(host string, ips []net.IP, ttl time.Duration) { |
||||
|
if ttl <= 0 { |
||||
|
return |
||||
|
} |
||||
|
if ttl > dohCacheMaxTTL { |
||||
|
ttl = dohCacheMaxTTL |
||||
|
} |
||||
|
cp := make([]net.IP, len(ips)) |
||||
|
copy(cp, ips) |
||||
|
c.mu.Lock() |
||||
|
c.m[host] = dohCacheEntry{ips: cp, expiry: time.Now().Add(ttl)} |
||||
|
c.mu.Unlock() |
||||
|
} |
||||
|
|
||||
|
// Go's net.Resolver dials this stub like a regular nameserver, which avoids
|
||||
|
// the many edge cases of a fake-net.Conn approach (RESINFO probes, EDNS
|
||||
|
// handshakes, truncation, …). Whatever it reads on UDP/TCP is sent verbatim
|
||||
|
// to a DoH endpoint and the wire response is sent back to the client.
|
||||
|
|
||||
|
type dohForwarder struct { |
||||
|
udpAddr string |
||||
|
tcpAddr string |
||||
|
} |
||||
|
|
||||
|
var ( |
||||
|
dohForwarderOnce sync.Once |
||||
|
dohForwarderInst *dohForwarder |
||||
|
dohForwarderErr error |
||||
|
) |
||||
|
|
||||
|
// sharedDohForwarder lazily starts a process-wide forwarder bound to the
|
||||
|
// supplied resolver. The first caller wins; subsequent callers reuse the
|
||||
|
// same forwarder regardless of what they pass in.
|
||||
|
func sharedDohForwarder(r *DohResolver) (*dohForwarder, error) { |
||||
|
dohForwarderOnce.Do(func() { |
||||
|
dohForwarderInst, dohForwarderErr = startDohForwarder(r) |
||||
|
}) |
||||
|
return dohForwarderInst, dohForwarderErr |
||||
|
} |
||||
|
|
||||
|
func startDohForwarder(r *DohResolver) (_ *dohForwarder, err error) { |
||||
|
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) |
||||
|
if err != nil { |
||||
|
return nil, fmt.Errorf("doh forwarder: listen UDP: %w", err) |
||||
|
} |
||||
|
defer func() { |
||||
|
if err != nil { |
||||
|
_ = udpConn.Close() |
||||
|
} |
||||
|
}() |
||||
|
tcpLn, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) |
||||
|
if err != nil { |
||||
|
return nil, fmt.Errorf("doh forwarder: listen TCP: %w", err) |
||||
|
} |
||||
|
defer func() { |
||||
|
if err != nil { |
||||
|
_ = tcpLn.Close() |
||||
|
} |
||||
|
}() |
||||
|
|
||||
|
fwd := &dohForwarder{ |
||||
|
udpAddr: udpConn.LocalAddr().String(), |
||||
|
tcpAddr: tcpLn.Addr().String(), |
||||
|
} |
||||
|
log.Printf("[DoH] forwarder listening udp=%s tcp=%s", fwd.udpAddr, fwd.tcpAddr) |
||||
|
|
||||
|
go fwd.serveUDP(udpConn, r) |
||||
|
go fwd.serveTCP(tcpLn, r) |
||||
|
return fwd, nil |
||||
|
} |
||||
|
|
||||
|
func (f *dohForwarder) serveUDP(conn *net.UDPConn, r *DohResolver) { |
||||
|
defer func() { _ = conn.Close() }() |
||||
|
buf := make([]byte, forwarderUDPBufSize) |
||||
|
for { |
||||
|
n, client, err := conn.ReadFromUDP(buf) |
||||
|
if err != nil { |
||||
|
log.Printf("[DoH] udp read: %v", err) |
||||
|
return |
||||
|
} |
||||
|
query := append([]byte(nil), buf[:n]...) |
||||
|
go func(q []byte, c *net.UDPAddr) { |
||||
|
ctx, cancel := context.WithTimeout(context.Background(), dohQueryTimeout) |
||||
|
defer cancel() |
||||
|
resp, _, err := r.forwardRaw(ctx, q) |
||||
|
if err != nil { |
||||
|
log.Printf("[DoH] udp forward failed: %v", err) |
||||
|
return |
||||
|
} |
||||
|
if _, err := conn.WriteToUDP(resp, c); err != nil { |
||||
|
log.Printf("[DoH] udp write: %v", err) |
||||
|
} |
||||
|
}(query, client) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func (f *dohForwarder) serveTCP(ln *net.TCPListener, r *DohResolver) { |
||||
|
defer func() { _ = ln.Close() }() |
||||
|
for { |
||||
|
conn, err := ln.Accept() |
||||
|
if err != nil { |
||||
|
log.Printf("[DoH] tcp accept: %v", err) |
||||
|
return |
||||
|
} |
||||
|
go handleDohForwarderTCP(conn, r) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func handleDohForwarderTCP(conn net.Conn, r *DohResolver) { |
||||
|
defer func() { _ = conn.Close() }() |
||||
|
for { |
||||
|
_ = conn.SetReadDeadline(time.Now().Add(forwarderTCPReadDL)) |
||||
|
var lenBuf [2]byte |
||||
|
if _, err := io.ReadFull(conn, lenBuf[:]); err != nil { |
||||
|
return |
||||
|
} |
||||
|
qlen := int(lenBuf[0])<<8 | int(lenBuf[1]) |
||||
|
if qlen == 0 || qlen > forwarderUDPBufSize { |
||||
|
return |
||||
|
} |
||||
|
query := make([]byte, qlen) |
||||
|
if _, err := io.ReadFull(conn, query); err != nil { |
||||
|
return |
||||
|
} |
||||
|
|
||||
|
ctx, cancel := context.WithTimeout(context.Background(), dohQueryTimeout) |
||||
|
resp, _, err := r.forwardRaw(ctx, query) |
||||
|
cancel() |
||||
|
if err != nil { |
||||
|
log.Printf("[DoH] tcp forward failed: %v", err) |
||||
|
return |
||||
|
} |
||||
|
out := make([]byte, 2+len(resp)) |
||||
|
out[0] = byte(len(resp) >> 8) |
||||
|
out[1] = byte(len(resp)) |
||||
|
copy(out[2:], resp) |
||||
|
_ = conn.SetWriteDeadline(time.Now().Add(forwarderTCPWriteDL)) |
||||
|
if _, err := conn.Write(out); err != nil { |
||||
|
return |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// dohForwarderDial returns a Resolver.Dial that connects to the local DoH
|
||||
|
// forwarder over UDP or TCP (whichever the resolver asked for).
|
||||
|
func dohForwarderDial(r *DohResolver) dialFunc { |
||||
|
return func(ctx context.Context, network, _ string) (net.Conn, error) { |
||||
|
fwd, err := sharedDohForwarder(r) |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
var d net.Dialer |
||||
|
switch network { |
||||
|
case "tcp", "tcp4", "tcp6": |
||||
|
return d.DialContext(ctx, "tcp", fwd.tcpAddr) |
||||
|
default: |
||||
|
return d.DialContext(ctx, "udp", fwd.udpAddr) |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
const ( |
||||
|
DNSModeUDP = "udp" |
||||
|
DNSModeDoH = "doh" |
||||
|
DNSModeAuto = "auto" |
||||
|
) |
||||
|
|
||||
|
var udpDNSServers = []string{ |
||||
|
"77.88.8.8:53", "77.88.8.1:53", |
||||
|
"8.8.8.8:53", "8.8.4.4:53", |
||||
|
"1.1.1.1:53", "1.0.0.1:53", |
||||
|
} |
||||
|
|
||||
|
type dialFunc = func(context.Context, string, string) (net.Conn, error) |
||||
|
|
||||
|
// buildDialer returns a net.Dialer whose internal Go resolver uses the
|
||||
|
// chosen DNS transport. In "auto" mode the first total-failure of UDP/53
|
||||
|
// sticks the process onto DoH for the rest of its lifetime.
|
||||
|
func buildDialer(mode string, r *DohResolver) net.Dialer { |
||||
|
switch mode { |
||||
|
case DNSModeUDP: |
||||
|
return newAppDialer(udpDNSDial) |
||||
|
case DNSModeDoH: |
||||
|
return newAppDialer(dohForwarderDial(r)) |
||||
|
case DNSModeAuto: |
||||
|
return newAppDialer(autoDial(r)) |
||||
|
default: |
||||
|
log.Panicf("unknown DNS mode %q", mode) |
||||
|
return net.Dialer{} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// newAppDialer wraps a Resolver.Dial with the timeouts used everywhere in
|
||||
|
// the app for outbound TCP/HTTP connections.
|
||||
|
func newAppDialer(dial dialFunc) net.Dialer { |
||||
|
return net.Dialer{ |
||||
|
Timeout: appDialerTimeout, |
||||
|
KeepAlive: appDialerKeepAlive, |
||||
|
Resolver: &net.Resolver{PreferGo: true, Dial: dial}, |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// udpDNSDial picks the first reachable UDP/53 resolver from udpDNSServers.
|
||||
|
func udpDNSDial(ctx context.Context, _ string, _ string) (net.Conn, error) { |
||||
|
var ( |
||||
|
d net.Dialer |
||||
|
lastErr error |
||||
|
) |
||||
|
for _, s := range udpDNSServers { |
||||
|
conn, err := d.DialContext(ctx, "udp", s) |
||||
|
if err == nil { |
||||
|
return conn, nil |
||||
|
} |
||||
|
lastErr = err |
||||
|
} |
||||
|
if lastErr == nil { |
||||
|
lastErr = errors.New("no UDP DNS servers available") |
||||
|
} |
||||
|
return nil, lastErr |
||||
|
} |
||||
|
|
||||
|
// autoDial returns a Dial that probes UDP/53 once with a real DNS round-trip;
|
||||
|
// if the probe fails it latches onto DoH for the rest of the process. Built
|
||||
|
// for Android, where the network can flip between Wi-Fi (UDP/53 works) and
|
||||
|
// mobile (UDP/53 blocked).
|
||||
|
//
|
||||
|
// A simple dial-timeout doesn't work for UDP because UDP "dial" is
|
||||
|
// connectionless and always succeeds instantly. The only way to know whether
|
||||
|
// UDP/53 actually works is to send a real query and wait for a response.
|
||||
|
func autoDial(r *DohResolver) dialFunc { |
||||
|
var ( |
||||
|
probed sync.Once |
||||
|
useDoH atomic.Bool |
||||
|
doh = dohForwarderDial(r) |
||||
|
) |
||||
|
return func(ctx context.Context, network, addr string) (net.Conn, error) { |
||||
|
probed.Do(func() { |
||||
|
if udpProbe(autoUDPBudget) { |
||||
|
log.Printf("[DNS] UDP/53 probe OK, using UDP") |
||||
|
} else { |
||||
|
log.Printf("[DNS] UDP/53 unreachable; sticky-switching to DoH") |
||||
|
useDoH.Store(true) |
||||
|
} |
||||
|
}) |
||||
|
if useDoH.Load() { |
||||
|
return doh(ctx, network, addr) |
||||
|
} |
||||
|
return udpDNSDial(ctx, network, addr) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// udpProbe sends a real DNS A query for a well-known domain via UDP and
|
||||
|
// checks whether any response arrives within the deadline. We try the first
|
||||
|
// two servers from udpDNSServers under a shared deadline — if neither
|
||||
|
// responds, UDP/53 is blocked.
|
||||
|
func udpProbe(timeout time.Duration) bool { |
||||
|
m := new(dns.Msg) |
||||
|
m.SetQuestion("dns.google.", dns.TypeA) |
||||
|
m.RecursionDesired = true |
||||
|
wire, err := m.Pack() |
||||
|
if err != nil { |
||||
|
return false |
||||
|
} |
||||
|
|
||||
|
deadline := time.Now().Add(timeout) |
||||
|
buf := make([]byte, 512) |
||||
|
limit := min(len(udpDNSServers), 2) |
||||
|
for _, server := range udpDNSServers[:limit] { |
||||
|
remaining := time.Until(deadline) |
||||
|
if remaining <= 0 { |
||||
|
break |
||||
|
} |
||||
|
conn, err := net.DialTimeout("udp", server, remaining) |
||||
|
if err != nil { |
||||
|
continue |
||||
|
} |
||||
|
_ = conn.SetDeadline(deadline) |
||||
|
_, _ = conn.Write(wire) |
||||
|
n, err := conn.Read(buf) |
||||
|
_ = conn.Close() |
||||
|
if err == nil && n > 12 { |
||||
|
return true |
||||
|
} |
||||
|
} |
||||
|
return false |
||||
|
} |
||||
@ -0,0 +1,197 @@ |
|||||
|
package main |
||||
|
|
||||
|
import ( |
||||
|
"context" |
||||
|
"io" |
||||
|
"net" |
||||
|
"net/http" |
||||
|
"net/http/httptest" |
||||
|
"sync/atomic" |
||||
|
"testing" |
||||
|
"time" |
||||
|
|
||||
|
"github.com/miekg/dns" |
||||
|
) |
||||
|
|
||||
|
// dohAnswer builds a wire-format DNS reply for a single question with one
|
||||
|
// answer of the matching type (A or AAAA). TTL is returned as-is.
|
||||
|
func dohAnswer(t *testing.T, query []byte, ip net.IP, ttl uint32) []byte { |
||||
|
t.Helper() |
||||
|
req := new(dns.Msg) |
||||
|
if err := req.Unpack(query); err != nil { |
||||
|
t.Fatalf("unpack query: %v", err) |
||||
|
} |
||||
|
reply := new(dns.Msg) |
||||
|
reply.SetReply(req) |
||||
|
if len(req.Question) != 1 { |
||||
|
t.Fatalf("expected 1 question, got %d", len(req.Question)) |
||||
|
} |
||||
|
q := req.Question[0] |
||||
|
switch q.Qtype { |
||||
|
case dns.TypeA: |
||||
|
if v4 := ip.To4(); v4 != nil { |
||||
|
reply.Answer = append(reply.Answer, &dns.A{ |
||||
|
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: ttl}, |
||||
|
A: v4, |
||||
|
}) |
||||
|
} |
||||
|
case dns.TypeAAAA: |
||||
|
if ip.To4() == nil { |
||||
|
reply.Answer = append(reply.Answer, &dns.AAAA{ |
||||
|
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: ttl}, |
||||
|
AAAA: ip, |
||||
|
}) |
||||
|
} |
||||
|
} |
||||
|
out, err := reply.Pack() |
||||
|
if err != nil { |
||||
|
t.Fatalf("pack reply: %v", err) |
||||
|
} |
||||
|
return out |
||||
|
} |
||||
|
|
||||
|
func readWire(t *testing.T, r io.Reader) []byte { |
||||
|
t.Helper() |
||||
|
b, err := io.ReadAll(r) |
||||
|
if err != nil { |
||||
|
t.Fatalf("read body: %v", err) |
||||
|
} |
||||
|
return b |
||||
|
} |
||||
|
|
||||
|
func TestDohResolver_LookupIPAddr_Success(t *testing.T) { |
||||
|
var hits atomic.Int32 |
||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
||||
|
hits.Add(1) |
||||
|
if ct := r.Header.Get("Content-Type"); ct != "application/dns-message" { |
||||
|
t.Errorf("wrong Content-Type: %q", ct) |
||||
|
} |
||||
|
body := readWire(t, r.Body) |
||||
|
w.Header().Set("Content-Type", "application/dns-message") |
||||
|
w.WriteHeader(http.StatusOK) |
||||
|
_, _ = w.Write(dohAnswer(t, body, net.ParseIP("93.184.216.34"), 300)) |
||||
|
})) |
||||
|
defer srv.Close() |
||||
|
|
||||
|
r := newDohResolverWithClient( |
||||
|
[]DohEndpoint{{URL: srv.URL, Hostname: "mock", BootstrapIPs: []string{"127.0.0.1"}}}, |
||||
|
srv.Client(), |
||||
|
) |
||||
|
|
||||
|
ips, err := r.LookupIPAddr(context.Background(), "example.com") |
||||
|
if err != nil { |
||||
|
t.Fatalf("lookup: %v", err) |
||||
|
} |
||||
|
if len(ips) == 0 { |
||||
|
t.Fatalf("no ips returned") |
||||
|
} |
||||
|
if ips[0].String() != "93.184.216.34" { |
||||
|
t.Fatalf("unexpected ip %s", ips[0]) |
||||
|
} |
||||
|
// Two concurrent queries fire (A + AAAA), so we expect 2 hits.
|
||||
|
if got := hits.Load(); got != 2 { |
||||
|
t.Fatalf("expected 2 HTTP hits, got %d", got) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestDohResolver_Fallback(t *testing.T) { |
||||
|
var firstHits, secondHits atomic.Int32 |
||||
|
first := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
||||
|
firstHits.Add(1) |
||||
|
w.WriteHeader(http.StatusInternalServerError) |
||||
|
})) |
||||
|
defer first.Close() |
||||
|
second := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
||||
|
secondHits.Add(1) |
||||
|
body := readWire(t, r.Body) |
||||
|
w.Header().Set("Content-Type", "application/dns-message") |
||||
|
_, _ = w.Write(dohAnswer(t, body, net.ParseIP("1.2.3.4"), 300)) |
||||
|
})) |
||||
|
defer second.Close() |
||||
|
|
||||
|
r := newDohResolverWithClient( |
||||
|
[]DohEndpoint{ |
||||
|
{URL: first.URL, Hostname: "first", BootstrapIPs: []string{"127.0.0.1"}}, |
||||
|
{URL: second.URL, Hostname: "second", BootstrapIPs: []string{"127.0.0.1"}}, |
||||
|
}, |
||||
|
first.Client(), |
||||
|
) |
||||
|
ips, err := r.LookupIPAddr(context.Background(), "example.com") |
||||
|
if err != nil { |
||||
|
t.Fatalf("lookup: %v", err) |
||||
|
} |
||||
|
if len(ips) != 1 || ips[0].String() != "1.2.3.4" { |
||||
|
t.Fatalf("unexpected ips: %v", ips) |
||||
|
} |
||||
|
if firstHits.Load() == 0 || secondHits.Load() == 0 { |
||||
|
t.Fatalf("fallback did not probe both endpoints: first=%d second=%d", firstHits.Load(), secondHits.Load()) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestDohResolver_Cache(t *testing.T) { |
||||
|
var hits atomic.Int32 |
||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
||||
|
hits.Add(1) |
||||
|
body := readWire(t, r.Body) |
||||
|
w.Header().Set("Content-Type", "application/dns-message") |
||||
|
_, _ = w.Write(dohAnswer(t, body, net.ParseIP("5.6.7.8"), 300)) |
||||
|
})) |
||||
|
defer srv.Close() |
||||
|
|
||||
|
r := newDohResolverWithClient( |
||||
|
[]DohEndpoint{{URL: srv.URL, Hostname: "mock", BootstrapIPs: []string{"127.0.0.1"}}}, |
||||
|
srv.Client(), |
||||
|
) |
||||
|
if _, err := r.LookupIPAddr(context.Background(), "example.com"); err != nil { |
||||
|
t.Fatalf("first lookup: %v", err) |
||||
|
} |
||||
|
firstHits := hits.Load() |
||||
|
if _, err := r.LookupIPAddr(context.Background(), "example.com"); err != nil { |
||||
|
t.Fatalf("second lookup: %v", err) |
||||
|
} |
||||
|
if hits.Load() != firstHits { |
||||
|
t.Fatalf("cache miss: expected %d HTTP hits, got %d", firstHits, hits.Load()) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestAutoDial_StickyAfterUDPFailure(t *testing.T) { |
||||
|
// DoH backend: always responds with a valid wire-format reply.
|
||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
||||
|
body := readWire(t, r.Body) |
||||
|
w.Header().Set("Content-Type", "application/dns-message") |
||||
|
_, _ = w.Write(dohAnswer(t, body, net.ParseIP("9.9.9.9"), 300)) |
||||
|
})) |
||||
|
defer srv.Close() |
||||
|
|
||||
|
resolver := newDohResolverWithClient( |
||||
|
[]DohEndpoint{{URL: srv.URL, Hostname: "mock", BootstrapIPs: []string{"127.0.0.1"}}}, |
||||
|
srv.Client(), |
||||
|
) |
||||
|
|
||||
|
dial := autoDial(resolver) |
||||
|
|
||||
|
// Poison udpDNSServers so that udpProbe (real DNS round-trip) fails
|
||||
|
// immediately — net.DialTimeout rejects the malformed address.
|
||||
|
old := udpDNSServers |
||||
|
udpDNSServers = []string{"not-a-valid-host-port"} |
||||
|
defer func() { udpDNSServers = old }() |
||||
|
|
||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) |
||||
|
defer cancel() |
||||
|
|
||||
|
conn1, err := dial(ctx, "udp", "unused") |
||||
|
if err != nil { |
||||
|
t.Fatalf("first dial: %v", err) |
||||
|
} |
||||
|
_ = conn1.Close() |
||||
|
|
||||
|
// Second call must skip UDP entirely. We assert this by poisoning
|
||||
|
// udpDNSServers with a value that would fail parsing — if the dialer
|
||||
|
// touches UDP again the call errors loudly.
|
||||
|
udpDNSServers = []string{"still-not-a-valid-host-port"} |
||||
|
conn2, err := dial(ctx, "udp", "unused") |
||||
|
if err != nil { |
||||
|
t.Fatalf("second dial: %v", err) |
||||
|
} |
||||
|
_ = conn2.Close() |
||||
|
} |
||||
File diff suppressed because it is too large
Loading…
Reference in new issue