|
|
|
@ -3,9 +3,15 @@ package gost |
|
|
|
import ( |
|
|
|
"bytes" |
|
|
|
"context" |
|
|
|
"crypto/tls" |
|
|
|
"encoding/base64" |
|
|
|
"errors" |
|
|
|
"io" |
|
|
|
"io/ioutil" |
|
|
|
"net" |
|
|
|
"net/http" |
|
|
|
"strconv" |
|
|
|
"strings" |
|
|
|
"time" |
|
|
|
|
|
|
|
"github.com/go-log/log" |
|
|
|
@ -112,15 +118,16 @@ func (h *dnsHandler) dumpMsgHeader(m *dns.Msg) string { |
|
|
|
} |
|
|
|
|
|
|
|
type DNSOptions struct { |
|
|
|
TCPMode bool |
|
|
|
Mode string |
|
|
|
UDPSize int |
|
|
|
ReadTimeout time.Duration |
|
|
|
WriteTimeout time.Duration |
|
|
|
TLSConfig *tls.Config |
|
|
|
} |
|
|
|
|
|
|
|
type dnsListener struct { |
|
|
|
addr net.Addr |
|
|
|
server *dns.Server |
|
|
|
server dnsServer |
|
|
|
connChan chan net.Conn |
|
|
|
errc chan error |
|
|
|
} |
|
|
|
@ -130,31 +137,70 @@ func DNSListener(addr string, options *DNSOptions) (Listener, error) { |
|
|
|
options = &DNSOptions{} |
|
|
|
} |
|
|
|
|
|
|
|
tlsConfig := options.TLSConfig |
|
|
|
if tlsConfig == nil { |
|
|
|
tlsConfig = DefaultTLSConfig |
|
|
|
} |
|
|
|
|
|
|
|
ln := &dnsListener{ |
|
|
|
connChan: make(chan net.Conn, 128), |
|
|
|
errc: make(chan error, 1), |
|
|
|
} |
|
|
|
|
|
|
|
var nets string |
|
|
|
var srv dnsServer |
|
|
|
var err error |
|
|
|
switch strings.ToLower(options.Mode) { |
|
|
|
case "tcp": |
|
|
|
srv = &dns.Server{ |
|
|
|
Net: "tcp", |
|
|
|
Addr: addr, |
|
|
|
Handler: ln, |
|
|
|
ReadTimeout: options.ReadTimeout, |
|
|
|
WriteTimeout: options.WriteTimeout, |
|
|
|
} |
|
|
|
case "tls": |
|
|
|
srv = &dns.Server{ |
|
|
|
Net: "tcp-tls", |
|
|
|
Addr: addr, |
|
|
|
Handler: ln, |
|
|
|
TLSConfig: tlsConfig, |
|
|
|
ReadTimeout: options.ReadTimeout, |
|
|
|
WriteTimeout: options.WriteTimeout, |
|
|
|
} |
|
|
|
case "https": |
|
|
|
srv = &dohServer{ |
|
|
|
addr: addr, |
|
|
|
tlsConfig: tlsConfig, |
|
|
|
server: &http.Server{ |
|
|
|
Handler: ln, |
|
|
|
ReadTimeout: options.ReadTimeout, |
|
|
|
WriteTimeout: options.WriteTimeout, |
|
|
|
}, |
|
|
|
} |
|
|
|
|
|
|
|
if options.TCPMode { |
|
|
|
nets = "tcp" |
|
|
|
default: |
|
|
|
ln.addr, err = net.ResolveTCPAddr("tcp", addr) |
|
|
|
} else { |
|
|
|
nets = "udp" |
|
|
|
ln.addr, err = net.ResolveUDPAddr("udp", addr) |
|
|
|
srv = &dns.Server{ |
|
|
|
Net: "udp", |
|
|
|
Addr: addr, |
|
|
|
Handler: ln, |
|
|
|
UDPSize: options.UDPSize, |
|
|
|
ReadTimeout: options.ReadTimeout, |
|
|
|
WriteTimeout: options.WriteTimeout, |
|
|
|
} |
|
|
|
} |
|
|
|
if err != nil { |
|
|
|
return nil, err |
|
|
|
} |
|
|
|
|
|
|
|
ln.server = &dns.Server{ |
|
|
|
Addr: addr, |
|
|
|
Net: nets, |
|
|
|
if ln.addr == nil { |
|
|
|
ln.addr, err = net.ResolveTCPAddr("tcp", addr) |
|
|
|
if err != nil { |
|
|
|
return nil, err |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
dns.HandleFunc(".", ln.handleRequest) |
|
|
|
ln.server = srv |
|
|
|
|
|
|
|
go func() { |
|
|
|
if err := ln.server.ListenAndServe(); err != nil { |
|
|
|
@ -172,30 +218,76 @@ func DNSListener(addr string, options *DNSOptions) (Listener, error) { |
|
|
|
return ln, nil |
|
|
|
} |
|
|
|
|
|
|
|
func (l *dnsListener) handleRequest(w dns.ResponseWriter, m *dns.Msg) { |
|
|
|
if w == nil || m == nil { |
|
|
|
return |
|
|
|
func (l *dnsListener) serve(w dnsResponseWriter, mq []byte) (err error) { |
|
|
|
conn := newDNSServerConn(l.addr, w.RemoteAddr()) |
|
|
|
conn.mq <- mq |
|
|
|
|
|
|
|
select { |
|
|
|
case l.connChan <- conn: |
|
|
|
default: |
|
|
|
return errors.New("connection queue is full") |
|
|
|
} |
|
|
|
|
|
|
|
conn := &dnsServerConn{ |
|
|
|
mq: make(chan []byte, 1), |
|
|
|
ResponseWriter: w, |
|
|
|
select { |
|
|
|
case mr := <-conn.mr: |
|
|
|
_, err = w.Write(mr) |
|
|
|
case <-conn.cclose: |
|
|
|
err = io.EOF |
|
|
|
} |
|
|
|
return |
|
|
|
} |
|
|
|
|
|
|
|
buf := mPool.Get().([]byte) |
|
|
|
defer mPool.Put(buf) |
|
|
|
buf = buf[:0] |
|
|
|
b, err := m.PackBuffer(buf) |
|
|
|
func (l *dnsListener) ServeDNS(w dns.ResponseWriter, m *dns.Msg) { |
|
|
|
b, err := m.Pack() |
|
|
|
if err != nil { |
|
|
|
log.Logf("[dns] %s: %v", l.addr, err) |
|
|
|
return |
|
|
|
} |
|
|
|
conn.mq <- b |
|
|
|
if err := l.serve(w, b); err != nil { |
|
|
|
log.Logf("[dns] %s: %v", l.addr, err) |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
select { |
|
|
|
case l.connChan <- conn: |
|
|
|
// Based on https://github.com/semihalev/sdns
|
|
|
|
func (l *dnsListener) ServeHTTP(w http.ResponseWriter, r *http.Request) { |
|
|
|
var buf []byte |
|
|
|
var err error |
|
|
|
switch r.Method { |
|
|
|
case http.MethodGet: |
|
|
|
buf, err = base64.RawURLEncoding.DecodeString(r.URL.Query().Get("dns")) |
|
|
|
if len(buf) == 0 || err != nil { |
|
|
|
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) |
|
|
|
return |
|
|
|
} |
|
|
|
case http.MethodPost: |
|
|
|
if r.Header.Get("Content-Type") != "application/dns-message" { |
|
|
|
http.Error(w, http.StatusText(http.StatusUnsupportedMediaType), http.StatusUnsupportedMediaType) |
|
|
|
return |
|
|
|
} |
|
|
|
|
|
|
|
buf, err = ioutil.ReadAll(r.Body) |
|
|
|
if err != nil { |
|
|
|
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) |
|
|
|
return |
|
|
|
} |
|
|
|
default: |
|
|
|
log.Logf("[dns] %s: connection queue is full", l.addr) |
|
|
|
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) |
|
|
|
return |
|
|
|
} |
|
|
|
|
|
|
|
mq := &dns.Msg{} |
|
|
|
if err := mq.Unpack(buf); err != nil { |
|
|
|
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) |
|
|
|
return |
|
|
|
} |
|
|
|
|
|
|
|
w.Header().Set("Server", "SDNS") |
|
|
|
w.Header().Set("Content-Type", "application/dns-message") |
|
|
|
|
|
|
|
raddr, _ := net.ResolveTCPAddr("tcp", r.RemoteAddr) |
|
|
|
if err := l.serve(newDoHResponseWriter(raddr, w), buf); err != nil { |
|
|
|
log.Logf("[dns] %s: %v", l.addr, err) |
|
|
|
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@ -215,21 +307,85 @@ func (l *dnsListener) Addr() net.Addr { |
|
|
|
return l.addr |
|
|
|
} |
|
|
|
|
|
|
|
type dnsServer interface { |
|
|
|
ListenAndServe() error |
|
|
|
Shutdown() error |
|
|
|
} |
|
|
|
|
|
|
|
type dohServer struct { |
|
|
|
addr string |
|
|
|
tlsConfig *tls.Config |
|
|
|
server *http.Server |
|
|
|
} |
|
|
|
|
|
|
|
func (s *dohServer) ListenAndServe() error { |
|
|
|
ln, err := net.Listen("tcp", s.addr) |
|
|
|
if err != nil { |
|
|
|
return err |
|
|
|
} |
|
|
|
ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, s.tlsConfig) |
|
|
|
return s.server.Serve(ln) |
|
|
|
} |
|
|
|
|
|
|
|
func (s *dohServer) Shutdown() error { |
|
|
|
return s.server.Shutdown(context.Background()) |
|
|
|
} |
|
|
|
|
|
|
|
type dnsServerConn struct { |
|
|
|
mq chan []byte |
|
|
|
dns.ResponseWriter |
|
|
|
mq chan []byte |
|
|
|
mr chan []byte |
|
|
|
cclose chan struct{} |
|
|
|
laddr, raddr net.Addr |
|
|
|
} |
|
|
|
|
|
|
|
func newDNSServerConn(laddr, raddr net.Addr) *dnsServerConn { |
|
|
|
return &dnsServerConn{ |
|
|
|
mq: make(chan []byte, 1), |
|
|
|
mr: make(chan []byte, 1), |
|
|
|
laddr: laddr, |
|
|
|
raddr: raddr, |
|
|
|
cclose: make(chan struct{}), |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
func (c *dnsServerConn) Read(b []byte) (n int, err error) { |
|
|
|
var mb []byte |
|
|
|
select { |
|
|
|
case mb = <-c.mq: |
|
|
|
default: |
|
|
|
case mb := <-c.mq: |
|
|
|
n = copy(b, mb) |
|
|
|
case <-c.cclose: |
|
|
|
err = errors.New("connection is closed") |
|
|
|
} |
|
|
|
return |
|
|
|
} |
|
|
|
|
|
|
|
func (c *dnsServerConn) Write(b []byte) (n int, err error) { |
|
|
|
select { |
|
|
|
case c.mr <- b: |
|
|
|
n = len(b) |
|
|
|
case <-c.cclose: |
|
|
|
err = errors.New("broken pipe") |
|
|
|
} |
|
|
|
n = copy(b, mb) |
|
|
|
|
|
|
|
return |
|
|
|
} |
|
|
|
|
|
|
|
func (c *dnsServerConn) Close() error { |
|
|
|
select { |
|
|
|
case <-c.cclose: |
|
|
|
default: |
|
|
|
close(c.cclose) |
|
|
|
} |
|
|
|
return nil |
|
|
|
} |
|
|
|
|
|
|
|
func (c *dnsServerConn) LocalAddr() net.Addr { |
|
|
|
return c.laddr |
|
|
|
} |
|
|
|
|
|
|
|
func (c *dnsServerConn) RemoteAddr() net.Addr { |
|
|
|
return c.raddr |
|
|
|
} |
|
|
|
|
|
|
|
func (c *dnsServerConn) SetDeadline(t time.Time) error { |
|
|
|
return &net.OpError{Op: "set", Net: "dns", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} |
|
|
|
} |
|
|
|
@ -241,3 +397,24 @@ func (c *dnsServerConn) SetReadDeadline(t time.Time) error { |
|
|
|
func (c *dnsServerConn) SetWriteDeadline(t time.Time) error { |
|
|
|
return &net.OpError{Op: "set", Net: "dns", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} |
|
|
|
} |
|
|
|
|
|
|
|
type dnsResponseWriter interface { |
|
|
|
io.Writer |
|
|
|
RemoteAddr() net.Addr |
|
|
|
} |
|
|
|
|
|
|
|
type dohResponseWriter struct { |
|
|
|
raddr net.Addr |
|
|
|
http.ResponseWriter |
|
|
|
} |
|
|
|
|
|
|
|
func newDoHResponseWriter(raddr net.Addr, w http.ResponseWriter) dnsResponseWriter { |
|
|
|
return &dohResponseWriter{ |
|
|
|
raddr: raddr, |
|
|
|
ResponseWriter: w, |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
func (w *dohResponseWriter) RemoteAddr() net.Addr { |
|
|
|
return w.raddr |
|
|
|
} |
|
|
|
|