|
|
|
@ -3,8 +3,15 @@ package gost |
|
|
|
import ( |
|
|
|
"crypto/tls" |
|
|
|
"crypto/x509" |
|
|
|
"errors" |
|
|
|
"net" |
|
|
|
"sync" |
|
|
|
"sync/atomic" |
|
|
|
"time" |
|
|
|
|
|
|
|
"github.com/go-log/log" |
|
|
|
|
|
|
|
smux "gopkg.in/xtaci/smux.v1" |
|
|
|
) |
|
|
|
|
|
|
|
type tlsTransporter struct { |
|
|
|
@ -27,6 +34,113 @@ func (tr *tlsTransporter) Handshake(conn net.Conn, options ...HandshakeOption) ( |
|
|
|
return wrapTLSClient(conn, opts.TLSConfig) |
|
|
|
} |
|
|
|
|
|
|
|
type mtlsTransporter struct { |
|
|
|
tcpTransporter |
|
|
|
sessions map[string]*muxSession |
|
|
|
sessionMutex sync.Mutex |
|
|
|
} |
|
|
|
|
|
|
|
// MTLSTransporter creates a Transporter that is used by multiplex-TLS proxy client.
|
|
|
|
func MTLSTransporter() Transporter { |
|
|
|
return &mtlsTransporter{ |
|
|
|
sessions: make(map[string]*muxSession), |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
func (tr *mtlsTransporter) Dial(addr string, options ...DialOption) (conn net.Conn, err error) { |
|
|
|
opts := &DialOptions{} |
|
|
|
for _, option := range options { |
|
|
|
option(opts) |
|
|
|
} |
|
|
|
|
|
|
|
if len(opts.IPs) > 0 { |
|
|
|
count := atomic.AddUint64(&tr.count, 1) |
|
|
|
_, sport, err := net.SplitHostPort(addr) |
|
|
|
if err != nil { |
|
|
|
return nil, err |
|
|
|
} |
|
|
|
n := uint64(len(opts.IPs)) |
|
|
|
addr = opts.IPs[int(count%n)] + ":" + sport |
|
|
|
} |
|
|
|
|
|
|
|
tr.sessionMutex.Lock() |
|
|
|
defer tr.sessionMutex.Unlock() |
|
|
|
|
|
|
|
session, ok := tr.sessions[addr] // TODO: the addr may be changed.
|
|
|
|
if !ok { |
|
|
|
if opts.Chain == nil { |
|
|
|
conn, err = net.DialTimeout("tcp", addr, opts.Timeout) |
|
|
|
} else { |
|
|
|
conn, err = opts.Chain.Dial(addr) |
|
|
|
} |
|
|
|
if err != nil { |
|
|
|
return |
|
|
|
} |
|
|
|
session = &muxSession{conn: conn} |
|
|
|
tr.sessions[addr] = session |
|
|
|
} |
|
|
|
return session.conn, nil |
|
|
|
} |
|
|
|
|
|
|
|
func (tr *mtlsTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { |
|
|
|
opts := &HandshakeOptions{} |
|
|
|
for _, option := range options { |
|
|
|
option(opts) |
|
|
|
} |
|
|
|
|
|
|
|
tr.sessionMutex.Lock() |
|
|
|
defer tr.sessionMutex.Unlock() |
|
|
|
|
|
|
|
session, ok := tr.sessions[opts.Addr] |
|
|
|
if session != nil && session.conn != conn { |
|
|
|
conn.Close() |
|
|
|
return nil, errors.New("mtls: unrecognized connection") |
|
|
|
} |
|
|
|
if !ok || session.session == nil { |
|
|
|
s, err := tr.initSession(opts.Addr, conn, opts) |
|
|
|
if err != nil { |
|
|
|
conn.Close() |
|
|
|
delete(tr.sessions, opts.Addr) |
|
|
|
return nil, err |
|
|
|
} |
|
|
|
session = s |
|
|
|
tr.sessions[opts.Addr] = session |
|
|
|
} |
|
|
|
cc, err := session.GetConn() |
|
|
|
if err != nil { |
|
|
|
session.Close() |
|
|
|
delete(tr.sessions, opts.Addr) |
|
|
|
return nil, err |
|
|
|
} |
|
|
|
|
|
|
|
return cc, nil |
|
|
|
} |
|
|
|
|
|
|
|
func (tr *mtlsTransporter) initSession(addr string, conn net.Conn, opts *HandshakeOptions) (*muxSession, error) { |
|
|
|
if opts == nil { |
|
|
|
opts = &HandshakeOptions{} |
|
|
|
} |
|
|
|
if opts.TLSConfig == nil { |
|
|
|
opts.TLSConfig = &tls.Config{InsecureSkipVerify: true} |
|
|
|
} |
|
|
|
conn, err := wrapTLSClient(conn, opts.TLSConfig) |
|
|
|
if err != nil { |
|
|
|
return nil, err |
|
|
|
} |
|
|
|
|
|
|
|
// stream multiplex
|
|
|
|
smuxConfig := smux.DefaultConfig() |
|
|
|
session, err := smux.Client(conn, smuxConfig) |
|
|
|
if err != nil { |
|
|
|
return nil, err |
|
|
|
} |
|
|
|
return &muxSession{conn: conn, session: session}, nil |
|
|
|
} |
|
|
|
|
|
|
|
func (tr *mtlsTransporter) Multiplex() bool { |
|
|
|
return true |
|
|
|
} |
|
|
|
|
|
|
|
type tlsListener struct { |
|
|
|
net.Listener |
|
|
|
} |
|
|
|
@ -43,6 +157,94 @@ func TLSListener(addr string, config *tls.Config) (Listener, error) { |
|
|
|
return &tlsListener{ln}, nil |
|
|
|
} |
|
|
|
|
|
|
|
type mtlsListener struct { |
|
|
|
ln net.Listener |
|
|
|
connChan chan net.Conn |
|
|
|
errChan chan error |
|
|
|
} |
|
|
|
|
|
|
|
// MTLSListener creates a Listener for multiplex-TLS proxy server.
|
|
|
|
func MTLSListener(addr string, config *tls.Config) (Listener, error) { |
|
|
|
if config == nil { |
|
|
|
config = DefaultTLSConfig |
|
|
|
} |
|
|
|
ln, err := tls.Listen("tcp", addr, config) |
|
|
|
if err != nil { |
|
|
|
return nil, err |
|
|
|
} |
|
|
|
|
|
|
|
l := &mtlsListener{ |
|
|
|
ln: ln, |
|
|
|
connChan: make(chan net.Conn, 1024), |
|
|
|
errChan: make(chan error, 1), |
|
|
|
} |
|
|
|
go l.listenLoop() |
|
|
|
|
|
|
|
return l, nil |
|
|
|
} |
|
|
|
|
|
|
|
func (l *mtlsListener) listenLoop() { |
|
|
|
for { |
|
|
|
conn, err := l.ln.Accept() |
|
|
|
if err != nil { |
|
|
|
log.Log("[mtls] accept:", err) |
|
|
|
l.errChan <- err |
|
|
|
close(l.errChan) |
|
|
|
return |
|
|
|
} |
|
|
|
go l.mux(conn) |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
func (l *mtlsListener) mux(conn net.Conn) { |
|
|
|
log.Logf("[mtls] %s - %s", conn.RemoteAddr(), l.Addr()) |
|
|
|
smuxConfig := smux.DefaultConfig() |
|
|
|
mux, err := smux.Server(conn, smuxConfig) |
|
|
|
if err != nil { |
|
|
|
log.Logf("[mtls] %s - %s : %s", conn.RemoteAddr(), l.Addr(), err) |
|
|
|
return |
|
|
|
} |
|
|
|
defer mux.Close() |
|
|
|
|
|
|
|
log.Logf("[mtls] %s <-> %s", conn.RemoteAddr(), l.Addr()) |
|
|
|
defer log.Logf("[mtls] %s >-< %s", conn.RemoteAddr(), l.Addr()) |
|
|
|
|
|
|
|
for { |
|
|
|
stream, err := mux.AcceptStream() |
|
|
|
if err != nil { |
|
|
|
log.Log("[mtls] accept stream:", err) |
|
|
|
return |
|
|
|
} |
|
|
|
|
|
|
|
cc := &muxStreamConn{Conn: conn, stream: stream} |
|
|
|
select { |
|
|
|
case l.connChan <- cc: |
|
|
|
default: |
|
|
|
cc.Close() |
|
|
|
log.Logf("[mtls] %s - %s: connection queue is full", conn.RemoteAddr(), conn.LocalAddr()) |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
func (l *mtlsListener) Accept() (conn net.Conn, err error) { |
|
|
|
var ok bool |
|
|
|
select { |
|
|
|
case conn = <-l.connChan: |
|
|
|
case err, ok = <-l.errChan: |
|
|
|
if !ok { |
|
|
|
err = errors.New("accpet on closed listener") |
|
|
|
} |
|
|
|
} |
|
|
|
return |
|
|
|
} |
|
|
|
func (l *mtlsListener) Addr() net.Addr { |
|
|
|
return l.ln.Addr() |
|
|
|
} |
|
|
|
|
|
|
|
func (l *mtlsListener) Close() error { |
|
|
|
return l.ln.Close() |
|
|
|
} |
|
|
|
|
|
|
|
// Wrap a net.Conn into a client tls connection, performing any
|
|
|
|
// additional verification as needed.
|
|
|
|
//
|
|
|
|
|