mirror of https://github.com/ginuerzh/gost
7 changed files with 642 additions and 106 deletions
@ -0,0 +1,224 @@ |
|||
package gost |
|||
|
|||
import ( |
|||
"bytes" |
|||
"context" |
|||
"errors" |
|||
"net" |
|||
"strconv" |
|||
"time" |
|||
|
|||
"github.com/go-log/log" |
|||
"github.com/miekg/dns" |
|||
) |
|||
|
|||
type dnsHandler struct { |
|||
options *HandlerOptions |
|||
} |
|||
|
|||
// DNSHandler creates a Handler for DNS server.
|
|||
func DNSHandler(raddr string, opts ...HandlerOption) Handler { |
|||
h := &dnsHandler{} |
|||
|
|||
for _, opt := range opts { |
|||
opt(h.options) |
|||
} |
|||
return h |
|||
} |
|||
|
|||
func (h *dnsHandler) Init(opts ...HandlerOption) { |
|||
if h.options == nil { |
|||
h.options = &HandlerOptions{} |
|||
} |
|||
|
|||
for _, opt := range opts { |
|||
opt(h.options) |
|||
} |
|||
} |
|||
|
|||
func (h *dnsHandler) Handle(conn net.Conn) { |
|||
defer conn.Close() |
|||
|
|||
b := mPool.Get().([]byte) |
|||
defer mPool.Put(b) |
|||
|
|||
n, err := conn.Read(b) |
|||
if err != nil { |
|||
log.Logf("[dns] %s - %s: %v", conn.RemoteAddr(), conn.LocalAddr(), err) |
|||
} |
|||
|
|||
mq := &dns.Msg{} |
|||
if err = mq.Unpack(b[:n]); err != nil { |
|||
log.Logf("[dns] %s - %s request unpack: %v", conn.RemoteAddr(), conn.LocalAddr(), err) |
|||
return |
|||
} |
|||
log.Logf("[dns] %s -> %s: %s", conn.RemoteAddr(), conn.LocalAddr(), h.dumpMsgHeader(mq)) |
|||
if Debug { |
|||
log.Logf("[dns] %s >>> %s: %s", conn.RemoteAddr(), conn.LocalAddr(), mq.String()) |
|||
} |
|||
|
|||
start := time.Now() |
|||
reply, err := h.options.Resolver.Exchange(context.Background(), b[:n]) |
|||
if err != nil { |
|||
log.Logf("[dns] %s - %s exchange: %v", conn.RemoteAddr(), conn.LocalAddr(), err) |
|||
return |
|||
} |
|||
|
|||
rtt := time.Since(start) |
|||
|
|||
mr := &dns.Msg{} |
|||
if err = mr.Unpack(reply); err != nil { |
|||
log.Logf("[dns] %s - %s reply unpack: %v", conn.RemoteAddr(), conn.LocalAddr(), err) |
|||
return |
|||
} |
|||
log.Logf("[dns] %s <- %s: %s [%s]", |
|||
conn.RemoteAddr(), conn.LocalAddr(), h.dumpMsgHeader(mr), rtt) |
|||
if Debug { |
|||
log.Logf("[dns] %s <<< %s: %s", conn.RemoteAddr(), conn.LocalAddr(), mr.String()) |
|||
} |
|||
|
|||
if _, err = conn.Write(reply); err != nil { |
|||
log.Logf("[dns] %s - %s reply unpack: %v", conn.RemoteAddr(), conn.LocalAddr(), err) |
|||
} |
|||
} |
|||
|
|||
func (h *dnsHandler) dumpMsgHeader(m *dns.Msg) string { |
|||
buf := new(bytes.Buffer) |
|||
buf.WriteString(m.MsgHdr.String() + " ") |
|||
buf.WriteString("QUERY: " + strconv.Itoa(len(m.Question)) + ", ") |
|||
buf.WriteString("ANSWER: " + strconv.Itoa(len(m.Answer)) + ", ") |
|||
buf.WriteString("AUTHORITY: " + strconv.Itoa(len(m.Ns)) + ", ") |
|||
buf.WriteString("ADDITIONAL: " + strconv.Itoa(len(m.Extra))) |
|||
return buf.String() |
|||
} |
|||
|
|||
type DNSOptions struct { |
|||
TCPMode bool |
|||
UDPSize int |
|||
ReadTimeout time.Duration |
|||
WriteTimeout time.Duration |
|||
} |
|||
|
|||
type dnsListener struct { |
|||
addr net.Addr |
|||
server *dns.Server |
|||
connChan chan net.Conn |
|||
errc chan error |
|||
} |
|||
|
|||
func DNSListener(addr string, options *DNSOptions) (Listener, error) { |
|||
if options == nil { |
|||
options = &DNSOptions{} |
|||
} |
|||
|
|||
ln := &dnsListener{ |
|||
connChan: make(chan net.Conn, 128), |
|||
errc: make(chan error, 1), |
|||
} |
|||
|
|||
var nets string |
|||
var err error |
|||
|
|||
if options.TCPMode { |
|||
nets = "tcp" |
|||
ln.addr, err = net.ResolveTCPAddr("tcp", addr) |
|||
} else { |
|||
nets = "udp" |
|||
ln.addr, err = net.ResolveUDPAddr("udp", addr) |
|||
} |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
ln.server = &dns.Server{ |
|||
Addr: addr, |
|||
Net: nets, |
|||
} |
|||
|
|||
dns.HandleFunc(".", ln.handleRequest) |
|||
|
|||
go func() { |
|||
if err := ln.server.ListenAndServe(); err != nil { |
|||
ln.errc <- err |
|||
return |
|||
} |
|||
}() |
|||
|
|||
select { |
|||
case err := <-ln.errc: |
|||
return nil, err |
|||
default: |
|||
} |
|||
|
|||
return ln, nil |
|||
} |
|||
|
|||
func (l *dnsListener) handleRequest(w dns.ResponseWriter, m *dns.Msg) { |
|||
if w == nil || m == nil { |
|||
return |
|||
} |
|||
|
|||
conn := &dnsServerConn{ |
|||
mq: make(chan []byte, 1), |
|||
ResponseWriter: w, |
|||
} |
|||
|
|||
buf := mPool.Get().([]byte) |
|||
defer mPool.Put(buf) |
|||
buf = buf[:0] |
|||
b, err := m.PackBuffer(buf) |
|||
if err != nil { |
|||
log.Logf("[dns] %s: %v", l.addr, err) |
|||
return |
|||
} |
|||
conn.mq <- b |
|||
|
|||
select { |
|||
case l.connChan <- conn: |
|||
default: |
|||
log.Logf("[dns] %s: connection queue is full", l.addr) |
|||
} |
|||
} |
|||
|
|||
func (l *dnsListener) Accept() (conn net.Conn, err error) { |
|||
select { |
|||
case conn = <-l.connChan: |
|||
case err = <-l.errc: |
|||
} |
|||
return |
|||
} |
|||
|
|||
func (l *dnsListener) Close() error { |
|||
return l.server.Shutdown() |
|||
} |
|||
|
|||
func (l *dnsListener) Addr() net.Addr { |
|||
return l.addr |
|||
} |
|||
|
|||
type dnsServerConn struct { |
|||
mq chan []byte |
|||
dns.ResponseWriter |
|||
} |
|||
|
|||
func (c *dnsServerConn) Read(b []byte) (n int, err error) { |
|||
var mb []byte |
|||
select { |
|||
case mb = <-c.mq: |
|||
default: |
|||
} |
|||
n = copy(b, mb) |
|||
return |
|||
} |
|||
|
|||
func (c *dnsServerConn) SetDeadline(t time.Time) error { |
|||
return &net.OpError{Op: "set", Net: "dns", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} |
|||
} |
|||
|
|||
func (c *dnsServerConn) SetReadDeadline(t time.Time) error { |
|||
return &net.OpError{Op: "set", Net: "dns", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} |
|||
} |
|||
|
|||
func (c *dnsServerConn) SetWriteDeadline(t time.Time) error { |
|||
return &net.OpError{Op: "set", Net: "dns", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} |
|||
} |
|||
Loading…
Reference in new issue