diff --git a/client/doh.go b/client/doh.go new file mode 100644 index 0000000..76744e6 --- /dev/null +++ b/client/doh.go @@ -0,0 +1,601 @@ +// DNS-over-HTTPS resolver for mobile networks where UDP/53 is blocked or +// spoofed. + +package main + +import ( + "bytes" + "context" + "crypto/tls" + "errors" + "fmt" + "io" + "log" + "net" + "net/http" + "sort" + "sync" + "sync/atomic" + "time" + + "github.com/miekg/dns" + + // Embedded Mozilla CA roots for CGO_ENABLED=0 builds (Android). + _ "golang.org/x/crypto/x509roots/fallback" +) + +const ( + dohQueryTimeout = 6 * time.Second + dohCacheMinTTL = 10 * time.Second + dohCacheMaxTTL = 1 * time.Hour + dohMaxResponseBytes = 64 * 1024 + dohContentType = "application/dns-message" + + dohDialerTimeout = 5 * time.Second + dohDialerKeepAlive = 30 * time.Second + appDialerTimeout = 20 * time.Second + appDialerKeepAlive = 30 * time.Second + + forwarderUDPBufSize = 4096 + forwarderTCPReadDL = 30 * time.Second + forwarderTCPWriteDL = 10 * time.Second + autoUDPBudget = 1500 * time.Millisecond +) + +// DohEndpoint describes a single DNS-over-HTTPS server together with the IPs +// we bootstrap to — so that resolving the endpoint hostname does not itself +// require DNS. +type DohEndpoint struct { + URL string + Hostname string + BootstrapIPs []string +} + +// Yandex is tried first because it tends to stay reachable on RU mobile +// operators even when international resolvers get blocked; Google and +// Cloudflare follow as fallbacks. +var defaultDohEndpoints = []DohEndpoint{ + {"https://common.dot.dns.yandex.net/dns-query", "common.dot.dns.yandex.net", []string{"77.88.8.8", "77.88.8.1"}}, + {"https://secure.dot.dns.yandex.net/dns-query", "secure.dot.dns.yandex.net", []string{"77.88.8.88", "77.88.8.2"}}, + {"https://family.dot.dns.yandex.net/dns-query", "family.dot.dns.yandex.net", []string{"77.88.8.7", "77.88.8.3"}}, + {"https://dns.google/dns-query", "dns.google", []string{"8.8.8.8", "8.8.4.4"}}, + {"https://cloudflare-dns.com/dns-query", "cloudflare-dns.com", []string{"1.1.1.1", "1.0.0.1"}}, +} + +// DohResolver resolves hostnames to IPs via DNS-over-HTTPS (RFC 8484). +type DohResolver struct { + endpoints []DohEndpoint + client *http.Client + cache *dohCache +} + +// NewDohResolver constructs a resolver using defaultDohEndpoints if endpoints +// is nil. Endpoint hostnames are dialed by IP using BootstrapIPs, so the DoH +// transport never depends on the system resolver. +func NewDohResolver(endpoints []DohEndpoint) *DohResolver { + if len(endpoints) == 0 { + endpoints = defaultDohEndpoints + } + return &DohResolver{ + endpoints: endpoints, + client: &http.Client{Timeout: dohQueryTimeout, Transport: newBootstrapTransport(endpoints)}, + cache: newDohCache(), + } +} + +// newDohResolverWithClient is a test hook that skips the bootstrap transport. +func newDohResolverWithClient(endpoints []DohEndpoint, client *http.Client) *DohResolver { + return &DohResolver{endpoints: endpoints, client: client, cache: newDohCache()} +} + +// newBootstrapTransport returns an http.Transport whose DialContext only +// knows how to reach the configured DoH endpoint hostnames, by mapping each +// to its BootstrapIPs. +func newBootstrapTransport(endpoints []DohEndpoint) *http.Transport { + bootstrap := make(map[string][]string, len(endpoints)) + for _, ep := range endpoints { + bootstrap[ep.Hostname] = ep.BootstrapIPs + } + dialer := &net.Dialer{Timeout: dohDialerTimeout, KeepAlive: dohDialerKeepAlive} + + return &http.Transport{ + MaxIdleConns: 8, + MaxIdleConnsPerHost: 2, + IdleConnTimeout: 90 * time.Second, + ForceAttemptHTTP2: true, + TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12}, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + ips, ok := bootstrap[host] + if !ok { + return nil, fmt.Errorf("doh: no bootstrap IPs for %q", host) + } + var lastErr error + for _, ip := range ips { + conn, derr := dialer.DialContext(ctx, network, net.JoinHostPort(ip, port)) + if derr == nil { + return conn, nil + } + lastErr = derr + } + return nil, lastErr + }, + } +} + +// LookupIPAddr resolves host to a combined list of A+AAAA IPs (IPv4 first). +// Cached results bypass the network entirely. +func (r *DohResolver) LookupIPAddr(ctx context.Context, host string) ([]net.IP, error) { + if ip := net.ParseIP(host); ip != nil { + return []net.IP{ip}, nil + } + if ips, ok := r.cache.get(host); ok { + return ips, nil + } + + type res struct { + ips []net.IP + ttl time.Duration + err error + } + results := make(chan res, 2) + for _, qt := range [...]uint16{dns.TypeA, dns.TypeAAAA} { + go func(qtype uint16) { + ips, ttl, err := r.queryIPs(ctx, host, qtype) + results <- res{ips, ttl, err} + }(qt) + } + + var ( + all []net.IP + lastErr error + minTTL = dohCacheMaxTTL + ) + for range 2 { + rr := <-results + if rr.err != nil { + lastErr = rr.err + continue + } + all = append(all, rr.ips...) + if rr.ttl > 0 && rr.ttl < minTTL { + minTTL = rr.ttl + } + } + if len(all) == 0 { + if lastErr == nil { + lastErr = fmt.Errorf("doh: no records for %s", host) + } + return nil, lastErr + } + + // IPv4 before IPv6 — better compat with mobile IPv4-only CGNAT. + sort.SliceStable(all, func(i, j int) bool { + return (all[i].To4() != nil) && (all[j].To4() == nil) + }) + + if minTTL < dohCacheMinTTL { + minTTL = dohCacheMinTTL + } + r.cache.set(host, all, minTTL) + return all, nil +} + +// queryIPs issues one DoH query for qtype, walking endpoints until one +// succeeds, and parses the wire reply into IPs + min TTL. +func (r *DohResolver) queryIPs(ctx context.Context, host string, qtype uint16) ([]net.IP, time.Duration, error) { + m := new(dns.Msg) + m.SetQuestion(dns.Fqdn(host), qtype) + m.Id = 0 // RFC 8484 §4.1 — zero ID is cache-friendly on shared caches. + m.RecursionDesired = true + wire, err := m.Pack() + if err != nil { + return nil, 0, fmt.Errorf("doh: pack query: %w", err) + } + + body, ep, err := r.forwardRaw(ctx, wire) + if err != nil { + return nil, 0, err + } + ips, ttl, err := parseAnswer(body) + if err != nil { + return nil, 0, fmt.Errorf("doh: parse %s: %w", ep.Hostname, err) + } + log.Printf("[DoH] %s %s via %s → %d IPs (ttl %s)", host, dns.TypeToString[qtype], ep.Hostname, len(ips), ttl) + return ips, ttl, nil +} + +// parseAnswer decodes a DNS wire reply into A/AAAA records and the minimum TTL. +func parseAnswer(body []byte) ([]net.IP, time.Duration, error) { + reply := new(dns.Msg) + if err := reply.Unpack(body); err != nil { + return nil, 0, fmt.Errorf("unpack: %w", err) + } + if reply.Rcode != dns.RcodeSuccess { + return nil, 0, fmt.Errorf("rcode %s", dns.RcodeToString[reply.Rcode]) + } + var ( + ips []net.IP + minTTL uint32 + ) + updateTTL := func(ttl uint32) { + if minTTL == 0 || ttl < minTTL { + minTTL = ttl + } + } + for _, ans := range reply.Answer { + switch a := ans.(type) { + case *dns.A: + ips = append(ips, a.A) + updateTTL(a.Hdr.Ttl) + case *dns.AAAA: + ips = append(ips, a.AAAA) + updateTTL(a.Hdr.Ttl) + } + } + return ips, time.Duration(minTTL) * time.Second, nil +} + +// forwardRaw POSTs an opaque DNS-wire query to the configured DoH endpoints +// in order and returns the first successful raw response together with the +// endpoint that produced it. No parsing — useful for the local forwarder +// which needs to pass through whatever the upstream resolver answers +// (RESINFO/HTTPS/SVCB/EDNS options/…). +func (r *DohResolver) forwardRaw(ctx context.Context, query []byte) ([]byte, DohEndpoint, error) { + if len(r.endpoints) == 0 { + return nil, DohEndpoint{}, errors.New("doh: no endpoints configured") + } + var lastErr error + for _, ep := range r.endpoints { + body, err := r.postWire(ctx, ep, query) + if err != nil { + log.Printf("[DoH] %s: %v", ep.Hostname, err) + lastErr = err + continue + } + return body, ep, nil + } + return nil, DohEndpoint{}, lastErr +} + +// postWire performs a single application/dns-message POST to one endpoint. +func (r *DohResolver) postWire(ctx context.Context, ep DohEndpoint, query []byte) ([]byte, error) { + req, err := http.NewRequestWithContext(ctx, "POST", ep.URL, bytes.NewReader(query)) + if err != nil { + return nil, fmt.Errorf("build request: %w", err) + } + req.Header.Set("Content-Type", dohContentType) + req.Header.Set("Accept", dohContentType) + + resp, err := r.client.Do(req) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + _, _ = io.Copy(io.Discard, resp.Body) + return nil, fmt.Errorf("HTTP %d", resp.StatusCode) + } + body, err := io.ReadAll(io.LimitReader(resp.Body, dohMaxResponseBytes)) + if err != nil { + return nil, fmt.Errorf("read body: %w", err) + } + return body, nil +} + +type dohCacheEntry struct { + ips []net.IP + expiry time.Time +} + +type dohCache struct { + mu sync.RWMutex + m map[string]dohCacheEntry +} + +func newDohCache() *dohCache { + return &dohCache{m: make(map[string]dohCacheEntry)} +} + +func (c *dohCache) get(host string) ([]net.IP, bool) { + c.mu.RLock() + e, ok := c.m[host] + c.mu.RUnlock() + if !ok || time.Now().After(e.expiry) { + return nil, false + } + out := make([]net.IP, len(e.ips)) + copy(out, e.ips) + return out, true +} + +func (c *dohCache) set(host string, ips []net.IP, ttl time.Duration) { + if ttl <= 0 { + return + } + if ttl > dohCacheMaxTTL { + ttl = dohCacheMaxTTL + } + cp := make([]net.IP, len(ips)) + copy(cp, ips) + c.mu.Lock() + c.m[host] = dohCacheEntry{ips: cp, expiry: time.Now().Add(ttl)} + c.mu.Unlock() +} + +// Go's net.Resolver dials this stub like a regular nameserver, which avoids +// the many edge cases of a fake-net.Conn approach (RESINFO probes, EDNS +// handshakes, truncation, …). Whatever it reads on UDP/TCP is sent verbatim +// to a DoH endpoint and the wire response is sent back to the client. + +type dohForwarder struct { + udpAddr string + tcpAddr string +} + +var ( + dohForwarderOnce sync.Once + dohForwarderInst *dohForwarder + dohForwarderErr error +) + +// sharedDohForwarder lazily starts a process-wide forwarder bound to the +// supplied resolver. The first caller wins; subsequent callers reuse the +// same forwarder regardless of what they pass in. +func sharedDohForwarder(r *DohResolver) (*dohForwarder, error) { + dohForwarderOnce.Do(func() { + dohForwarderInst, dohForwarderErr = startDohForwarder(r) + }) + return dohForwarderInst, dohForwarderErr +} + +func startDohForwarder(r *DohResolver) (_ *dohForwarder, err error) { + udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) + if err != nil { + return nil, fmt.Errorf("doh forwarder: listen UDP: %w", err) + } + defer func() { + if err != nil { + _ = udpConn.Close() + } + }() + tcpLn, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}) + if err != nil { + return nil, fmt.Errorf("doh forwarder: listen TCP: %w", err) + } + defer func() { + if err != nil { + _ = tcpLn.Close() + } + }() + + fwd := &dohForwarder{ + udpAddr: udpConn.LocalAddr().String(), + tcpAddr: tcpLn.Addr().String(), + } + log.Printf("[DoH] forwarder listening udp=%s tcp=%s", fwd.udpAddr, fwd.tcpAddr) + + go fwd.serveUDP(udpConn, r) + go fwd.serveTCP(tcpLn, r) + return fwd, nil +} + +func (f *dohForwarder) serveUDP(conn *net.UDPConn, r *DohResolver) { + defer func() { _ = conn.Close() }() + buf := make([]byte, forwarderUDPBufSize) + for { + n, client, err := conn.ReadFromUDP(buf) + if err != nil { + log.Printf("[DoH] udp read: %v", err) + return + } + query := append([]byte(nil), buf[:n]...) + go func(q []byte, c *net.UDPAddr) { + ctx, cancel := context.WithTimeout(context.Background(), dohQueryTimeout) + defer cancel() + resp, _, err := r.forwardRaw(ctx, q) + if err != nil { + log.Printf("[DoH] udp forward failed: %v", err) + return + } + if _, err := conn.WriteToUDP(resp, c); err != nil { + log.Printf("[DoH] udp write: %v", err) + } + }(query, client) + } +} + +func (f *dohForwarder) serveTCP(ln *net.TCPListener, r *DohResolver) { + defer func() { _ = ln.Close() }() + for { + conn, err := ln.Accept() + if err != nil { + log.Printf("[DoH] tcp accept: %v", err) + return + } + go handleDohForwarderTCP(conn, r) + } +} + +func handleDohForwarderTCP(conn net.Conn, r *DohResolver) { + defer func() { _ = conn.Close() }() + for { + _ = conn.SetReadDeadline(time.Now().Add(forwarderTCPReadDL)) + var lenBuf [2]byte + if _, err := io.ReadFull(conn, lenBuf[:]); err != nil { + return + } + qlen := int(lenBuf[0])<<8 | int(lenBuf[1]) + if qlen == 0 || qlen > forwarderUDPBufSize { + return + } + query := make([]byte, qlen) + if _, err := io.ReadFull(conn, query); err != nil { + return + } + + ctx, cancel := context.WithTimeout(context.Background(), dohQueryTimeout) + resp, _, err := r.forwardRaw(ctx, query) + cancel() + if err != nil { + log.Printf("[DoH] tcp forward failed: %v", err) + return + } + out := make([]byte, 2+len(resp)) + out[0] = byte(len(resp) >> 8) + out[1] = byte(len(resp)) + copy(out[2:], resp) + _ = conn.SetWriteDeadline(time.Now().Add(forwarderTCPWriteDL)) + if _, err := conn.Write(out); err != nil { + return + } + } +} + +// dohForwarderDial returns a Resolver.Dial that connects to the local DoH +// forwarder over UDP or TCP (whichever the resolver asked for). +func dohForwarderDial(r *DohResolver) dialFunc { + return func(ctx context.Context, network, _ string) (net.Conn, error) { + fwd, err := sharedDohForwarder(r) + if err != nil { + return nil, err + } + var d net.Dialer + switch network { + case "tcp", "tcp4", "tcp6": + return d.DialContext(ctx, "tcp", fwd.tcpAddr) + default: + return d.DialContext(ctx, "udp", fwd.udpAddr) + } + } +} + +const ( + DNSModeUDP = "udp" + DNSModeDoH = "doh" + DNSModeAuto = "auto" +) + +var udpDNSServers = []string{ + "77.88.8.8:53", "77.88.8.1:53", + "8.8.8.8:53", "8.8.4.4:53", + "1.1.1.1:53", "1.0.0.1:53", +} + +type dialFunc = func(context.Context, string, string) (net.Conn, error) + +// buildDialer returns a net.Dialer whose internal Go resolver uses the +// chosen DNS transport. In "auto" mode the first total-failure of UDP/53 +// sticks the process onto DoH for the rest of its lifetime. +func buildDialer(mode string, r *DohResolver) net.Dialer { + switch mode { + case DNSModeUDP: + return newAppDialer(udpDNSDial) + case DNSModeDoH: + return newAppDialer(dohForwarderDial(r)) + case DNSModeAuto: + return newAppDialer(autoDial(r)) + default: + log.Panicf("unknown DNS mode %q", mode) + return net.Dialer{} + } +} + +// newAppDialer wraps a Resolver.Dial with the timeouts used everywhere in +// the app for outbound TCP/HTTP connections. +func newAppDialer(dial dialFunc) net.Dialer { + return net.Dialer{ + Timeout: appDialerTimeout, + KeepAlive: appDialerKeepAlive, + Resolver: &net.Resolver{PreferGo: true, Dial: dial}, + } +} + +// udpDNSDial picks the first reachable UDP/53 resolver from udpDNSServers. +func udpDNSDial(ctx context.Context, _ string, _ string) (net.Conn, error) { + var ( + d net.Dialer + lastErr error + ) + for _, s := range udpDNSServers { + conn, err := d.DialContext(ctx, "udp", s) + if err == nil { + return conn, nil + } + lastErr = err + } + if lastErr == nil { + lastErr = errors.New("no UDP DNS servers available") + } + return nil, lastErr +} + +// autoDial returns a Dial that probes UDP/53 once with a real DNS round-trip; +// if the probe fails it latches onto DoH for the rest of the process. Built +// for Android, where the network can flip between Wi-Fi (UDP/53 works) and +// mobile (UDP/53 blocked). +// +// A simple dial-timeout doesn't work for UDP because UDP "dial" is +// connectionless and always succeeds instantly. The only way to know whether +// UDP/53 actually works is to send a real query and wait for a response. +func autoDial(r *DohResolver) dialFunc { + var ( + probed sync.Once + useDoH atomic.Bool + doh = dohForwarderDial(r) + ) + return func(ctx context.Context, network, addr string) (net.Conn, error) { + probed.Do(func() { + if udpProbe(autoUDPBudget) { + log.Printf("[DNS] UDP/53 probe OK, using UDP") + } else { + log.Printf("[DNS] UDP/53 unreachable; sticky-switching to DoH") + useDoH.Store(true) + } + }) + if useDoH.Load() { + return doh(ctx, network, addr) + } + return udpDNSDial(ctx, network, addr) + } +} + +// udpProbe sends a real DNS A query for a well-known domain via UDP and +// checks whether any response arrives within the deadline. We try the first +// two servers from udpDNSServers under a shared deadline — if neither +// responds, UDP/53 is blocked. +func udpProbe(timeout time.Duration) bool { + m := new(dns.Msg) + m.SetQuestion("dns.google.", dns.TypeA) + m.RecursionDesired = true + wire, err := m.Pack() + if err != nil { + return false + } + + deadline := time.Now().Add(timeout) + buf := make([]byte, 512) + limit := min(len(udpDNSServers), 2) + for _, server := range udpDNSServers[:limit] { + remaining := time.Until(deadline) + if remaining <= 0 { + break + } + conn, err := net.DialTimeout("udp", server, remaining) + if err != nil { + continue + } + _ = conn.SetDeadline(deadline) + _, _ = conn.Write(wire) + n, err := conn.Read(buf) + _ = conn.Close() + if err == nil && n > 12 { + return true + } + } + return false +} diff --git a/client/doh_test.go b/client/doh_test.go new file mode 100644 index 0000000..39d9088 --- /dev/null +++ b/client/doh_test.go @@ -0,0 +1,197 @@ +package main + +import ( + "context" + "io" + "net" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "github.com/miekg/dns" +) + +// dohAnswer builds a wire-format DNS reply for a single question with one +// answer of the matching type (A or AAAA). TTL is returned as-is. +func dohAnswer(t *testing.T, query []byte, ip net.IP, ttl uint32) []byte { + t.Helper() + req := new(dns.Msg) + if err := req.Unpack(query); err != nil { + t.Fatalf("unpack query: %v", err) + } + reply := new(dns.Msg) + reply.SetReply(req) + if len(req.Question) != 1 { + t.Fatalf("expected 1 question, got %d", len(req.Question)) + } + q := req.Question[0] + switch q.Qtype { + case dns.TypeA: + if v4 := ip.To4(); v4 != nil { + reply.Answer = append(reply.Answer, &dns.A{ + Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: ttl}, + A: v4, + }) + } + case dns.TypeAAAA: + if ip.To4() == nil { + reply.Answer = append(reply.Answer, &dns.AAAA{ + Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: ttl}, + AAAA: ip, + }) + } + } + out, err := reply.Pack() + if err != nil { + t.Fatalf("pack reply: %v", err) + } + return out +} + +func readWire(t *testing.T, r io.Reader) []byte { + t.Helper() + b, err := io.ReadAll(r) + if err != nil { + t.Fatalf("read body: %v", err) + } + return b +} + +func TestDohResolver_LookupIPAddr_Success(t *testing.T) { + var hits atomic.Int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hits.Add(1) + if ct := r.Header.Get("Content-Type"); ct != "application/dns-message" { + t.Errorf("wrong Content-Type: %q", ct) + } + body := readWire(t, r.Body) + w.Header().Set("Content-Type", "application/dns-message") + w.WriteHeader(http.StatusOK) + _, _ = w.Write(dohAnswer(t, body, net.ParseIP("93.184.216.34"), 300)) + })) + defer srv.Close() + + r := newDohResolverWithClient( + []DohEndpoint{{URL: srv.URL, Hostname: "mock", BootstrapIPs: []string{"127.0.0.1"}}}, + srv.Client(), + ) + + ips, err := r.LookupIPAddr(context.Background(), "example.com") + if err != nil { + t.Fatalf("lookup: %v", err) + } + if len(ips) == 0 { + t.Fatalf("no ips returned") + } + if ips[0].String() != "93.184.216.34" { + t.Fatalf("unexpected ip %s", ips[0]) + } + // Two concurrent queries fire (A + AAAA), so we expect 2 hits. + if got := hits.Load(); got != 2 { + t.Fatalf("expected 2 HTTP hits, got %d", got) + } +} + +func TestDohResolver_Fallback(t *testing.T) { + var firstHits, secondHits atomic.Int32 + first := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + firstHits.Add(1) + w.WriteHeader(http.StatusInternalServerError) + })) + defer first.Close() + second := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + secondHits.Add(1) + body := readWire(t, r.Body) + w.Header().Set("Content-Type", "application/dns-message") + _, _ = w.Write(dohAnswer(t, body, net.ParseIP("1.2.3.4"), 300)) + })) + defer second.Close() + + r := newDohResolverWithClient( + []DohEndpoint{ + {URL: first.URL, Hostname: "first", BootstrapIPs: []string{"127.0.0.1"}}, + {URL: second.URL, Hostname: "second", BootstrapIPs: []string{"127.0.0.1"}}, + }, + first.Client(), + ) + ips, err := r.LookupIPAddr(context.Background(), "example.com") + if err != nil { + t.Fatalf("lookup: %v", err) + } + if len(ips) != 1 || ips[0].String() != "1.2.3.4" { + t.Fatalf("unexpected ips: %v", ips) + } + if firstHits.Load() == 0 || secondHits.Load() == 0 { + t.Fatalf("fallback did not probe both endpoints: first=%d second=%d", firstHits.Load(), secondHits.Load()) + } +} + +func TestDohResolver_Cache(t *testing.T) { + var hits atomic.Int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hits.Add(1) + body := readWire(t, r.Body) + w.Header().Set("Content-Type", "application/dns-message") + _, _ = w.Write(dohAnswer(t, body, net.ParseIP("5.6.7.8"), 300)) + })) + defer srv.Close() + + r := newDohResolverWithClient( + []DohEndpoint{{URL: srv.URL, Hostname: "mock", BootstrapIPs: []string{"127.0.0.1"}}}, + srv.Client(), + ) + if _, err := r.LookupIPAddr(context.Background(), "example.com"); err != nil { + t.Fatalf("first lookup: %v", err) + } + firstHits := hits.Load() + if _, err := r.LookupIPAddr(context.Background(), "example.com"); err != nil { + t.Fatalf("second lookup: %v", err) + } + if hits.Load() != firstHits { + t.Fatalf("cache miss: expected %d HTTP hits, got %d", firstHits, hits.Load()) + } +} + +func TestAutoDial_StickyAfterUDPFailure(t *testing.T) { + // DoH backend: always responds with a valid wire-format reply. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body := readWire(t, r.Body) + w.Header().Set("Content-Type", "application/dns-message") + _, _ = w.Write(dohAnswer(t, body, net.ParseIP("9.9.9.9"), 300)) + })) + defer srv.Close() + + resolver := newDohResolverWithClient( + []DohEndpoint{{URL: srv.URL, Hostname: "mock", BootstrapIPs: []string{"127.0.0.1"}}}, + srv.Client(), + ) + + dial := autoDial(resolver) + + // Poison udpDNSServers so that udpProbe (real DNS round-trip) fails + // immediately — net.DialTimeout rejects the malformed address. + old := udpDNSServers + udpDNSServers = []string{"not-a-valid-host-port"} + defer func() { udpDNSServers = old }() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + conn1, err := dial(ctx, "udp", "unused") + if err != nil { + t.Fatalf("first dial: %v", err) + } + _ = conn1.Close() + + // Second call must skip UDP entirely. We assert this by poisoning + // udpDNSServers with a value that would fail parsing — if the dialer + // touches UDP again the call errors loudly. + udpDNSServers = []string{"still-not-a-valid-host-port"} + conn2, err := dial(ctx, "udp", "unused") + if err != nil { + t.Fatalf("second dial: %v", err) + } + _ = conn2.Close() +} diff --git a/client/main.go b/client/main.go index 73dfd17..59bd2e6 100644 --- a/client/main.go +++ b/client/main.go @@ -11,6 +11,7 @@ import ( "encoding/base64" "encoding/hex" "encoding/json" + "errors" "flag" "fmt" "io" @@ -21,6 +22,7 @@ import ( neturl "net/url" "os" "os/signal" + "runtime" "strconv" "strings" "sync" @@ -32,7 +34,6 @@ import ( tlsclient "github.com/bogdanfinn/tls-client" "github.com/bogdanfinn/tls-client/profiles" - "github.com/bschaatsbergen/dnsdialer" "github.com/cacggghp/vk-turn-proxy/tcputil" "github.com/cbeuw/connutil" "github.com/google/uuid" @@ -63,10 +64,12 @@ var ( globalCaptchaLockout atomic.Int64 connectedStreams atomic.Int32 globalAppCancel context.CancelFunc - handshakeSem = make(chan struct{}, 3) + handshakeSem chan struct{} isDebug bool manualCaptcha bool autoCaptchaSliderPOC bool + allocsPerStream int + udpMode bool ) type captchaSolveMode int @@ -227,47 +230,35 @@ func applyBrowserProfileFhttp(req *fhttp.Request, profile Profile) { req.Header.Set("DNT", "1") } +// generateBrowserFp produces a stable fallback fingerprint when no SavedProfile +// is available. Stable (no time component) so the value matches between +// componentDone and check inside the same auto-solve attempt. func generateBrowserFp(profile Profile) string { - data := profile.UserAgent + profile.SecChUa + "1920x1080x24" + strconv.FormatInt(time.Now().UnixNano(), 10) + data := profile.UserAgent + profile.SecChUa + "1536x864x24" h := md5.Sum([]byte(data)) return hex.EncodeToString(h[:]) } -func generateFakeCursor() string { - startX := 600 + rand.Intn(400) - startY := 300 + rand.Intn(200) - startTime := time.Now().UnixMilli() - int64(rand.Intn(2000)+1000) - var points []string - for i := 0; i < 15+rand.Intn(10); i++ { - startX += rand.Intn(15) - 5 - startY += rand.Intn(15) + 2 - startTime += int64(rand.Intn(40) + 10) - points = append(points, fmt.Sprintf(`{"x":%d,"y":%d,"t":%d}`, startX, startY, startTime)) - } - return "[" + strings.Join(points, ",") + "]" +// dnsMode is set in main() from the -dns flag and consumed by appDialer(). +var dnsMode = DNSModeAuto + +// dohResolverSingleton is shared across all callers of appDialer(). +var ( + dohResolverOnce sync.Once + dohResolverInstance *DohResolver +) + +func sharedDohResolver() *DohResolver { + dohResolverOnce.Do(func() { + dohResolverInstance = NewDohResolver(nil) + }) + return dohResolverInstance } -func getCustomNetDialer() net.Dialer { - return net.Dialer{ - Timeout: 20 * time.Second, - KeepAlive: 30 * time.Second, - Resolver: &net.Resolver{ - PreferGo: true, - Dial: func(ctx context.Context, network, address string) (net.Conn, error) { - var d net.Dialer - dnsServers := []string{"77.88.8.8:53", "77.88.8.1:53", "8.8.8.8:53", "8.8.4.4:53", "1.1.1.1:53", "1.0.0.1:53"} - var lastErr error - for _, dns := range dnsServers { - conn, err := d.DialContext(ctx, "udp", dns) - if err == nil { - return conn, nil - } - lastErr = err - } - return nil, lastErr - }, - }, - } +// appDialer returns the net.Dialer used by tls-client and other HTTP callers. +// DNS transport is selected by the -dns flag (udp | doh | auto). +func appDialer() net.Dialer { + return buildDialer(dnsMode, sharedDohResolver()) } // endregion @@ -393,6 +384,16 @@ func solveVkCaptcha(ctx context.Context, captchaErr *VkCaptchaError, streamID in return "", fmt.Errorf("no redirect_uri for auto-solve") } + // Reuse the real-browser fingerprint captured during a prior manual solve. + // VK fingerprints (browser_fp, device, UA) together; keeping them consistent + // across runs helps the auto path stay out of the BOT bucket. + var savedProfile *SavedProfile + if sp, err := LoadProfileFromDisk(); err == nil { + log.Printf("[STREAM %d] [Captcha] Using saved real browser profile", streamID) + savedProfile = sp + profile = sp.Profile + } + bootstrap, err := fetchCaptchaBootstrap(ctx, captchaErr.RedirectURI, client, profile) if err != nil { return "", fmt.Errorf("failed to fetch captcha bootstrap: %w", err) @@ -400,7 +401,10 @@ func solveVkCaptcha(ctx context.Context, captchaErr *VkCaptchaError, streamID in log.Printf("[STREAM %d] [Captcha] PoW input: %s, difficulty: %d", streamID, bootstrap.PowInput, bootstrap.Difficulty) - hash := solvePoW(bootstrap.PowInput, bootstrap.Difficulty) + hash, err := solvePoW(bootstrap.PowInput, bootstrap.Difficulty) + if err != nil { + return "", fmt.Errorf("PoW: %w", err) + } log.Printf("[STREAM %d] [Captcha] PoW solved: hash=%s", streamID, hash) var successToken string @@ -413,9 +417,10 @@ func solveVkCaptcha(ctx context.Context, captchaErr *VkCaptchaError, streamID in client, profile, bootstrap.Settings, + savedProfile, ) } else { - successToken, err = callCaptchaNotRobot(ctx, captchaErr.SessionToken, hash, streamID, client, profile) + successToken, err = callCaptchaNotRobot(ctx, captchaErr.SessionToken, hash, streamID, client, profile, savedProfile) } if err != nil { return "", fmt.Errorf("captchaNotRobot API failed: %w", err) @@ -459,20 +464,50 @@ func fetchCaptchaBootstrap(ctx context.Context, redirectURI string, client tlscl return parseCaptchaBootstrapHTML(string(body)) } -func solvePoW(powInput string, difficulty int) string { +func solvePoW(powInput string, difficulty int) (string, error) { target := strings.Repeat("0", difficulty) - for nonce := 1; nonce <= 10000000; nonce++ { - data := powInput + strconv.Itoa(nonce) - hash := sha256.Sum256([]byte(data)) - hexHash := hex.EncodeToString(hash[:]) - if strings.HasPrefix(hexHash, target) { - return hexHash - } + const maxNonce = 10000000 + workers := runtime.NumCPU() + if workers < 1 { + workers = 1 } - return "" + + var ( + found atomic.Bool + resultCh = make(chan string, 1) + wg sync.WaitGroup + ) + + for w := 0; w < workers; w++ { + wg.Add(1) + go func(start int) { + defer wg.Done() + for nonce := start; nonce <= maxNonce; nonce += workers { + if found.Load() { + return + } + data := powInput + strconv.Itoa(nonce) + hash := sha256.Sum256([]byte(data)) + hexHash := hex.EncodeToString(hash[:]) + if strings.HasPrefix(hexHash, target) { + if found.CompareAndSwap(false, true) { + resultCh <- hexHash + } + return + } + } + }(w + 1) + } + + go func() { wg.Wait(); close(resultCh) }() + + if h, ok := <-resultCh; ok { + return h, nil + } + return "", fmt.Errorf("PoW unsolved (difficulty=%d, tried %dM nonces)", difficulty, maxNonce/1000000) } -func callCaptchaNotRobot(ctx context.Context, sessionToken, hash string, streamID int, client tlsclient.HttpClient, profile Profile) (string, error) { +func callCaptchaNotRobot(ctx context.Context, sessionToken, hash string, streamID int, client tlsclient.HttpClient, profile Profile, savedProfile *SavedProfile) (string, error) { vkReq := func(method string, postData string) (map[string]interface{}, error) { reqURL := "https://api.vk.ru/method/" + method + "?v=5.131" parsedURL, err := neturl.Parse(reqURL) @@ -490,13 +525,11 @@ func callCaptchaNotRobot(ctx context.Context, sessionToken, hash string, streamI applyBrowserProfileFhttp(req, profile) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Accept", "*/*") - req.Header.Set("Origin", "https://id.vk.ru") - req.Header.Set("Referer", "https://id.vk.ru/") - req.Header.Set("Sec-Fetch-Site", "same-site") + req.Header.Set("Origin", "https://api.vk.ru") + req.Header.Set("Referer", fmt.Sprintf("https://api.vk.ru/not_robot_captcha?domain=vk.com&session_token=%s&variant=popup&blank=1", sessionToken)) + req.Header.Set("Sec-Fetch-Site", "same-origin") req.Header.Set("Sec-Fetch-Mode", "cors") req.Header.Set("Sec-Fetch-Dest", "empty") - req.Header.Set("Sec-GPC", "1") - req.Header.Set("Priority", "u=1, i") httpResp, err := client.Do(req) if err != nil { @@ -517,7 +550,14 @@ func callCaptchaNotRobot(ctx context.Context, sessionToken, hash string, streamI return resp, nil } - baseParams := fmt.Sprintf("session_token=%s&domain=vk.com&adFp=&access_token=", neturl.QueryEscape(sessionToken)) + // Per-session adFp: a stable empty value is itself a fingerprint. + adFpBytes := make([]byte, 16) + for i := range adFpBytes { + adFpBytes[i] = byte(rand.Intn(256)) + } + adFp := base64.RawURLEncoding.EncodeToString(adFpBytes)[:21] + + baseParams := fmt.Sprintf("session_token=%s&domain=vk.com&adFp=%s&access_token=", neturl.QueryEscape(sessionToken), neturl.QueryEscape(adFp)) log.Printf("[STREAM %d] [Captcha] Step 1/4: settings", streamID) if _, err := vkReq("captchaNotRobot.settings", baseParams); err != nil { @@ -529,6 +569,10 @@ func callCaptchaNotRobot(ctx context.Context, sessionToken, hash string, streamI log.Printf("[STREAM %d] [Captcha] Step 2/4: componentDone", streamID) browserFp := generateBrowserFp(profile) deviceJSON := buildCaptchaDeviceJSON(profile) + if savedProfile != nil { + browserFp = savedProfile.BrowserFp + deviceJSON = savedProfile.DeviceJSON + } componentDoneData := baseParams + fmt.Sprintf("&browser_fp=%s&device=%s", browserFp, neturl.QueryEscape(deviceJSON)) if _, err := vkReq("captchaNotRobot.componentDone", componentDoneData); err != nil { @@ -538,15 +582,31 @@ func callCaptchaNotRobot(ctx context.Context, sessionToken, hash string, streamI time.Sleep(200 * time.Millisecond) log.Printf("[STREAM %d] [Captcha] Step 3/4: check", streamID) - cursorJSON := generateFakeCursor() + // Real browser sends [] for cursor on the first check. + cursorJSON := "[]" answer := base64.StdEncoding.EncodeToString([]byte("{}")) - // Dynamically generate debug_info to avoid static fingerprint bans - debugInfoBytes := md5.Sum([]byte(profile.UserAgent + strconv.FormatInt(time.Now().UnixNano(), 10))) + // debug_info must vary per-session — a hardcoded hash becomes a stable + // fingerprint VK uses to flag the bot path (status=BOT). + debugInfoBytes := sha256.Sum256([]byte(profile.UserAgent + sessionToken + strconv.FormatInt(time.Now().UnixNano(), 10))) debugInfo := hex.EncodeToString(debugInfoBytes[:]) - connectionRtt := "[50,50,50,50,50,50,50,50,50,50]" - connectionDownlink := "[9.5,9.5,9.5,9.5,9.5,9.5,9.5,9.5,9.5,9.5,9.5,9.5,9.5,9.5,9.5,9.5]" + // Realistic per-session jitter; static arrays were also a fingerprint. + rttSamples := 4 + rand.Intn(4) + rttBase := 40 + rand.Intn(120) + rttVals := make([]string, rttSamples) + for i := range rttVals { + rttVals[i] = strconv.Itoa(rttBase + rand.Intn(40) - 20) + } + connectionRtt := "[" + strings.Join(rttVals, ",") + "]" + + dlSamples := 4 + rand.Intn(4) + dlBase := 2.0 + rand.Float64()*8.0 + dlVals := make([]string, dlSamples) + for i := range dlVals { + dlVals[i] = strconv.FormatFloat(dlBase+(rand.Float64()-0.5)*0.4, 'f', 2, 64) + } + connectionDownlink := "[" + strings.Join(dlVals, ",") + "]" checkData := baseParams + fmt.Sprintf( "&accelerometer=%s&gyroscope=%s&motion=%s&cursor=%s&taps=%s&connectionRtt=%s&connectionDownlink=%s&browser_fp=%s&hash=%s&answer=%s&debug_info=%s", @@ -594,12 +654,12 @@ type VKCredentials struct { ClientSecret string } +// Only client_ids that currently expose calls.getAnonymousToken. +// VKVIDEO_* and VK_ID_AUTH_APP started returning error_code:3 "Unknown method" +// (observed 2026-04-28) and only burn throttle budget if kept in rotation. var vkCredentialsList = []VKCredentials{ - {ClientID: "6287487", ClientSecret: "QbYic1K3lEV5kTGiqlq2"}, // VK_WEB_APP_ID - {ClientID: "7879029", ClientSecret: "aR5NKGmm03GYrCiNKsaw"}, // VK_MVK_APP_ID - {ClientID: "52461373", ClientSecret: "o557NLIkAErNhakXrQ7A"}, // VK_WEB_VKVIDEO_APP_ID - {ClientID: "52649896", ClientSecret: "WStp4ihWG4l3nmXZgIbC"}, // VK_MVK_VKVIDEO_APP_ID - {ClientID: "51781872", ClientSecret: "IjjCNl4L4Tf5QZEXIHKK"}, // VK_ID_AUTH_APP + {ClientID: "6287487", ClientSecret: "QbYic1K3lEV5kTGiqlq2"}, // VK_WEB_APP_ID + {ClientID: "7879029", ClientSecret: "aR5NKGmm03GYrCiNKsaw"}, // VK_MVK_APP_ID } type TurnCredentials struct { @@ -622,7 +682,10 @@ const ( cacheSafetyMargin = 60 * time.Second maxCacheErrors = 3 errorWindow = 10 * time.Second - streamsPerCache = 10 + // streamsPerCache=1: each stream caches its own slot creds because + // acquireVkTurnSlot mints a unique (username, password) per call. + streamsPerCache = 1 + identityLifetime = 8 * time.Minute ) func getCacheID(streamID int) int { @@ -634,6 +697,24 @@ func vkDelayRandom(minMs, maxMs int) { time.Sleep(time.Duration(ms) * time.Millisecond) } +// sleepCtx waits d or until ctx cancels, whichever comes first. Returns false +// on cancellation. Uses NewTimer+Stop so the timer is reclaimed immediately on +// ctx cancel, not after d expires (avoids long-lived timer leak under repeated +// cancellation/restart cycles). +func sleepCtx(ctx context.Context, d time.Duration) bool { + if d <= 0 { + return ctx.Err() == nil + } + timer := time.NewTimer(d) + defer timer.Stop() + select { + case <-ctx.Done(): + return false + case <-timer.C: + return true + } +} + var credentialsStore = struct { mu sync.RWMutex caches map[int]*StreamCredentialsCache @@ -710,7 +791,7 @@ func (c *StreamCredentialsCache) invalidate(streamID int) { log.Printf("[STREAM %d] [VK Auth] Credentials cache invalidated", streamID) } -func getVkCredsCached(ctx context.Context, link string, streamID int, dialer *dnsdialer.Dialer) (string, string, string, error) { +func getVkCredsCached(ctx context.Context, link string, streamID int) (string, string, string, error) { cache := getStreamCache(streamID) cacheID := getCacheID(streamID) @@ -734,7 +815,7 @@ func getVkCredsCached(ctx context.Context, link string, streamID int, dialer *dn return cache.creds.Username, cache.creds.Password, cache.creds.ServerAddr, nil } - user, pass, addr, err := fetchVkCredsSerialized(ctx, link, streamID, dialer) + user, pass, addr, err := fetchVkCreds(ctx, link, streamID) if err != nil { return "", "", "", err } @@ -743,63 +824,196 @@ func getVkCredsCached(ctx context.Context, link string, streamID int, dialer *dn return user, pass, addr, nil } +// vkClientThrottle holds per-client_id serialisation + cooldown timestamp. +// Was previously a single global mutex which forced acquires across distinct +// client_ids onto the same queue even though VK rate-limits per client_id. +type vkClientThrottle struct { + mu sync.Mutex + lastTime time.Time +} + var ( - vkRequestMu sync.Mutex - globalLastVkFetchTime time.Time + vkThrottleStore = struct { + mu sync.Mutex + m map[string]*vkClientThrottle + }{m: make(map[string]*vkClientThrottle)} ) -func fetchVkCredsSerialized(ctx context.Context, link string, streamID int, dialer *dnsdialer.Dialer) (string, string, string, error) { - vkRequestMu.Lock() - defer vkRequestMu.Unlock() +func getVkThrottle(clientID string) *vkClientThrottle { + vkThrottleStore.mu.Lock() + defer vkThrottleStore.mu.Unlock() + t, ok := vkThrottleStore.m[clientID] + if !ok { + t = &vkClientThrottle{} + vkThrottleStore.m[clientID] = t + } + return t +} + +// vkIdentity caches the captcha-gated portion of a VK auth chain (steps 1-3: +// anonym_token + getCallPreview + getAnonymousToken). Once acquired it can be +// replayed via acquireVkTurnSlot to mint independent TURN credentials, each +// with a unique username — bypassing per-username throttling at the cost of a +// single captcha solve per (link, client_id) pair. +type vkIdentity struct { + creds VKCredentials + profile Profile + name string + token1 string + token2 string + client tlsclient.HttpClient + expiresAt time.Time + urlCounter atomic.Uint64 // round-robin index across turn_server.urls +} - // Ensure a minimum cooldown between credential requests to avoid VK rate limits - minInterval := 3*time.Second + time.Duration(rand.Intn(3000))*time.Millisecond - elapsed := time.Since(globalLastVkFetchTime) +type identityCacheKey struct { + link string + clientID string +} - if !globalLastVkFetchTime.IsZero() && elapsed < minInterval { - wait := minInterval - elapsed - log.Printf("[STREAM %d] [VK Auth] Throttling: waiting %v to prevent rate limit...", streamID, wait.Truncate(time.Millisecond)) - select { - case <-ctx.Done(): - return "", "", "", ctx.Err() - case <-time.After(wait): +type identityEntry struct { + mu sync.Mutex + ident *vkIdentity +} + +var identityStore = struct { + mu sync.Mutex + m map[identityCacheKey]*identityEntry +}{m: make(map[identityCacheKey]*identityEntry)} + +// startIdentityJanitor prunes expired identityEntry every period. Long-running +// clients hopping across many distinct (link, client_id) pairs would otherwise +// grow identityStore.m without bound. +// +// Two-phase to avoid blocking acquires: phase 1 snapshots keys under store +// lock; phase 2 inspects each entry via TryLock. If an entry is busy +// (acquireVkIdentity in progress), skip it this round — it will be revisited +// next tick. Only entries with nil or expired ident are deleted. +func startIdentityJanitor(ctx context.Context, period time.Duration) { + go func() { + ticker := time.NewTicker(period) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + } + + identityStore.mu.Lock() + keys := make([]identityCacheKey, 0, len(identityStore.m)) + for k := range identityStore.m { + keys = append(keys, k) + } + identityStore.mu.Unlock() + + now := time.Now() + for _, key := range keys { + identityStore.mu.Lock() + entry, ok := identityStore.m[key] + identityStore.mu.Unlock() + if !ok { + continue + } + if !entry.mu.TryLock() { + continue // busy — revisit next tick + } + expired := entry.ident == nil || now.After(entry.ident.expiresAt) + entry.mu.Unlock() + if !expired { + continue + } + identityStore.mu.Lock() + // Re-check under store lock + entry lock to avoid racing a + // concurrent acquire that just refilled the entry. + if cur, stillThere := identityStore.m[key]; stillThere && cur == entry { + if entry.mu.TryLock() { + if entry.ident == nil || now.After(entry.ident.expiresAt) { + delete(identityStore.m, key) + } + entry.mu.Unlock() + } + } + identityStore.mu.Unlock() + } } + }() +} + +func getOrAcquireIdentity(ctx context.Context, link string, streamID int, creds VKCredentials) (*vkIdentity, error) { + key := identityCacheKey{link: link, clientID: creds.ClientID} + + identityStore.mu.Lock() + entry, ok := identityStore.m[key] + if !ok { + entry = &identityEntry{} + identityStore.m[key] = entry } + identityStore.mu.Unlock() - defer func() { - globalLastVkFetchTime = time.Now() - }() + entry.mu.Lock() + defer entry.mu.Unlock() + + if entry.ident != nil && time.Now().Before(entry.ident.expiresAt) { + return entry.ident, nil + } + + ident, err := acquireVkIdentity(ctx, link, streamID, creds) + if err != nil { + return nil, err + } + entry.ident = ident + return ident, nil +} - return fetchVkCreds(ctx, link, streamID, dialer) +func invalidateIdentity(link, clientID string) { + identityStore.mu.Lock() + entry, ok := identityStore.m[identityCacheKey{link: link, clientID: clientID}] + identityStore.mu.Unlock() + if !ok { + return + } + entry.mu.Lock() + entry.ident = nil + entry.mu.Unlock() } -func fetchVkCreds(ctx context.Context, link string, streamID int, dialer *dnsdialer.Dialer) (string, string, string, error) { - // Check Global Lockout to prevent API bans +func fetchVkCreds(ctx context.Context, link string, streamID int) (string, string, string, error) { if time.Now().Unix() < globalCaptchaLockout.Load() { return "", "", "", fmt.Errorf("CAPTCHA_WAIT_REQUIRED: global lockout active") } - var lastErr error - jar := tlsclient.NewCookieJar() + n := len(vkCredentialsList) + startIdx := streamID % n - for _, creds := range vkCredentialsList { + var lastErr error + for offset := 0; offset < n; offset++ { + creds := vkCredentialsList[(startIdx+offset)%n] log.Printf("[STREAM %d] [VK Auth] Trying credentials: client_id=%s", streamID, creds.ClientID) - user, pass, addr, err := getTokenChain(ctx, link, streamID, creds, dialer, jar) + ident, err := getOrAcquireIdentity(ctx, link, streamID, creds) + if err != nil { + lastErr = err + log.Printf("[STREAM %d] [VK Auth] identity acquire failed (client_id=%s): %v", streamID, creds.ClientID, err) + if strings.Contains(err.Error(), "CAPTCHA_WAIT_REQUIRED") || strings.Contains(err.Error(), "FATAL_CAPTCHA") { + return "", "", "", err + } + continue + } + user, pass, addr, err := acquireVkTurnSlot(ctx, link, streamID, ident) if err == nil { log.Printf("[STREAM %d] [VK Auth] Success with client_id=%s", streamID, creds.ClientID) return user, pass, addr, nil } lastErr = err - log.Printf("[STREAM %d] [VK Auth] Failed with client_id=%s: %v", streamID, creds.ClientID, err) + log.Printf("[STREAM %d] [VK Auth] slot acquire failed (client_id=%s): %v", streamID, creds.ClientID, err) + invalidateIdentity(link, creds.ClientID) - // Hard abort on captcha/fatal conditions instead of trying next creds if strings.Contains(err.Error(), "CAPTCHA_WAIT_REQUIRED") || strings.Contains(err.Error(), "FATAL_CAPTCHA") { return "", "", "", err } - if strings.Contains(err.Error(), "error_code:29") || strings.Contains(err.Error(), "error_code: 29") || strings.Contains(err.Error(), "Rate limit") { log.Printf("[STREAM %d] [VK Auth] Rate limit detected, trying next credentials...", streamID) } @@ -808,7 +1022,81 @@ func fetchVkCreds(ctx context.Context, link string, streamID int, dialer *dnsdia return "", "", "", fmt.Errorf("all VK credentials failed: %w", lastErr) } -func getTokenChain(ctx context.Context, link string, streamID int, creds VKCredentials, dialer *dnsdialer.Dialer, jar tlsclient.CookieJar) (string, string, string, error) { +func vkDoRequest(ctx context.Context, client tlsclient.HttpClient, profile Profile, data, url string) (map[string]interface{}, error) { + parsedURL, err := neturl.Parse(url) + if err != nil { + return nil, fmt.Errorf("parse request URL: %w", err) + } + domain := parsedURL.Hostname() + + req, err := fhttp.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer([]byte(data))) + if err != nil { + return nil, err + } + + req.Host = domain + applyBrowserProfileFhttp(req, profile) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "*/*") + req.Header.Set("Origin", "https://vk.ru") + req.Header.Set("Referer", "https://vk.ru/") + req.Header.Set("Sec-Fetch-Site", "same-site") + req.Header.Set("Sec-Fetch-Mode", "cors") + req.Header.Set("Sec-Fetch-Dest", "empty") + req.Header.Set("Priority", "u=1, i") + + httpResp, err := client.Do(req) + if err != nil { + return nil, err + } + defer func() { + if closeErr := httpResp.Body.Close(); closeErr != nil { + log.Printf("close response body: %s", closeErr) + } + }() + + body, err := io.ReadAll(httpResp.Body) + if err != nil { + return nil, err + } + + var resp map[string]interface{} + if err := json.Unmarshal(body, &resp); err != nil { + return nil, err + } + return resp, nil +} + +// acquireVkIdentity runs the heavy + captcha-gated portion of the VK auth chain +// (steps 1-3: get_anonym_token, calls.getCallPreview, calls.getAnonymousToken). +// The result is cached and reused across many TURN slot acquisitions. +// +// Per-client_id serialised + 3-6s cooldown to avoid VK API bans. Distinct +// client_ids run in parallel. +func acquireVkIdentity(ctx context.Context, link string, streamID int, creds VKCredentials) (*vkIdentity, error) { + throttle := getVkThrottle(creds.ClientID) + throttle.mu.Lock() + defer throttle.mu.Unlock() + + minInterval := 3*time.Second + time.Duration(rand.Intn(3000))*time.Millisecond + elapsed := time.Since(throttle.lastTime) + if !throttle.lastTime.IsZero() && elapsed < minInterval { + wait := minInterval - elapsed + log.Printf("[STREAM %d] [VK Auth] Throttling client_id=%s: waiting %v...", streamID, creds.ClientID, wait.Truncate(time.Millisecond)) + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(wait): + } + } + defer func() { + throttle.lastTime = time.Now() + }() + + if time.Now().Unix() < globalCaptchaLockout.Load() { + return nil, fmt.Errorf("CAPTCHA_WAIT_REQUIRED: global lockout active") + } + profile := Profile{ UserAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/146.0.0.0 Safari/537.36", SecChUa: `"Not(A:Brand";v="99", "Google Chrome";v="146", "Chromium";v="146"`, @@ -816,101 +1104,66 @@ func getTokenChain(ctx context.Context, link string, streamID int, creds VKCrede SecChUaPlatform: `"Windows"`, } + jar := tlsclient.NewCookieJar() client, err := tlsclient.NewHttpClient(tlsclient.NewNoopLogger(), tlsclient.WithTimeoutSeconds(20), tlsclient.WithClientProfile(profiles.Chrome_146), tlsclient.WithCookieJar(jar), - tlsclient.WithDialer(getCustomNetDialer()), + tlsclient.WithDialer(appDialer()), ) if err != nil { - return "", "", "", fmt.Errorf("failed to initialize tls_client: %w", err) + return nil, fmt.Errorf("failed to initialize tls_client: %w", err) } name := generateName() escapedName := neturl.QueryEscape(name) - log.Printf("[STREAM %d] [VK Auth] Connecting Identity - Name: %s | User-Agent: %s", streamID, name, profile.UserAgent) + log.Printf("[STREAM %d] [VK Auth] Connecting Identity - Name: %s | client_id=%s", streamID, name, creds.ClientID) - doRequest := func(data string, url string) (resp map[string]interface{}, err error) { - parsedURL, err := neturl.Parse(url) - if err != nil { - return nil, fmt.Errorf("parse request URL: %w", err) - } - domain := parsedURL.Hostname() - - req, err := fhttp.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer([]byte(data))) - if err != nil { - return nil, err - } - - req.Host = domain - applyBrowserProfileFhttp(req, profile) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.Header.Set("Accept", "*/*") - req.Header.Set("Origin", "https://vk.ru") - req.Header.Set("Referer", "https://vk.ru/") - req.Header.Set("Sec-Fetch-Site", "same-site") - req.Header.Set("Sec-Fetch-Mode", "cors") - req.Header.Set("Sec-Fetch-Dest", "empty") - req.Header.Set("Priority", "u=1, i") - - httpResp, err := client.Do(req) - if err != nil { - return nil, err - } - defer func() { - if closeErr := httpResp.Body.Close(); closeErr != nil { - log.Printf("close response body: %s", closeErr) - } - }() - - body, err := io.ReadAll(httpResp.Body) - if err != nil { - return nil, err - } - - err = json.Unmarshal(body, &resp) - if err != nil { - return nil, err - } - return resp, nil - } - - // Token 1 + // Step 1: anonym_token data := fmt.Sprintf("client_id=%s&token_type=messages&client_secret=%s&version=1&app_id=%s", creds.ClientID, creds.ClientSecret, creds.ClientID) - resp, err := doRequest(data, "https://login.vk.ru/?act=get_anonym_token") + resp, err := vkDoRequest(ctx, client, profile, data, "https://login.vk.ru/?act=get_anonym_token") if err != nil { - return "", "", "", err + return nil, err } dataMap, ok := resp["data"].(map[string]interface{}) if !ok { - return "", "", "", fmt.Errorf("unexpected anon token response: %v", resp) + return nil, fmt.Errorf("unexpected anon token response: %v", resp) } token1, ok := dataMap["access_token"].(string) if !ok { - return "", "", "", fmt.Errorf("missing access_token in response: %v", resp) + return nil, fmt.Errorf("missing access_token in response: %v", resp) } vkDelayRandom(100, 150) - // Token 1 -> getCallPreview + // Step 2: getCallPreview (best-effort) data = fmt.Sprintf("vk_join_link=https://vk.com/call/join/%s&fields=photo_200&access_token=%s", link, token1) - _, err = doRequest(data, "https://api.vk.ru/method/calls.getCallPreview?v=5.275&client_id="+creds.ClientID) + _, err = vkDoRequest(ctx, client, profile, data, "https://api.vk.ru/method/calls.getCallPreview?v=5.275&client_id="+creds.ClientID) if err != nil { log.Printf("[STREAM %d] [VK Auth] Warning: getCallPreview failed: %v", streamID, err) } vkDelayRandom(200, 400) - // Token 2 + // Step 3: getAnonymousToken (captcha-gated) data = fmt.Sprintf("vk_join_link=https://vk.com/call/join/%s&name=%s&access_token=%s", link, escapedName, token1) urlAddr := fmt.Sprintf("https://api.vk.ru/method/calls.getAnonymousToken?v=5.275&client_id=%s", creds.ClientID) + exhaustedCaptcha := func() error { + globalCaptchaLockout.Store(time.Now().Add(60 * time.Second).Unix()) + if connectedStreams.Load() == 0 { + log.Printf("[STREAM %d] [FATAL] 0 connected streams and captcha solve modes exhausted.", streamID) + return fmt.Errorf("FATAL_CAPTCHA_FAILED_NO_STREAMS") + } + return fmt.Errorf("CAPTCHA_WAIT_REQUIRED") + } + var token2 string for attempt := 0; ; attempt++ { - resp, err = doRequest(data, urlAddr) + resp, err = vkDoRequest(ctx, client, profile, data, urlAddr) if err != nil { - return "", "", "", err + return nil, err } if errObj, hasErr := resp["error"].(map[string]interface{}); hasErr { @@ -919,16 +1172,7 @@ func getTokenChain(ctx context.Context, link string, streamID int, creds VKCrede solveMode, hasSolveMode := captchaSolveModeForAttempt(attempt, manualCaptcha, autoCaptchaSliderPOC) if !hasSolveMode { log.Printf("[STREAM %d] [Captcha] No more solve modes available (attempt %d)", streamID, attempt+1) - - // Engage global lockout to protect API - globalCaptchaLockout.Store(time.Now().Add(60 * time.Second).Unix()) - - if connectedStreams.Load() == 0 { - log.Printf("[STREAM %d] [FATAL] 0 connected streams and captcha solve modes exhausted.", streamID) - return "", "", "", fmt.Errorf("FATAL_CAPTCHA_FAILED_NO_STREAMS") - } - - return "", "", "", fmt.Errorf("CAPTCHA_WAIT_REQUIRED") + return nil, exhaustedCaptcha() } var successToken string @@ -956,7 +1200,9 @@ func getTokenChain(ctx context.Context, link string, streamID int, creds VKCrede } case captchaSolveModeManual: log.Printf("[STREAM %d] [Captcha] Triggering manual captcha fallback...", streamID) - manualCtx, manualCancel := context.WithTimeout(ctx, 60*time.Second) + // Manual solve waits on a human; keep generous timeout + // independent of any auth-level deadline. + manualCtx, manualCancel := context.WithTimeout(context.Background(), 3*time.Minute) type manualRes struct { token string @@ -969,7 +1215,7 @@ func getTokenChain(ctx context.Context, link string, streamID int, creds VKCrede var t, k string var e error if captchaErr.RedirectURI != "" { - t, e = solveCaptchaViaProxy(captchaErr.RedirectURI, dialer) + t, e = solveCaptchaViaProxy(captchaErr.RedirectURI) } else if captchaErr.CaptchaImg != "" { k, e = solveCaptchaViaHTTP(captchaErr.CaptchaImg) } else { @@ -980,16 +1226,24 @@ func getTokenChain(ctx context.Context, link string, streamID int, creds VKCrede select { case res := <-resCh: - successToken = res.token - captchaKey = res.key - solveErr = res.err + // Token can arrive even when err != nil (e.g. server + // Shutdown timeout after the token was already received). + // A non-empty token/key counts as success. + if res.token != "" || res.key != "" { + successToken = res.token + captchaKey = res.key + if res.err != nil { + log.Printf("[STREAM %d] [Captcha] Token received (ignoring cleanup error: %v)", streamID, res.err) + } + } else { + solveErr = res.err + } case <-manualCtx.Done(): - solveErr = fmt.Errorf("manual captcha timed out after 60s") + solveErr = fmt.Errorf("manual captcha timed out after 3m") } manualCancel() } - // If solving failed (auto or manual) or timed out if solveErr != nil { log.Printf("[STREAM %d] [Captcha] %s failed (attempt %d): %v", streamID, captchaSolveModeLabel(solveMode), attempt+1, solveErr) @@ -998,17 +1252,7 @@ func getTokenChain(ctx context.Context, link string, streamID int, creds VKCrede log.Printf("[STREAM %d] [Captcha] Falling back to %s...", streamID, captchaSolveModeLabel(nextSolveMode)) continue } - - // Engage global lockout to protect API - globalCaptchaLockout.Store(time.Now().Add(60 * time.Second).Unix()) - - // If we have 0 streams alive, this is fatal - if connectedStreams.Load() == 0 { - log.Printf("[STREAM %d] [FATAL] 0 connected streams and manual captcha failed/timed out.", streamID) - return "", "", "", fmt.Errorf("FATAL_CAPTCHA_FAILED_NO_STREAMS") - } - - return "", "", "", fmt.Errorf("CAPTCHA_WAIT_REQUIRED") + return nil, exhaustedCaptcha() } if captchaErr.CaptchaAttempt == "0" || captchaErr.CaptchaAttempt == "" { @@ -1024,26 +1268,41 @@ func getTokenChain(ctx context.Context, link string, streamID int, creds VKCrede } continue } - return "", "", "", fmt.Errorf("VK API error: %v", errObj) + return nil, fmt.Errorf("VK API error: %v", errObj) } respMap, okLoop := resp["response"].(map[string]interface{}) if !okLoop { - return "", "", "", fmt.Errorf("unexpected getAnonymousToken response: %v", resp) + return nil, fmt.Errorf("unexpected getAnonymousToken response: %v", resp) } token2, okLoop = respMap["token"].(string) if !okLoop { - return "", "", "", fmt.Errorf("missing token in response: %v", resp) + return nil, fmt.Errorf("missing token in response: %v", resp) } break } - vkDelayRandom(100, 150) - - // Token 3 + return &vkIdentity{ + creds: creds, + profile: profile, + name: name, + token1: token1, + token2: token2, + client: client, + expiresAt: time.Now().Add(identityLifetime), + }, nil +} + +// acquireVkTurnSlot runs the lightweight portion of the chain (steps 4-5): +// auth.anonymLogin (with a fresh device_id) followed by vchat.joinConversationByLink. +// Each call returns a distinct (username, password) pair from VK, which lets us +// run multiple parallel TURN allocations under the same identity — bypassing +// per-username throttling without re-solving captcha. +func acquireVkTurnSlot(ctx context.Context, link string, streamID int, ident *vkIdentity) (string, string, string, error) { + // Step 4: auth.anonymLogin with fresh device_id → fresh session_key sessionData := fmt.Sprintf(`{"version":2,"device_id":"%s","client_version":1.1,"client_type":"SDK_JS"}`, uuid.New()) - data = fmt.Sprintf("session_data=%s&method=auth.anonymLogin&format=JSON&application_key=CGMMEJLGDIHBABABA", neturl.QueryEscape(sessionData)) - resp, err = doRequest(data, "https://calls.okcdn.ru/fb.do") + data := fmt.Sprintf("session_data=%s&method=auth.anonymLogin&format=JSON&application_key=CGMMEJLGDIHBABABA", neturl.QueryEscape(sessionData)) + resp, err := vkDoRequest(ctx, ident.client, ident.profile, data, "https://calls.okcdn.ru/fb.do") if err != nil { return "", "", "", err } @@ -1054,9 +1313,9 @@ func getTokenChain(ctx context.Context, link string, streamID int, creds VKCrede vkDelayRandom(100, 150) - // Token 4 -> TURN Creds - data = fmt.Sprintf("joinLink=%s&isVideo=false&protocolVersion=5&capabilities=2F7F&anonymToken=%s&method=vchat.joinConversationByLink&format=JSON&application_key=CGMMEJLGDIHBABABA&session_key=%s", link, token2, token3) - resp, err = doRequest(data, "https://calls.okcdn.ru/fb.do") + // Step 5: vchat.joinConversationByLink → turn_server creds + data = fmt.Sprintf("joinLink=%s&isVideo=false&protocolVersion=5&capabilities=2F7F&anonymToken=%s&method=vchat.joinConversationByLink&format=JSON&application_key=CGMMEJLGDIHBABABA&session_key=%s", link, ident.token2, token3) + resp, err = vkDoRequest(ctx, ident.client, ident.profile, data, "https://calls.okcdn.ru/fb.do") if err != nil { return "", "", "", err } @@ -1077,11 +1336,44 @@ func getTokenChain(ctx context.Context, link string, streamID int, creds VKCrede if !ok || len(urlsRaw) == 0 { return "", "", "", fmt.Errorf("missing or empty urls in turn_server") } - urlStr, ok := urlsRaw[0].(string) - if !ok { - return "", "", "", fmt.Errorf("turn server url is not a string") + if isDebug { + log.Printf("[STREAM %d] [VK Auth] turn_server urls: %v", streamID, urlsRaw) + } + + // Prefer URLs whose transport matches the requested mode (udpMode). + // Per RFC 7065, "?transport=tcp" → TCP, missing or "transport=udp" → UDP. + // Fall back to the full list if nothing matches — this preserves the + // -port override path where the user intentionally dials a port not + // advertised in the URL list. + all := make([]string, 0, len(urlsRaw)) + preferred := make([]string, 0, len(urlsRaw)) + for _, raw := range urlsRaw { + s, ok := raw.(string) + if !ok { + continue + } + all = append(all, s) + isTCP := strings.Contains(s, "transport=tcp") + if udpMode == !isTCP { + preferred = append(preferred, s) + } + } + if len(all) == 0 { + return "", "", "", fmt.Errorf("turn_server urls list contained no strings: %v", urlsRaw) + } + + pool := preferred + if len(pool) == 0 { + pool = all + log.Printf("[STREAM %d] [VK Auth] no urls match transport (udp=%v), falling back to full list (relying on -port override). urls=%v", streamID, udpMode, all) } + // Round-robin within the identity. streamID%len(pool) collapses every + // stream of the identity onto the same parity, so use a counter instead. + urlIdx := int(ident.urlCounter.Add(1)-1) % len(pool) + urlStr := pool[urlIdx] + log.Printf("[STREAM %d] [VK Auth] turn_server urls=%d (preferred=%d), picked[%d]: %s", streamID, len(all), len(preferred), urlIdx, urlStr) + clean := strings.Split(urlStr, "?")[0] address := strings.TrimPrefix(strings.TrimPrefix(clean, "turn:"), "turns:") @@ -1209,10 +1501,12 @@ func getYandexCreds(link string) (string, string, string, error) { } endpoint := "https://" + telemostConfHost + telemostConfPath + appD := appDialer() tr := &http.Transport{ MaxIdleConns: 100, MaxIdleConnsPerHost: 100, IdleConnTimeout: 90 * time.Second, + DialContext: appD.DialContext, } client := &http.Client{ Timeout: 20 * time.Second, @@ -1264,7 +1558,8 @@ func getYandexCreds(link string) (string, string, string, error) { ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() - dialer := websocket.Dialer{} + wsAppD := appDialer() + dialer := websocket.Dialer{NetDialContext: wsAppD.DialContext} var conn *websocket.Conn conn, resp, err = dialer.DialContext(ctx, data.Wss, h) if err != nil { @@ -1356,7 +1651,8 @@ func getYandexCreds(link string) (string, string, string, error) { return "", "", "", fmt.Errorf("ws set read deadline: %w", err) } - for { + const maxWsMessages = 64 + for i := 0; i < maxWsMessages; i++ { _, msg, err := conn.ReadMessage() if err != nil { return "", "", "", fmt.Errorf("ws read: %w", err) @@ -1393,6 +1689,7 @@ func getYandexCreds(link string) (string, string, string, error) { } } } + return "", "", "", fmt.Errorf("ws read: no serverHello with TURN urls after %d messages", maxWsMessages) } func dtlsFunc(ctx context.Context, conn net.PacketConn, peer *net.UDPAddr) (net.Conn, error) { @@ -1430,7 +1727,7 @@ func dtlsFunc(ctx context.Context, conn net.PacketConn, peer *net.UDPAddr) (net. } func oneDtlsConnection(ctx context.Context, peer *net.UDPAddr, listenConn net.PacketConn, inboundChan <-chan *UDPPacket, connchan chan<- net.PacketConn, okchan chan<- struct{}, streamID int) error { - time.Sleep(time.Duration(rand.Intn(400)+100) * time.Millisecond) + time.Sleep(time.Duration(rand.Intn(100)+30) * time.Millisecond) dtlsctx, dtlscancel := context.WithCancel(ctx) defer dtlscancel() @@ -1491,20 +1788,31 @@ func oneDtlsConnection(ctx context.Context, peer *net.UDPAddr, listenConn net.Pa defer wg.Done() defer dtlscancel() buf := make([]byte, 1600) + var cachedAddr net.Addr + var cachedPtr any for { n, err1 := dtlsConn.Read(buf) if err1 != nil { return } - // Send back to the active WG client - if peerAddr := activeLocalPeer.Load(); peerAddr != nil { - if addr, ok := peerAddr.(net.Addr); ok { - if _, err := listenConn.WriteTo(buf[:n], addr); err != nil { - log.Printf("[STREAM %d] failed to forward packet to local peer: %v", streamID, err) - } + // Send back to the active WG client. Cache addr locally — only + // re-resolve when atomic.Value pointer changes (rare). + peerAddr := activeLocalPeer.Load() + if peerAddr == nil { + continue + } + if peerAddr != cachedPtr { + if a, ok := peerAddr.(net.Addr); ok { + cachedAddr = a + cachedPtr = peerAddr + } else { + continue } } + if _, err := listenConn.WriteTo(buf[:n], cachedAddr); err != nil { + log.Printf("[STREAM %d] failed to forward packet to local peer: %v", streamID, err) + } } }() @@ -1523,6 +1831,58 @@ func (c *connectedUDPConn) WriteTo(p []byte, _ net.Addr) (int, error) { return c.Write(p) } +type countingConn struct { + net.Conn + written atomic.Int64 + read atomic.Int64 +} + +func (c *countingConn) Read(p []byte) (int, error) { + n, err := c.Conn.Read(p) + if n > 0 { + c.read.Add(int64(n)) + } + return n, err +} + +func (c *countingConn) Write(p []byte) (int, error) { + n, err := c.Conn.Write(p) + if n > 0 { + c.written.Add(int64(n)) + } + return n, err +} + +func classifyNetErr(err error) string { + if err == nil { + return "nil" + } + if errors.Is(err, context.DeadlineExceeded) { + return "ctx-deadline" + } + if errors.Is(err, io.EOF) { + return "eof" + } + if errors.Is(err, syscall.ECONNRESET) { + return "rst" + } + if errors.Is(err, syscall.ECONNREFUSED) { + return "refused" + } + if errors.Is(err, syscall.EPIPE) { + return "broken-pipe" + } + var ne net.Error + if errors.As(err, &ne) && ne.Timeout() { + return "net-timeout" + } + var oe *net.OpError + if errors.As(err, &oe) { + return "op:" + oe.Op + } + return "other" +} + type turnParams struct { host string port string @@ -1531,8 +1891,129 @@ type turnParams struct { getCreds getCredsFunc } +// turnAllocation bundles a single TURN session: dial socket, TURN client, relay PacketConn. +type turnAllocation struct { + dialConn io.Closer + client *turn.Client + relay net.PacketConn +} + +func (a *turnAllocation) close() { + if a.relay != nil { + _ = a.relay.Close() + } + if a.client != nil { + a.client.Close() + } + if a.dialConn != nil { + _ = a.dialConn.Close() + } +} + +// dialTurn opens a fresh TURN session under the given (user, pass). Each call +// produces an independent 5-tuple (own source UDP/TCP port) and an independent +// TURN allocation. VK may or may not allow multiple allocations under the same +// credentials — caller must tolerate failures on additional sessions. +func dialTurn(ctx context.Context, useUDP bool, turnServerAddr string, turnServerUDPAddr *net.UDPAddr, addrFamily turn.RequestedAddressFamily, user, pass string, streamID int) (*turnAllocation, error) { + var dialCloser io.Closer + var turnConn net.PacketConn + if useUDP { + conn, err := net.DialUDP("udp", nil, turnServerUDPAddr) + if err != nil { + return nil, fmt.Errorf("failed to connect to TURN server: %w", err) + } + dialCloser = conn + turnConn = &connectedUDPConn{conn} + } else { + ctx1, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + var d net.Dialer + conn, err := d.DialContext(ctx1, "tcp", turnServerAddr) + if err != nil { + log.Printf("[STREAM %d] [TURN] tcp dial %s failed: class=%s err=%v", + streamID, turnServerAddr, classifyNetErr(err), err) + return nil, fmt.Errorf("failed to connect to TURN server: %w", err) + } + if isDebug { + log.Printf("[STREAM %d] [TURN] tcp established %s -> %s", + streamID, conn.LocalAddr(), conn.RemoteAddr()) + } + dialCloser = conn + turnConn = turn.NewSTUNConn(&countingConn{Conn: conn}) + } + + cfg := &turn.ClientConfig{ + STUNServerAddr: turnServerAddr, + TURNServerAddr: turnServerAddr, + Conn: turnConn, + Net: newDirectNet(), + Username: user, + Password: pass, + RequestedAddressFamily: addrFamily, + LoggerFactory: logging.NewDefaultLoggerFactory(), + } + + client, err := turn.NewClient(cfg) + if err != nil { + _ = dialCloser.Close() + return nil, fmt.Errorf("failed to create TURN client: %w", err) + } + + if err := client.Listen(); err != nil { + client.Close() + _ = dialCloser.Close() + return nil, fmt.Errorf("failed to listen: %w", err) + } + + relay, err := client.Allocate() + if err != nil { + client.Close() + _ = dialCloser.Close() + return nil, fmt.Errorf("failed to allocate: %w", err) + } + + return &turnAllocation{dialConn: dialCloser, client: client, relay: relay}, nil +} + +// relayPool is a concurrent ring of live relay PacketConns. Reads (pick) are +// fully lock-free via atomic.Pointer to a snapshot slice (copy-on-write). +// add is rare and pays the alloc cost. +type relayPool struct { + relays atomic.Pointer[[]net.PacketConn] + addMu sync.Mutex + counter atomic.Uint64 +} + +func (p *relayPool) add(r net.PacketConn) { + p.addMu.Lock() + defer p.addMu.Unlock() + cur := p.relays.Load() + var next []net.PacketConn + if cur != nil { + next = make([]net.PacketConn, len(*cur)+1) + copy(next, *cur) + next[len(*cur)] = r + } else { + next = []net.PacketConn{r} + } + p.relays.Store(&next) +} + +func (p *relayPool) pick() net.PacketConn { + cur := p.relays.Load() + if cur == nil { + return nil + } + n := len(*cur) + if n == 0 { + return nil + } + idx := int(p.counter.Add(1)-1) % n + return (*cur)[idx] +} + func oneTurnConnection(ctx context.Context, turnParams *turnParams, peer *net.UDPAddr, conn2 net.PacketConn, streamID int, c chan<- error) { - time.Sleep(time.Duration(rand.Intn(400)+100) * time.Millisecond) + time.Sleep(time.Duration(rand.Intn(100)+30) * time.Millisecond) var err error defer func() { c <- err }() user, pass, urlTarget, err1 := turnParams.getCreds(ctx, turnParams.link, streamID) @@ -1551,8 +2032,8 @@ func oneTurnConnection(ctx context.Context, turnParams *turnParams, peer *net.UD if turnParams.port != "" { urlport = turnParams.port } - var turnServerAddr string - turnServerAddr = net.JoinHostPort(urlhost, urlport) + turnServerAddr := net.JoinHostPort(urlhost, urlport) + log.Printf("[STREAM %d] [TURN] dialing %s (udp=%v)", streamID, turnServerAddr, turnParams.udp) turnServerUDPAddr, err1 := net.ResolveUDPAddr("udp", turnServerAddr) if err1 != nil { err = fmt.Errorf("failed to resolve TURN server address: %s", err1) @@ -1560,38 +2041,7 @@ func oneTurnConnection(ctx context.Context, turnParams *turnParams, peer *net.UD } turnServerAddr = turnServerUDPAddr.String() fmt.Println(turnServerUDPAddr.IP) - var cfg *turn.ClientConfig - var turnConn net.PacketConn - var d net.Dialer - ctx1, cancel := context.WithTimeout(ctx, 5*time.Second) - defer cancel() - if turnParams.udp { - conn, err2 := net.DialUDP("udp", nil, turnServerUDPAddr) // nolint: noctx - if err2 != nil { - err = fmt.Errorf("failed to connect to TURN server: %s", err2) - return - } - defer func() { - if err1 = conn.Close(); err1 != nil { - err = fmt.Errorf("failed to close TURN server connection: %s", err1) - return - } - }() - turnConn = &connectedUDPConn{conn} - } else { - conn, err2 := d.DialContext(ctx1, "tcp", turnServerAddr) - if err2 != nil { - err = fmt.Errorf("failed to connect to TURN server: %s", err2) - return - } - defer func() { - if err1 = conn.Close(); err1 != nil { - err = fmt.Errorf("failed to close TURN server connection: %s", err1) - return - } - }() - turnConn = turn.NewSTUNConn(conn) - } + var addrFamily turn.RequestedAddressFamily if peer.IP.To4() != nil { addrFamily = turn.RequestedAddressFamilyIPv4 @@ -1599,66 +2049,83 @@ func oneTurnConnection(ctx context.Context, turnParams *turnParams, peer *net.UD addrFamily = turn.RequestedAddressFamilyIPv6 } - cfg = &turn.ClientConfig{ - STUNServerAddr: turnServerAddr, - TURNServerAddr: turnServerAddr, - Conn: turnConn, - Net: newDirectNet(), - Username: user, - Password: pass, - RequestedAddressFamily: addrFamily, - LoggerFactory: logging.NewDefaultLoggerFactory(), - } - - client, err1 := turn.NewClient(cfg) - if err1 != nil { - err = fmt.Errorf("failed to create TURN client: %s", err1) - return - } - defer client.Close() - - err1 = client.Listen() - if err1 != nil { - err = fmt.Errorf("failed to listen: %s", err1) - return - } - - relayConn, err1 := client.Allocate() + primary, err1 := dialTurn(ctx, turnParams.udp, turnServerAddr, turnServerUDPAddr, addrFamily, user, pass, streamID) if err1 != nil { if isAuthError(err1) { handleAuthError(streamID) } - err = fmt.Errorf("failed to allocate: %s", err1) + err = err1 return } - // Reset error count on successful allocation getStreamCache(streamID).errorCount.Store(0) - // Safely track active streams globally connectedStreams.Add(1) - defer func() { - connectedStreams.Add(-1) - if err1 := relayConn.Close(); err1 != nil { - err = fmt.Errorf("failed to close TURN allocated connection: %s", err1) - } - }() + defer connectedStreams.Add(-1) if isDebug { - log.Printf("[STREAM %d] relayed-address=%s", streamID, relayConn.LocalAddr().String()) + log.Printf("[STREAM %d] relayed-address=%s", streamID, primary.relay.LocalAddr().String()) } - wg := sync.WaitGroup{} - wg.Add(1) + pool := &relayPool{} + pool.add(primary.relay) + turnctx, turncancel := context.WithCancel(ctx) + defer turncancel() + + // Track all allocations for clean shutdown. + allocs := []*turnAllocation{primary} + var allocsMu sync.Mutex + defer func() { + allocsMu.Lock() + toClose := allocs + allocs = nil + allocsMu.Unlock() + for _, a := range toClose { + a.close() + } + }() + context.AfterFunc(turnctx, func() { - if err := relayConn.SetDeadline(time.Now()); err != nil { - log.Printf("Failed to set relay deadline: %s", err) + allocsMu.Lock() + defer allocsMu.Unlock() + for _, a := range allocs { + if a.relay != nil { + _ = a.relay.SetDeadline(time.Now()) + } } - // Do not set conn2 deadline (conn2 can sometimes be listenConn if direct mode is used) }) + var internalPipeAddr atomic.Value + // Per-relay inbound goroutine: read from its own relay, forward to conn2. + var inboundWg sync.WaitGroup + spawnInbound := func(relay net.PacketConn) { + inboundWg.Add(1) + go func() { + defer inboundWg.Done() + defer turncancel() + buf := make([]byte, 1600) + for { + n, _, err1 := relay.ReadFrom(buf) + if err1 != nil { + return + } + addr1 := internalPipeAddr.Load() + if addr1 == nil { + continue + } + if addr, ok := addr1.(net.Addr); ok { + if _, err := conn2.WriteTo(buf[:n], addr); err != nil { + return + } + } + } + }() + } + spawnInbound(primary.relay) + + // Outbound: read from conn2, send via round-robin across the relay pool. go func() { defer turncancel() buf := make([]byte, 1600) @@ -1673,42 +2140,51 @@ func oneTurnConnection(ctx context.Context, turnParams *turnParams, peer *net.UD if turnctx.Err() != nil { return } - internalPipeAddr.Store(addr1) - _, err1 = relayConn.WriteTo(buf[:n], peer) - if err1 != nil { + r := pool.pick() + if r == nil { + return + } + if _, err1 = r.WriteTo(buf[:n], peer); err1 != nil { return } } }() - go func() { - defer wg.Done() - defer turncancel() - buf := make([]byte, 1600) - for { - n, _, err1 := relayConn.ReadFrom(buf) - if err1 != nil { + // Open extra allocations under the same creds. DTLS handshake completes + // over the primary first; deferring extras lets the server install the + // Connection ID so subsequent multi-path packets are matched to the + // existing session via CID rather than 5-tuple. + extras := allocsPerStream - 1 + if extras > 0 { + go func() { + select { + case <-turnctx.Done(): return + case <-time.After(1 * time.Second): } - addr1 := internalPipeAddr.Load() - if addr1 == nil { - continue - } - - if addr, ok := addr1.(net.Addr); ok { - if _, err := conn2.WriteTo(buf[:n], addr); err != nil { + for i := 0; i < extras; i++ { + if turnctx.Err() != nil { return } + extra, err := dialTurn(ctx, turnParams.udp, turnServerAddr, turnServerUDPAddr, addrFamily, user, pass, streamID) + if err != nil { + log.Printf("[STREAM %d] [TURN] extra alloc %d/%d failed: %v", streamID, i+1, extras, err) + continue + } + log.Printf("[STREAM %d] [TURN] extra alloc %d/%d OK relay=%s", streamID, i+1, extras, extra.relay.LocalAddr()) + allocsMu.Lock() + allocs = append(allocs, extra) + allocsMu.Unlock() + pool.add(extra.relay) + spawnInbound(extra.relay) + time.Sleep(200 * time.Millisecond) } - } - }() - - wg.Wait() - if err := relayConn.SetDeadline(time.Time{}); err != nil { - log.Printf("Failed to clear relay deadline: %s", err) + }() } + + inboundWg.Wait() } func oneDtlsConnectionLoop(ctx context.Context, peer *net.UDPAddr, listenConn net.PacketConn, inboundChan <-chan *UDPPacket, connchan chan<- net.PacketConn, okchan chan<- struct{}, streamID int) { @@ -1722,10 +2198,8 @@ func oneDtlsConnectionLoop(ctx context.Context, peer *net.UDPAddr, listenConn ne if time.Now().Unix() < globalCaptchaLockout.Load() && strings.Contains(err.Error(), "context deadline exceeded") { continue } - select { - case <-ctx.Done(): + if !sleepCtx(ctx, time.Duration(10+rand.Intn(20))*time.Second) { return - case <-time.After(time.Duration(10+rand.Intn(20)) * time.Second): } } } @@ -1757,10 +2231,8 @@ func oneTurnConnectionLoop(ctx context.Context, turnParams *turnParams, peer *ne if strings.Contains(err.Error(), "CAPTCHA_WAIT_REQUIRED") { if !strings.Contains(err.Error(), "global lockout active") { log.Printf("[STREAM %d] Backing off for 60 seconds to avoid IP ban...", streamID) - select { - case <-ctx.Done(): + if !sleepCtx(ctx, 60*time.Second) { return - case <-time.After(60 * time.Second): } } else { lockoutEnd := globalCaptchaLockout.Load() @@ -1768,10 +2240,8 @@ func oneTurnConnectionLoop(ctx context.Context, turnParams *turnParams, peer *ne if sleepDuration < 0 { sleepDuration = 5 * time.Second } - select { - case <-ctx.Done(): + if !sleepCtx(ctx, sleepDuration) { return - case <-time.After(sleepDuration): } } } else { @@ -1812,7 +2282,21 @@ func main() { vlessMode := flag.Bool("vless", false, "VLESS mode: forward TCP connections (for VLESS) instead of UDP packets") debugFlag := flag.Bool("debug", false, "enable debug logging") manualCaptchaFlag := flag.Bool("manual-captcha", false, "skip auto captcha solving, use manual mode immediately") + dnsFlag := flag.String("dns", DNSModeAuto, "DNS resolution mode: udp | doh | auto (auto tries UDP/53 first, sticky-fallback to DoH on total failure)") + allocsFlag := flag.Int("allocs-per-stream", 1, "open this many TURN allocations per stream under shared creds (only useful if VK throttles per-allocation)") + handshakeConc := flag.Int("handshake-concurrency", 8, "max concurrent DTLS handshakes") flag.Parse() + if *handshakeConc < 1 { + *handshakeConc = 1 + } + handshakeSem = make(chan struct{}, *handshakeConc) + switch *dnsFlag { + case DNSModeUDP, DNSModeDoH, DNSModeAuto: + dnsMode = *dnsFlag + default: + log.Panicf("invalid -dns value %q (expected udp|doh|auto)", *dnsFlag) + } + log.Printf("[DNS] mode=%s", dnsMode) if *peerAddr == "" { log.Panicf("Need peer address!") } @@ -1827,6 +2311,13 @@ func main() { isDebug = *debugFlag manualCaptcha = *manualCaptchaFlag autoCaptchaSliderPOC = !manualCaptcha + allocsPerStream = *allocsFlag + if allocsPerStream < 1 { + allocsPerStream = 1 + } + udpMode = *udp + + startIdentityJanitor(ctx, 5*time.Minute) var link string var getCreds getCredsFunc @@ -1834,14 +2325,8 @@ func main() { parts := strings.Split(*vklink, "join/") link = parts[len(parts)-1] - dialer := dnsdialer.New( - dnsdialer.WithResolvers("77.88.8.8:53", "77.88.8.1:53", "8.8.8.8:53", "8.8.4.4:53", "1.1.1.1:53", "1.0.0.1:53"), - dnsdialer.WithStrategy(dnsdialer.Fallback{}), - dnsdialer.WithCache(100, 10*time.Hour, 10*time.Hour), - ) - getCreds = func(ctx context.Context, s string, streamID int) (string, string, string, error) { - return getVkCredsCached(ctx, s, streamID, dialer) + return getVkCredsCached(ctx, s, streamID) } if *n <= 0 { *n = 10 @@ -1889,9 +2374,24 @@ func main() { } // Shared Worker Pool Queue for Aggregation - inboundChan := make(chan *UDPPacket, 2000) + inboundChan := make(chan *UDPPacket, 8192) + var droppedPkts atomic.Uint64 go func() { + dropTicker := time.NewTicker(5 * time.Second) + defer dropTicker.Stop() + var lastDropped uint64 + for range dropTicker.C { + cur := droppedPkts.Load() + if cur != lastDropped { + log.Printf("[inbound] dropped %d pkts since start (delta=%d) — queue saturated", cur, cur-lastDropped) + lastDropped = cur + } + } + }() + + go func() { + var lastAddrStr string for { pktIface := packetPool.Get() pkt, ok := pktIface.(*UDPPacket) @@ -1904,16 +2404,12 @@ func main() { return } - // Save the local WireGuard peer address - current := activeLocalPeer.Load() - if current == nil { - activeLocalPeer.Store(addr) - } else if addrStr, ok := current.(net.Addr); ok { - if addrStr.String() != addr.String() { - activeLocalPeer.Store(addr) - } - } else { + // Save local WireGuard peer addr; cache string to avoid repeated + // type-assert + String() in hot path. + s := addr.String() + if s != lastAddrStr { activeLocalPeer.Store(addr) + lastAddrStr = s } pkt.N = nRead @@ -1921,14 +2417,14 @@ func main() { select { case inboundChan <- pkt: default: - // Drop the packet only if the global queue is completely full + droppedPkts.Add(1) packetPool.Put(pkt) } } }() wg1 := sync.WaitGroup{} - t := time.Tick(200 * time.Millisecond) + t := time.Tick(100 * time.Millisecond) if *direct { log.Panicf("Direct mode not supported with dispatcher") @@ -1952,7 +2448,7 @@ func main() { case <-ctx.Done(): } - for i := 1; i < numStreams; i++ { + for i := 2; i <= numStreams; i++ { cchan := make(chan net.PacketConn) wg1.Add(1) go func(streamID int) { @@ -1969,45 +2465,57 @@ func main() { wg1.Wait() } -// sessionPool manages a pool of smux sessions for round-robin TCP distribution. +// sessionPool manages smux sessions for round-robin TCP distribution. +// Lock-free reads via atomic.Pointer copy-on-write snapshot. type sessionPool struct { - mu sync.RWMutex - sessions []*smux.Session + sessions atomic.Pointer[[]*smux.Session] + mu sync.Mutex counter atomic.Uint64 } +func (p *sessionPool) snapshot() []*smux.Session { + cur := p.sessions.Load() + if cur == nil { + return nil + } + return *cur +} + func (p *sessionPool) add(s *smux.Session) { p.mu.Lock() - p.sessions = append(p.sessions, s) - p.mu.Unlock() + defer p.mu.Unlock() + cur := p.snapshot() + next := make([]*smux.Session, len(cur)+1) + copy(next, cur) + next[len(cur)] = s + p.sessions.Store(&next) } func (p *sessionPool) remove(s *smux.Session) { p.mu.Lock() - for i, sess := range p.sessions { - if sess == s { - p.sessions = append(p.sessions[:i], p.sessions[i+1:]...) - break + defer p.mu.Unlock() + cur := p.snapshot() + next := make([]*smux.Session, 0, len(cur)) + for _, sess := range cur { + if sess != s { + next = append(next, sess) } } - p.mu.Unlock() + p.sessions.Store(&next) } func (p *sessionPool) pick() *smux.Session { - p.mu.RLock() - defer p.mu.RUnlock() - n := len(p.sessions) + cur := p.snapshot() + n := len(cur) if n == 0 { return nil } idx := p.counter.Add(1) % uint64(n) - return p.sessions[idx] + return cur[idx] } func (p *sessionPool) count() int { - p.mu.RLock() - defer p.mu.RUnlock() - return len(p.sessions) + return len(p.snapshot()) } // runVLESSMode implements TCP forwarding with round-robin across N TURN sessions. @@ -2023,7 +2531,7 @@ func runVLESSMode(ctx context.Context, tp *turnParams, peer *net.UDPAddr, listen select { case <-ctx.Done(): return - case <-time.After(time.Duration(id) * 300 * time.Millisecond): + case <-time.After(time.Duration(id) * 100 * time.Millisecond): } maintainVLESSSession(ctx, tp, peer, id, pool) }(i) diff --git a/client/manual_captcha.go b/client/manual_captcha.go index 826478c..0bb57c0 100644 --- a/client/manual_captcha.go +++ b/client/manual_captcha.go @@ -14,15 +14,44 @@ import ( "net/http/httputil" neturl "net/url" "os/exec" + "regexp" "runtime" + "sort" "strings" "time" - - "github.com/bschaatsbergen/dnsdialer" ) const captchaListenPort = "8765" +// redactSensitiveQueryRe matches sensitive token/hash params in form bodies and +// query strings. Replaced with "" so logs reveal presence and length +// without exposing the JWT itself. +var redactSensitiveQueryRe = regexp.MustCompile(`(?i)\b(session_token|access_token|success_token|hash|debug_info|browser_fp)=([^&\s]*)`) + +var redactCookieValueRe = regexp.MustCompile(`(remix[a-z]+|prcl|domain_sid)=([^;\s]+)`) + +func redactBodyForLog(s string) string { + return redactSensitiveQueryRe.ReplaceAllStringFunc(s, func(m string) string { + groups := redactSensitiveQueryRe.FindStringSubmatch(m) + if len(groups) < 3 { + return m + } + return groups[1] + "=" + }) +} + +func redactHeaderForLog(name, value string) string { + switch strings.ToLower(name) { + case "cookie", "set-cookie": + return redactCookieValueRe.ReplaceAllString(value, "$1=") + case "referer", "origin", "location": + return redactBodyForLog(value) + case "authorization", "proxy-authorization": + return "" + } + return value +} + type browserCommand struct { name string args []string @@ -125,7 +154,23 @@ func rewriteProxyRequest(req *http.Request, targetURL *neturl.URL) { req.Host = targetURL.Host req.Header.Del("Accept-Encoding") - req.Header.Del("TE") // Disable transfer encoding compression + req.Header.Del("TE") + // Strip WebView identity / fingerprint leak headers. Android WebView + // auto-injects X-Requested-With with the host package name, which would + // reveal the proxy app to VK. + for _, h := range []string{ + "X-Requested-With", + "X-Android-Package", + "X-Android-Cert", + "X-Client-Data", + "X-Discord-Locale", + "X-Discord-Timezone", + "Save-Data", + "Purpose", + "Sec-Purpose", + } { + req.Header.Del(h) + } for _, headerName := range []string{"Origin", "Referer"} { if rewritten := rewriteProxyHeaderURL(req.Header.Get(headerName), targetURL); rewritten != "" { req.Header.Set(headerName, rewritten) @@ -164,10 +209,84 @@ func rewriteProxyCookies(header http.Header) { } } +var htmlURLAttrDoubleRe = regexp.MustCompile(`(?i)((?:src|href|action)\s*=\s*)"((?:https?:)?//[^"]+)"`) +var htmlURLAttrSingleRe = regexp.MustCompile(`(?i)((?:src|href|action)\s*=\s*)'((?:https?:)?//[^']+)'`) + +var ( + scriptBlockRe = regexp.MustCompile(`(?is)]*>.*?`) + styleBlockRe = regexp.MustCompile(`(?is)]*>.*?`) +) + +// rewriteHTMLAttrsServerSide rewrites absolute and protocol-relative URLs in +// src/href/action attributes of raw HTML. URLs matching the upstream origin go +// to localhost; other absolute URLs are routed through /generic_proxy. Skips +// `, localOrigin, upstreamOrigin) + // Inject as early as possible — at the opening tag — so XHR/fetch + // overrides are active before any inline