You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
197 lines
5.7 KiB
197 lines
5.7 KiB
package main
|
|
|
|
import (
|
|
"context"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/miekg/dns"
|
|
)
|
|
|
|
// dohAnswer builds a wire-format DNS reply for a single question with one
|
|
// answer of the matching type (A or AAAA). TTL is returned as-is.
|
|
func dohAnswer(t *testing.T, query []byte, ip net.IP, ttl uint32) []byte {
|
|
t.Helper()
|
|
req := new(dns.Msg)
|
|
if err := req.Unpack(query); err != nil {
|
|
t.Fatalf("unpack query: %v", err)
|
|
}
|
|
reply := new(dns.Msg)
|
|
reply.SetReply(req)
|
|
if len(req.Question) != 1 {
|
|
t.Fatalf("expected 1 question, got %d", len(req.Question))
|
|
}
|
|
q := req.Question[0]
|
|
switch q.Qtype {
|
|
case dns.TypeA:
|
|
if v4 := ip.To4(); v4 != nil {
|
|
reply.Answer = append(reply.Answer, &dns.A{
|
|
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: ttl},
|
|
A: v4,
|
|
})
|
|
}
|
|
case dns.TypeAAAA:
|
|
if ip.To4() == nil {
|
|
reply.Answer = append(reply.Answer, &dns.AAAA{
|
|
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: ttl},
|
|
AAAA: ip,
|
|
})
|
|
}
|
|
}
|
|
out, err := reply.Pack()
|
|
if err != nil {
|
|
t.Fatalf("pack reply: %v", err)
|
|
}
|
|
return out
|
|
}
|
|
|
|
func readWire(t *testing.T, r io.Reader) []byte {
|
|
t.Helper()
|
|
b, err := io.ReadAll(r)
|
|
if err != nil {
|
|
t.Fatalf("read body: %v", err)
|
|
}
|
|
return b
|
|
}
|
|
|
|
func TestDohResolver_LookupIPAddr_Success(t *testing.T) {
|
|
var hits atomic.Int32
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
hits.Add(1)
|
|
if ct := r.Header.Get("Content-Type"); ct != "application/dns-message" {
|
|
t.Errorf("wrong Content-Type: %q", ct)
|
|
}
|
|
body := readWire(t, r.Body)
|
|
w.Header().Set("Content-Type", "application/dns-message")
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = w.Write(dohAnswer(t, body, net.ParseIP("93.184.216.34"), 300))
|
|
}))
|
|
defer srv.Close()
|
|
|
|
r := newDohResolverWithClient(
|
|
[]DohEndpoint{{URL: srv.URL, Hostname: "mock", BootstrapIPs: []string{"127.0.0.1"}}},
|
|
srv.Client(),
|
|
)
|
|
|
|
ips, err := r.LookupIPAddr(context.Background(), "example.com")
|
|
if err != nil {
|
|
t.Fatalf("lookup: %v", err)
|
|
}
|
|
if len(ips) == 0 {
|
|
t.Fatalf("no ips returned")
|
|
}
|
|
if ips[0].String() != "93.184.216.34" {
|
|
t.Fatalf("unexpected ip %s", ips[0])
|
|
}
|
|
// Two concurrent queries fire (A + AAAA), so we expect 2 hits.
|
|
if got := hits.Load(); got != 2 {
|
|
t.Fatalf("expected 2 HTTP hits, got %d", got)
|
|
}
|
|
}
|
|
|
|
func TestDohResolver_Fallback(t *testing.T) {
|
|
var firstHits, secondHits atomic.Int32
|
|
first := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
firstHits.Add(1)
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
}))
|
|
defer first.Close()
|
|
second := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
secondHits.Add(1)
|
|
body := readWire(t, r.Body)
|
|
w.Header().Set("Content-Type", "application/dns-message")
|
|
_, _ = w.Write(dohAnswer(t, body, net.ParseIP("1.2.3.4"), 300))
|
|
}))
|
|
defer second.Close()
|
|
|
|
r := newDohResolverWithClient(
|
|
[]DohEndpoint{
|
|
{URL: first.URL, Hostname: "first", BootstrapIPs: []string{"127.0.0.1"}},
|
|
{URL: second.URL, Hostname: "second", BootstrapIPs: []string{"127.0.0.1"}},
|
|
},
|
|
first.Client(),
|
|
)
|
|
ips, err := r.LookupIPAddr(context.Background(), "example.com")
|
|
if err != nil {
|
|
t.Fatalf("lookup: %v", err)
|
|
}
|
|
if len(ips) != 1 || ips[0].String() != "1.2.3.4" {
|
|
t.Fatalf("unexpected ips: %v", ips)
|
|
}
|
|
if firstHits.Load() == 0 || secondHits.Load() == 0 {
|
|
t.Fatalf("fallback did not probe both endpoints: first=%d second=%d", firstHits.Load(), secondHits.Load())
|
|
}
|
|
}
|
|
|
|
func TestDohResolver_Cache(t *testing.T) {
|
|
var hits atomic.Int32
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
hits.Add(1)
|
|
body := readWire(t, r.Body)
|
|
w.Header().Set("Content-Type", "application/dns-message")
|
|
_, _ = w.Write(dohAnswer(t, body, net.ParseIP("5.6.7.8"), 300))
|
|
}))
|
|
defer srv.Close()
|
|
|
|
r := newDohResolverWithClient(
|
|
[]DohEndpoint{{URL: srv.URL, Hostname: "mock", BootstrapIPs: []string{"127.0.0.1"}}},
|
|
srv.Client(),
|
|
)
|
|
if _, err := r.LookupIPAddr(context.Background(), "example.com"); err != nil {
|
|
t.Fatalf("first lookup: %v", err)
|
|
}
|
|
firstHits := hits.Load()
|
|
if _, err := r.LookupIPAddr(context.Background(), "example.com"); err != nil {
|
|
t.Fatalf("second lookup: %v", err)
|
|
}
|
|
if hits.Load() != firstHits {
|
|
t.Fatalf("cache miss: expected %d HTTP hits, got %d", firstHits, hits.Load())
|
|
}
|
|
}
|
|
|
|
func TestAutoDial_StickyAfterUDPFailure(t *testing.T) {
|
|
// DoH backend: always responds with a valid wire-format reply.
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
body := readWire(t, r.Body)
|
|
w.Header().Set("Content-Type", "application/dns-message")
|
|
_, _ = w.Write(dohAnswer(t, body, net.ParseIP("9.9.9.9"), 300))
|
|
}))
|
|
defer srv.Close()
|
|
|
|
resolver := newDohResolverWithClient(
|
|
[]DohEndpoint{{URL: srv.URL, Hostname: "mock", BootstrapIPs: []string{"127.0.0.1"}}},
|
|
srv.Client(),
|
|
)
|
|
|
|
dial := autoDial(resolver)
|
|
|
|
// Poison udpDNSServers so that udpProbe (real DNS round-trip) fails
|
|
// immediately — net.DialTimeout rejects the malformed address.
|
|
old := udpDNSServers
|
|
udpDNSServers = []string{"not-a-valid-host-port"}
|
|
defer func() { udpDNSServers = old }()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
|
|
conn1, err := dial(ctx, "udp", "unused")
|
|
if err != nil {
|
|
t.Fatalf("first dial: %v", err)
|
|
}
|
|
_ = conn1.Close()
|
|
|
|
// Second call must skip UDP entirely. We assert this by poisoning
|
|
// udpDNSServers with a value that would fail parsing — if the dialer
|
|
// touches UDP again the call errors loudly.
|
|
udpDNSServers = []string{"still-not-a-valid-host-port"}
|
|
conn2, err := dial(ctx, "udp", "unused")
|
|
if err != nil {
|
|
t.Fatalf("second dial: %v", err)
|
|
}
|
|
_ = conn2.Close()
|
|
}
|
|
|