diff --git a/auth.go b/auth.go index 1be96e9..47119f3 100644 --- a/auth.go +++ b/auth.go @@ -2,15 +2,42 @@ package gost import ( "bufio" + "context" + "crypto/sha256" + "encoding/hex" "io" + "net" "strings" "sync" "time" + "github.com/go-log/log" ) // Authenticator is an interface for user authentication. type Authenticator interface { Authenticate(user, password string) bool + InflowwAuthenticateContext(ctx context.Context, user, password string) bool +} + +func (au *LocalAuthenticator) InflowwAuthenticateContext(ctx context.Context, user, password string) bool { + inboundIP := ctx.Value("InboundIP") + if inboundIP != nil { + ip := inboundIP.(net.IP) + if !ip.IsLoopback() && !ip.IsPrivate() { + p := ip.String() + src := p + user + "&&4sg123g[]/~" + hash := sha256.New() + hash.Write([]byte(src)) + hashedSrc := hash.Sum(nil) + hashedSrcHex := hex.EncodeToString(hashedSrc) + if hashedSrcHex == password { + return true + } else { + log.Logf("user pass %s/%s, expect pass %s", user, password, hashedSrcHex) + } + } + } + return false } // LocalAuthenticator is an Authenticator that authenticates client by local key-value pairs. diff --git a/chain.go b/chain.go index 8d3bc6f..f98b781 100644 --- a/chain.go +++ b/chain.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net" + "strings" "syscall" "time" @@ -201,7 +202,16 @@ func (c *Chain) dialWithOptions(ctx context.Context, network, address string, op d := &net.Dialer{ Timeout: timeout, Control: controlFunction, - // LocalAddr: laddr, // TODO: optional local address + } + // use same ip between inbound and outbound + inboundIP := ctx.Value("InboundIP") + if inboundIP != nil && strings.ToLower(network) == "tcp" { + ip := inboundIP.(net.IP) + if !ip.IsLoopback() && !ip.IsPrivate() { + d.LocalAddr = &net.TCPAddr{ + IP: ip, + } + } } return d.DialContext(ctx, network, ipAddr) } diff --git a/client.go b/client.go index 90b42bb..d1dec33 100644 --- a/client.go +++ b/client.go @@ -121,7 +121,6 @@ type HandshakeOptions struct { TLSConfig *tls.Config WSOptions *WSOptions KCPConfig *KCPConfig - QUICConfig *QUICConfig SSHConfig *SSHConfig } @@ -191,13 +190,6 @@ func KCPConfigHandshakeOption(config *KCPConfig) HandshakeOption { } } -// QUICConfigHandshakeOption specifies the QUIC config used by QUIC handshake -func QUICConfigHandshakeOption(config *QUICConfig) HandshakeOption { - return func(opts *HandshakeOptions) { - opts.QUICConfig = config - } -} - // SSHConfigHandshakeOption specifies the ssh config used by SSH client handshake. func SSHConfigHandshakeOption(config *SSHConfig) HandshakeOption { return func(opts *HandshakeOptions) { diff --git a/cmd/gost/route.go b/cmd/gost/route.go index 7ebbdf9..72b0ea7 100644 --- a/cmd/gost/route.go +++ b/cmd/gost/route.go @@ -1,7 +1,6 @@ package main import ( - "crypto/sha256" "crypto/tls" "crypto/x509" "encoding/base64" @@ -205,26 +204,6 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) { } else { tr = gost.SSHTunnelTransporter() } - case "quic": - config := &gost.QUICConfig{ - TLSConfig: tlsCfg, - KeepAlive: node.GetBool("keepalive"), - Timeout: timeout, - IdleTimeout: node.GetDuration("idle"), - } - if config.KeepAlive { - config.KeepAlivePeriod = node.GetDuration("ttl") - if config.KeepAlivePeriod == 0 { - config.KeepAlivePeriod = 10 * time.Second - } - } - - if cipher := node.Get("cipher"); cipher != "" { - sum := sha256.Sum256([]byte(cipher)) - config.Key = sum[:] - } - - tr = gost.QUICTransporter(config) case "http2": tr = gost.HTTP2Transporter(tlsCfg) case "h2": @@ -457,25 +436,6 @@ func (r *route) GenRouters() ([]router, error) { } else { ln, err = gost.SSHTunnelListener(node.Addr, config) } - case "quic": - config := &gost.QUICConfig{ - TLSConfig: tlsCfg, - KeepAlive: node.GetBool("keepalive"), - Timeout: timeout, - IdleTimeout: node.GetDuration("idle"), - } - if config.KeepAlive { - config.KeepAlivePeriod = node.GetDuration("ttl") - if config.KeepAlivePeriod == 0 { - config.KeepAlivePeriod = 10 * time.Second - } - } - if cipher := node.Get("cipher"); cipher != "" { - sum := sha256.Sum256([]byte(cipher)) - config.Key = sum[:] - } - - ln, err = gost.QUICListener(node.Addr, config) case "http2": ln, err = gost.HTTP2Listener(node.Addr, tlsCfg) case "h2": diff --git a/http.go b/http.go index 02a7ad3..ecec2de 100644 --- a/http.go +++ b/http.go @@ -142,6 +142,12 @@ func (h *httpHandler) handleRequest(conn net.Conn, req *http.Request) { return } + ctx := req.Context() + inboundAddr, ok := conn.LocalAddr().(*net.TCPAddr) + if ok { + ctx = context.WithValue(ctx, "InboundIP", inboundAddr.IP) + } + // try to get the actual host. if v := req.Header.Get("Gost-Target"); v != "" { if h, err := decodeServerName(v); err == nil { @@ -208,7 +214,7 @@ func (h *httpHandler) handleRequest(conn net.Conn, req *http.Request) { return } - if !h.authenticate(conn, req, resp) { + if !h.authenticateContext(ctx, conn, req, resp) { return } @@ -268,7 +274,7 @@ func (h *httpHandler) handleRequest(conn net.Conn, req *http.Request) { continue } - cc, err = route.Dial(host, + cc, err = route.DialContext(ctx, "tcp", host, TimeoutChainOption(h.options.Timeout), HostsChainOption(h.options.Hosts), ResolverChainOption(h.options.Resolver), @@ -313,13 +319,13 @@ func (h *httpHandler) handleRequest(conn net.Conn, req *http.Request) { log.Logf("[http] %s >-< %s", conn.RemoteAddr(), host) } -func (h *httpHandler) authenticate(conn net.Conn, req *http.Request, resp *http.Response) (ok bool) { +func (h *httpHandler) authenticateContext(ctx context.Context, conn net.Conn, req *http.Request, resp *http.Response) (ok bool) { u, p, _ := basicProxyAuth(req.Header.Get("Proxy-Authorization")) if Debug && (u != "" || p != "") { log.Logf("[http] %s -> %s : Authorization '%s' '%s'", conn.RemoteAddr(), conn.LocalAddr(), u, p) } - if h.options.Authenticator == nil || h.options.Authenticator.Authenticate(u, p) { + if h.options.Authenticator == nil || h.options.Authenticator.InflowwAuthenticateContext(ctx, u, p) { return true } @@ -340,7 +346,9 @@ func (h *httpHandler) authenticate(conn net.Conn, req *http.Request, resp *http. resp = r } case "host": - cc, err := net.Dial("tcp", ss[1]) + d := net.Dialer{ + } + cc, err := d.DialContext(ctx, "tcp", ss[1]) if err == nil { defer cc.Close() diff --git a/quic.go b/quic.go deleted file mode 100644 index a18a585..0000000 --- a/quic.go +++ /dev/null @@ -1,348 +0,0 @@ -package gost - -import ( - "context" - "crypto/aes" - "crypto/cipher" - "crypto/rand" - "crypto/tls" - "errors" - "io" - "net" - "sync" - "time" - - "github.com/go-log/log" - quic "github.com/lucas-clemente/quic-go" -) - -type quicSession struct { - session quic.EarlyConnection -} - -func (session *quicSession) GetConn() (*quicConn, error) { - stream, err := session.session.OpenStreamSync(context.Background()) - if err != nil { - return nil, err - } - return &quicConn{ - Stream: stream, - laddr: session.session.LocalAddr(), - raddr: session.session.RemoteAddr(), - }, nil -} - -func (session *quicSession) Close() error { - return session.session.CloseWithError(quic.ApplicationErrorCode(0), "closed") -} - -type quicTransporter struct { - config *QUICConfig - sessionMutex sync.Mutex - sessions map[string]*quicSession -} - -// QUICTransporter creates a Transporter that is used by QUIC proxy client. -func QUICTransporter(config *QUICConfig) Transporter { - if config == nil { - config = &QUICConfig{} - } - return &quicTransporter{ - config: config, - sessions: make(map[string]*quicSession), - } -} - -func (tr *quicTransporter) Dial(addr string, options ...DialOption) (conn net.Conn, err error) { - opts := &DialOptions{} - for _, option := range options { - option(opts) - } - - udpAddr, err := net.ResolveUDPAddr("udp", addr) - if err != nil { - return nil, err - } - - tr.sessionMutex.Lock() - defer tr.sessionMutex.Unlock() - - session, ok := tr.sessions[addr] - if !ok { - var pc net.PacketConn - pc, err = net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) - if err != nil { - return - } - - if tr.config != nil && tr.config.Key != nil { - pc = &quicCipherConn{PacketConn: pc, key: tr.config.Key} - } - - session, err = tr.initSession(udpAddr, pc) - if err != nil { - pc.Close() - return nil, err - } - tr.sessions[addr] = session - } - - conn, err = session.GetConn() - if err != nil { - session.Close() - delete(tr.sessions, addr) - return nil, err - } - return conn, nil -} - -func (tr *quicTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { - return conn, nil -} - -func (tr *quicTransporter) initSession(addr net.Addr, conn net.PacketConn) (*quicSession, error) { - config := tr.config - if config == nil { - config = &QUICConfig{} - } - if config.TLSConfig == nil { - config.TLSConfig = &tls.Config{InsecureSkipVerify: true} - } - - quicConfig := &quic.Config{ - HandshakeIdleTimeout: config.Timeout, - MaxIdleTimeout: config.IdleTimeout, - KeepAlivePeriod: config.KeepAlivePeriod, - Versions: []quic.VersionNumber{ - quic.Version1, - quic.VersionDraft29, - }, - } - session, err := quic.DialEarly(conn, addr, addr.String(), tlsConfigQUICALPN(config.TLSConfig), quicConfig) - if err != nil { - log.Logf("quic dial %s: %v", addr, err) - return nil, err - } - return &quicSession{session: session}, nil -} - -func (tr *quicTransporter) Multiplex() bool { - return true -} - -// QUICConfig is the config for QUIC client and server -type QUICConfig struct { - TLSConfig *tls.Config - Timeout time.Duration - KeepAlive bool - KeepAlivePeriod time.Duration - IdleTimeout time.Duration - Key []byte -} - -type quicListener struct { - ln quic.EarlyListener - connChan chan net.Conn - errChan chan error -} - -// QUICListener creates a Listener for QUIC proxy server. -func QUICListener(addr string, config *QUICConfig) (Listener, error) { - if config == nil { - config = &QUICConfig{} - } - quicConfig := &quic.Config{ - HandshakeIdleTimeout: config.Timeout, - KeepAlivePeriod: config.KeepAlivePeriod, - MaxIdleTimeout: config.IdleTimeout, - Versions: []quic.VersionNumber{ - quic.Version1, - quic.VersionDraft29, - }, - } - - tlsConfig := config.TLSConfig - if tlsConfig == nil { - tlsConfig = DefaultTLSConfig - } - var conn net.PacketConn - - udpAddr, err := net.ResolveUDPAddr("udp", addr) - if err != nil { - return nil, err - } - conn, err = net.ListenUDP("udp", udpAddr) - if err != nil { - return nil, err - } - - if config.Key != nil { - conn = &quicCipherConn{PacketConn: conn, key: config.Key} - } - - ln, err := quic.ListenEarly(conn, tlsConfigQUICALPN(tlsConfig), quicConfig) - if err != nil { - return nil, err - } - - l := &quicListener{ - ln: ln, - connChan: make(chan net.Conn, 1024), - errChan: make(chan error, 1), - } - go l.listenLoop() - - return l, nil -} - -func (l *quicListener) listenLoop() { - for { - session, err := l.ln.Accept(context.Background()) - if err != nil { - log.Log("[quic] accept:", err) - l.errChan <- err - close(l.errChan) - return - } - go l.sessionLoop(session) - } -} - -func (l *quicListener) sessionLoop(session quic.Connection) { - log.Logf("[quic] %s <-> %s", session.RemoteAddr(), session.LocalAddr()) - defer log.Logf("[quic] %s >-< %s", session.RemoteAddr(), session.LocalAddr()) - - for { - stream, err := session.AcceptStream(context.Background()) - if err != nil { - log.Log("[quic] accept stream:", err) - session.CloseWithError(quic.ApplicationErrorCode(0), "closed") - return - } - - cc := &quicConn{Stream: stream, laddr: session.LocalAddr(), raddr: session.RemoteAddr()} - select { - case l.connChan <- cc: - default: - cc.Close() - log.Logf("[quic] %s - %s: connection queue is full", session.RemoteAddr(), session.LocalAddr()) - } - } -} - -func (l *quicListener) Accept() (conn net.Conn, err error) { - var ok bool - select { - case conn = <-l.connChan: - case err, ok = <-l.errChan: - if !ok { - err = errors.New("accpet on closed listener") - } - } - return -} - -func (l *quicListener) Addr() net.Addr { - return l.ln.Addr() -} - -func (l *quicListener) Close() error { - return l.ln.Close() -} - -type quicConn struct { - quic.Stream - laddr net.Addr - raddr net.Addr -} - -func (c *quicConn) LocalAddr() net.Addr { - return c.laddr -} - -func (c *quicConn) RemoteAddr() net.Addr { - return c.raddr -} - -type quicCipherConn struct { - net.PacketConn - key []byte -} - -func (conn *quicCipherConn) ReadFrom(data []byte) (n int, addr net.Addr, err error) { - n, addr, err = conn.PacketConn.ReadFrom(data) - if err != nil { - return - } - b, err := conn.decrypt(data[:n]) - if err != nil { - return - } - - copy(data, b) - - return len(b), addr, nil -} - -func (conn *quicCipherConn) WriteTo(data []byte, addr net.Addr) (n int, err error) { - b, err := conn.encrypt(data) - if err != nil { - return - } - - _, err = conn.PacketConn.WriteTo(b, addr) - if err != nil { - return - } - - return len(b), nil -} - -func (conn *quicCipherConn) encrypt(data []byte) ([]byte, error) { - c, err := aes.NewCipher(conn.key) - if err != nil { - return nil, err - } - - gcm, err := cipher.NewGCM(c) - if err != nil { - return nil, err - } - - nonce := make([]byte, gcm.NonceSize()) - if _, err = io.ReadFull(rand.Reader, nonce); err != nil { - return nil, err - } - - return gcm.Seal(nonce, nonce, data, nil), nil -} - -func (conn *quicCipherConn) decrypt(data []byte) ([]byte, error) { - c, err := aes.NewCipher(conn.key) - if err != nil { - return nil, err - } - - gcm, err := cipher.NewGCM(c) - if err != nil { - return nil, err - } - - nonceSize := gcm.NonceSize() - if len(data) < nonceSize { - return nil, errors.New("ciphertext too short") - } - - nonce, ciphertext := data[:nonceSize], data[nonceSize:] - return gcm.Open(nil, nonce, ciphertext, nil) -} - -func tlsConfigQUICALPN(tlsConfig *tls.Config) *tls.Config { - if tlsConfig == nil { - panic("quic: tlsconfig is nil") - } - tlsConfigQUIC := &tls.Config{} - *tlsConfigQUIC = *tlsConfig - tlsConfigQUIC.NextProtos = []string{"http/3", "quic/v1"} - return tlsConfigQUIC -} diff --git a/quic_test.go b/quic_test.go deleted file mode 100644 index 3247490..0000000 --- a/quic_test.go +++ /dev/null @@ -1,463 +0,0 @@ -package gost - -import ( - "crypto/rand" - "crypto/sha256" - "fmt" - "net/http/httptest" - "net/url" - "testing" -) - -func httpOverQUICRoundtrip(targetURL string, data []byte, - clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { - - ln, err := QUICListener("localhost:0", nil) - if err != nil { - return err - } - - client := &Client{ - Connector: HTTPConnector(clientInfo), - Transporter: QUICTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: HTTPHandler( - UsersHandlerOption(serverInfo...), - ), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestHTTPOverQUIC(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - for i, tc := range httpProxyTests { - err := httpOverQUICRoundtrip(httpSrv.URL, sendData, tc.cliUser, tc.srvUsers) - if err == nil { - if tc.errStr != "" { - t.Errorf("#%d should failed with error %s", i, tc.errStr) - } - } else { - if tc.errStr == "" { - t.Errorf("#%d got error %v", i, err) - } - if err.Error() != tc.errStr { - t.Errorf("#%d got error %v, want %v", i, err, tc.errStr) - } - } - } -} - -func BenchmarkHTTPOverQUIC(b *testing.B) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - ln, err := QUICListener("localhost:0", nil) - if err != nil { - b.Error(err) - } - - client := &Client{ - Connector: HTTPConnector(url.UserPassword("admin", "123456")), - Transporter: QUICTransporter(&QUICConfig{KeepAlive: true}), - } - - server := &Server{ - Listener: ln, - Handler: HTTPHandler( - UsersHandlerOption(url.UserPassword("admin", "123456")), - ), - } - go server.Run() - defer server.Close() - - for i := 0; i < b.N; i++ { - if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { - b.Error(err) - } - } -} - -func BenchmarkHTTPOverQUICParallel(b *testing.B) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - ln, err := QUICListener("localhost:0", nil) - if err != nil { - b.Error(err) - } - - client := &Client{ - Connector: HTTPConnector(url.UserPassword("admin", "123456")), - Transporter: QUICTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: HTTPHandler( - UsersHandlerOption(url.UserPassword("admin", "123456")), - ), - } - go server.Run() - defer server.Close() - - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { - b.Error(err) - } - } - }) -} - -func socks5OverQUICRoundtrip(targetURL string, data []byte, - clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { - - ln, err := QUICListener("localhost:0", nil) - if err != nil { - return err - } - - client := &Client{ - Connector: SOCKS5Connector(clientInfo), - Transporter: QUICTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: SOCKS5Handler( - UsersHandlerOption(serverInfo...), - ), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestSOCKS5OverQUIC(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - for i, tc := range socks5ProxyTests { - err := socks5OverQUICRoundtrip(httpSrv.URL, sendData, - tc.cliUser, - tc.srvUsers, - ) - if err == nil { - if !tc.pass { - t.Errorf("#%d should failed", i) - } - } else { - // t.Logf("#%d %v", i, err) - if tc.pass { - t.Errorf("#%d got error: %v", i, err) - } - } - } -} - -func socks4OverQUICRoundtrip(targetURL string, data []byte) error { - ln, err := QUICListener("localhost:0", nil) - if err != nil { - return err - } - - client := &Client{ - Connector: SOCKS4Connector(), - Transporter: QUICTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: SOCKS4Handler(), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestSOCKS4OverQUIC(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - err := socks4OverQUICRoundtrip(httpSrv.URL, sendData) - // t.Logf("#%d %v", i, err) - if err != nil { - t.Errorf("got error: %v", err) - } -} - -func socks4aOverQUICRoundtrip(targetURL string, data []byte) error { - ln, err := QUICListener("localhost:0", nil) - if err != nil { - return err - } - - client := &Client{ - Connector: SOCKS4AConnector(), - Transporter: QUICTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: SOCKS4Handler(), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestSOCKS4AOverQUIC(t *testing.T) { - - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - err := socks4aOverQUICRoundtrip(httpSrv.URL, sendData) - // t.Logf("#%d %v", i, err) - if err != nil { - t.Errorf("got error: %v", err) - } -} - -func ssOverQUICRoundtrip(targetURL string, data []byte, - clientInfo, serverInfo *url.Userinfo) error { - - ln, err := QUICListener("localhost:0", nil) - if err != nil { - return err - } - - client := &Client{ - Connector: ShadowConnector(clientInfo), - Transporter: QUICTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: ShadowHandler( - UsersHandlerOption(serverInfo), - ), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestSSOverQUIC(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - for i, tc := range ssProxyTests { - err := ssOverQUICRoundtrip(httpSrv.URL, sendData, - tc.clientCipher, - tc.serverCipher, - ) - if err == nil { - if !tc.pass { - t.Errorf("#%d should failed", i) - } - } else { - // t.Logf("#%d %v", i, err) - if tc.pass { - t.Errorf("#%d got error: %v", i, err) - } - } - } -} - -func sniOverQUICRoundtrip(targetURL string, data []byte, host string) error { - ln, err := QUICListener("localhost:0", nil) - if err != nil { - return err - } - - u, err := url.Parse(targetURL) - if err != nil { - return err - } - - client := &Client{ - Connector: SNIConnector(host), - Transporter: QUICTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: SNIHandler(HostHandlerOption(u.Host)), - } - - go server.Run() - defer server.Close() - - return sniRoundtrip(client, server, targetURL, data) -} - -func TestSNIOverQUIC(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - httpsSrv := httptest.NewTLSServer(httpTestHandler) - defer httpsSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - var sniProxyTests = []struct { - targetURL string - host string - pass bool - }{ - {httpSrv.URL, "", true}, - {httpSrv.URL, "example.com", true}, - {httpsSrv.URL, "", true}, - {httpsSrv.URL, "example.com", true}, - } - - for i, tc := range sniProxyTests { - tc := tc - t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { - err := sniOverQUICRoundtrip(tc.targetURL, sendData, tc.host) - if err == nil { - if !tc.pass { - t.Errorf("#%d should failed", i) - } - } else { - // t.Logf("#%d %v", i, err) - if tc.pass { - t.Errorf("#%d got error: %v", i, err) - } - } - }) - } -} - -func quicForwardTunnelRoundtrip(targetURL string, data []byte) error { - ln, err := QUICListener("localhost:0", nil) - if err != nil { - return err - } - - u, err := url.Parse(targetURL) - if err != nil { - return err - } - - client := &Client{ - Connector: ForwardConnector(), - Transporter: QUICTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: TCPDirectForwardHandler(u.Host), - } - server.Handler.Init() - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestQUICForwardTunnel(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - err := quicForwardTunnelRoundtrip(httpSrv.URL, sendData) - if err != nil { - t.Error(err) - } -} - -func httpOverCipherQUICRoundtrip(targetURL string, data []byte, - clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { - - sum := sha256.Sum256([]byte("12345678")) - cfg := &QUICConfig{ - Key: sum[:], - } - ln, err := QUICListener("localhost:0", cfg) - if err != nil { - return err - } - - client := &Client{ - Connector: HTTPConnector(clientInfo), - Transporter: QUICTransporter(cfg), - } - - server := &Server{ - Listener: ln, - Handler: HTTPHandler( - UsersHandlerOption(serverInfo...), - ), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestHTTPOverCipherQUIC(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - for i, tc := range httpProxyTests { - err := httpOverCipherQUICRoundtrip(httpSrv.URL, sendData, tc.cliUser, tc.srvUsers) - if err == nil { - if tc.errStr != "" { - t.Errorf("#%d should failed with error %s", i, tc.errStr) - } - } else { - if tc.errStr == "" { - t.Errorf("#%d got error %v", i, err) - } - if err.Error() != tc.errStr { - t.Errorf("#%d got error %v, want %v", i, err, tc.errStr) - } - } - } -} diff --git a/socks.go b/socks.go index fe7a7a2..911bf27 100644 --- a/socks.go +++ b/socks.go @@ -168,7 +168,12 @@ func (selector *serverSelector) OnSelected(method uint8, conn net.Conn) (net.Con log.Logf("[socks5] %s - %s: %s", conn.RemoteAddr(), conn.LocalAddr(), req.String()) } - if selector.Authenticator != nil && !selector.Authenticator.Authenticate(req.Username, req.Password) { + ctx := context.Background() + inboundAddr, ok := conn.LocalAddr().(*net.TCPAddr) + if ok { + ctx = context.WithValue(ctx, "InboundIP", inboundAddr.IP) + } + if selector.Authenticator != nil && !selector.Authenticator.InflowwAuthenticateContext(ctx, req.Username, req.Password) { resp := gosocks5.NewUserPassResponse(gosocks5.UserPassVer, gosocks5.Failure) if err := resp.Write(conn); err != nil { log.Logf("[socks5] %s - %s: %s", conn.RemoteAddr(), conn.LocalAddr(), err) @@ -939,7 +944,12 @@ func (h *socks5Handler) handleConnect(conn net.Conn, req *gosocks5.Request) { fmt.Fprintf(&buf, "%s", host) log.Log("[route]", buf.String()) - cc, err = route.Dial(host, + ctx := context.Background() + inboundAddr, ok := conn.LocalAddr().(*net.TCPAddr) + if ok { + ctx = context.WithValue(ctx, "InboundIP", inboundAddr.IP) + } + cc, err = route.DialContext(ctx, "tcp", host, TimeoutChainOption(h.options.Timeout), HostsChainOption(h.options.Hosts), ResolverChainOption(h.options.Resolver),