diff --git a/auth.go b/auth.go index 4158cb9..4ded573 100644 --- a/auth.go +++ b/auth.go @@ -2,7 +2,6 @@ package gost import ( "bufio" - "context" "crypto/sha256" "encoding/hex" "io" @@ -17,13 +16,11 @@ import ( // Authenticator is an interface for user authentication. type Authenticator interface { Authenticate(user, password string) bool - InflowwAuthenticateContext(ctx context.Context, user, password string) bool + IFAuthenticate(ip net.IP, user, password string) bool } -func (au *LocalAuthenticator) InflowwAuthenticateContext(ctx context.Context, user, password string) bool { - inboundIP := ctx.Value("InboundIP") - if inboundIP != nil { - ip := inboundIP.(net.IP) +func (au *LocalAuthenticator) IFAuthenticate(ip net.IP, user, password string) bool { + if ip != nil { if !ip.IsLoopback() && !ip.IsPrivate() { expected := GeneratePass(ip.String(), user) if expected == password { diff --git a/chain.go b/chain.go index 8d3bc6f..d001980 100644 --- a/chain.go +++ b/chain.go @@ -9,11 +9,13 @@ import ( "time" "github.com/go-log/log" + "golang.org/x/crypto/ssh" ) var ( // ErrEmptyChain is an error that implies the chain is empty. ErrEmptyChain = errors.New("empty chain") + localAddrKey = "localAddr" ) // Chain is a proxy chain that holds a list of proxy node groups. @@ -109,14 +111,29 @@ func (c *Chain) IsEmpty() bool { return c == nil || len(c.nodeGroups) == 0 } +func GetIP(conn net.Conn) (ip net.IP) { + IP := conn.LocalAddr().(*net.TCPAddr).IP + if IP != nil && !IP.IsPrivate() && !IP.IsLoopback() { + return IP + } + return nil +} +func GetSshIP(conn ssh.ConnMetadata) (ip net.IP) { + IP := conn.LocalAddr().(*net.TCPAddr).IP + if IP != nil && !IP.IsPrivate() && !IP.IsLoopback() { + return IP + } + return nil +} + // Dial connects to the target TCP address addr through the chain. // Deprecated: use DialContext instead. -func (c *Chain) Dial(address string, opts ...ChainOption) (conn net.Conn, err error) { - return c.DialContext(context.Background(), "tcp", address, opts...) +func (c *Chain) Dial(ip net.IP, address string, opts ...ChainOption) (conn net.Conn, err error) { + return c.DialContext(ip, context.Background(), "tcp", address, opts...) } // DialContext connects to the address on the named network using the provided context. -func (c *Chain) DialContext(ctx context.Context, network, address string, opts ...ChainOption) (conn net.Conn, err error) { +func (c *Chain) DialContext(ip net.IP, ctx context.Context, network, address string, opts ...ChainOption) (conn net.Conn, err error) { options := &ChainOptions{} for _, opt := range opts { opt(options) @@ -131,7 +148,7 @@ func (c *Chain) DialContext(ctx context.Context, network, address string, opts . } for i := 0; i < retries; i++ { - conn, err = c.dialWithOptions(ctx, network, address, options) + conn, err = c.dialWithOptions(ip, ctx, network, address, options) if err == nil { break } @@ -139,7 +156,7 @@ func (c *Chain) DialContext(ctx context.Context, network, address string, opts . return } -func (c *Chain) dialWithOptions(ctx context.Context, network, address string, options *ChainOptions) (net.Conn, error) { +func (c *Chain) dialWithOptions(ip net.IP, ctx context.Context, network, address string, options *ChainOptions) (net.Conn, error) { if options == nil { options = &ChainOptions{} } @@ -198,10 +215,28 @@ func (c *Chain) dialWithOptions(ctx context.Context, network, address string, op } default: } + var localAddr net.Addr + // ip := cn.LocalAddr().(*net.TCPAddr).IP + // ctx := context.WithValue(context.Background(), localAddrKey, ip) + // if v := ctx.Value(localAddrKey); v != nil { + // if ip, ok := v.(net.IP); ok { + // localAddr = &net.TCPAddr{ + // IP: ip, + // Port: 0, + // } + // } + // } + if ip != nil { + localAddr = &net.TCPAddr{ + IP: ip, + Port: 0, + } + } + d := &net.Dialer{ - Timeout: timeout, - Control: controlFunction, - // LocalAddr: laddr, // TODO: optional local address + Timeout: timeout, + Control: controlFunction, + LocalAddr: localAddr, } return d.DialContext(ctx, network, ipAddr) } diff --git a/cmd/gost/build.sh b/cmd/gost/build.sh index b0b07b9..1fc90fe 100644 --- a/cmd/gost/build.sh +++ b/cmd/gost/build.sh @@ -2,6 +2,6 @@ GOOS=linux GOARCH=amd64 go build -o gost -rsync -avz gost root@192.168.2.186:/root/gost/gost -rsync -avz gost root@192.168.3.35:/root/gost/gost -rsync -avz gost root@192.168.3.40:/root/gost/gost \ No newline at end of file +# rsync -avz gost root@192.168.2.186:/root/gost/gost +# rsync -avz gost root@192.168.3.35:/root/gost/gost +# rsync -avz gost root@192.168.3.40:/root/gost/gost \ No newline at end of file diff --git a/dns.go b/dns.go index cc67718..b87bdc4 100644 --- a/dns.go +++ b/dns.go @@ -82,7 +82,8 @@ func (h *dnsHandler) Handle(conn net.Conn) { if resolver == nil { resolver = defaultResolver } - reply, err := resolver.Exchange(context.Background(), b[:n]) + ip := GetIP(conn) + reply, err := resolver.Exchange(ip, context.Background(), b[:n]) if err != nil { log.Logf("[dns] %s - %s exchange: %v", conn.RemoteAddr(), conn.LocalAddr(), err) return diff --git a/examples/ssu/ssu.go b/examples/ssu/ssu.go index 6aeee1c..eae71fb 100644 --- a/examples/ssu/ssu.go +++ b/examples/ssu/ssu.go @@ -28,13 +28,13 @@ func ssuClient() { if err != nil { log.Fatal(err) } - cc := ss.NewSecurePacketConn(conn, cp, false) + cc := ss.NewSecurePacketConn(conn, cp) raddr, _ := net.ResolveUDPAddr("udp", ":8080") msg := []byte(`abcdefghijklmnopqrstuvwxyz`) dgram := gosocks5.NewUDPDatagram(gosocks5.NewUDPHeader(0, 0, toSocksAddr(raddr)), msg) buf := bytes.Buffer{} - dgram.Write(&buf) + dgram.WriteTo(&buf) for { log.Printf("%# x", buf.Bytes()[3:]) if _, err := cc.WriteTo(buf.Bytes()[3:], addr); err != nil { diff --git a/forward.go b/forward.go index 9985780..e2913e8 100644 --- a/forward.go +++ b/forward.go @@ -119,6 +119,7 @@ func (h *tcpDirectForwardHandler) Handle(conn net.Conn) { var cc net.Conn var node Node var err error + ip := GetIP(conn) for i := 0; i < retries; i++ { if len(h.group.Nodes()) > 0 { node, err = h.group.Next() @@ -128,7 +129,7 @@ func (h *tcpDirectForwardHandler) Handle(conn net.Conn) { } } - cc, err = h.options.Chain.Dial(node.Addr, + cc, err = h.options.Chain.Dial(ip, node.Addr, RetryChainOption(h.options.Retries), TimeoutChainOption(h.options.Timeout), ResolverChainOption(h.options.Resolver), @@ -197,13 +198,8 @@ func (h *udpDirectForwardHandler) Handle(conn net.Conn) { return } } - - cc, err := h.options.Chain.DialContext( - context.Background(), - "udp", - node.Addr, - ResolverChainOption(h.options.Resolver), - ) + ip := GetIP(conn) + cc, err := h.options.Chain.DialContext(ip, context.Background(), "udp", node.Addr, ResolverChainOption(h.options.Resolver)) if err != nil { node.MarkDead() log.Logf("[udp] %s - %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) @@ -452,7 +448,7 @@ func (l *tcpRemoteForwardListener) Accept() (conn net.Conn, err error) { func (l *tcpRemoteForwardListener) accept() (conn net.Conn, err error) { lastNode := l.chain.LastNode() if lastNode.Protocol == "forward" && lastNode.Transport == "ssh" { - return l.chain.Dial(l.addr.String()) + return l.chain.Dial(nil, l.addr.String()) } if l.isChainValid() { diff --git a/http.go b/http.go index 8f9e3fd..ed5fafd 100644 --- a/http.go +++ b/http.go @@ -268,8 +268,8 @@ func (h *httpHandler) handleRequest(conn net.Conn, req *http.Request) { log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) continue } - - cc, err = route.Dial(host, + ip := GetIP(conn) + cc, err = route.Dial(ip, host, TimeoutChainOption(h.options.Timeout), HostsChainOption(h.options.Hosts), ResolverChainOption(h.options.Resolver), @@ -358,7 +358,8 @@ func (h *httpHandler) authenticate(conn net.Conn, req *http.Request, resp *http. log.Logf("[http] %s -> %s : Authorization '%s' '%s'", conn.RemoteAddr(), conn.LocalAddr(), u, p) } - if h.options.Authenticator == nil || h.options.Authenticator.Authenticate(u, p) { + ip := GetIP(conn) + if h.options.Authenticator == nil || h.options.Authenticator.IFAuthenticate(ip, u, p) { return true } diff --git a/http2.go b/http2.go index de152ea..278644d 100644 --- a/http2.go +++ b/http2.go @@ -142,7 +142,7 @@ func (tr *http2Transporter) Dial(addr string, options ...DialOption) (net.Conn, if !ok { // NOTE: There is no real connection to the HTTP2 server at this moment. // So we try to connect to the server to check the server health. - conn, err := opts.Chain.Dial(addr) + conn, err := opts.Chain.Dial(nil, addr) if err != nil { log.Log("http2 dial:", addr, err) return nil, err @@ -156,7 +156,7 @@ func (tr *http2Transporter) Dial(addr string, options ...DialOption) (net.Conn, transport := http2.Transport{ TLSClientConfig: tr.tlsConfig, DialTLS: func(network, adr string, cfg *tls.Config) (net.Conn, error) { - conn, err := opts.Chain.Dial(adr) + conn, err := opts.Chain.Dial(nil, adr) if err != nil { return nil, err } @@ -234,7 +234,7 @@ func (tr *h2Transporter) Dial(addr string, options ...DialOption) (net.Conn, err transport := http2.Transport{ TLSClientConfig: tr.tlsConfig, DialTLS: func(network, adr string, cfg *tls.Config) (net.Conn, error) { - conn, err := opts.Chain.Dial(addr) + conn, err := opts.Chain.Dial(nil, addr) if err != nil { return nil, err } @@ -338,10 +338,11 @@ func (h *http2Handler) Handle(conn net.Conn) { return } - h.roundTrip(h2c.w, h2c.r) + ip := GetIP(conn) + h.roundTrip(ip, h2c.w, h2c.r) } -func (h *http2Handler) roundTrip(w http.ResponseWriter, r *http.Request) { +func (h *http2Handler) roundTrip(ip net.IP, w http.ResponseWriter, r *http.Request) { host := r.Header.Get("Gost-Target") if host == "" { host = r.Host @@ -391,7 +392,7 @@ func (h *http2Handler) roundTrip(w http.ResponseWriter, r *http.Request) { Body: io.NopCloser(bytes.NewReader([]byte{})), } - if !h.authenticate(w, r, resp) { + if !h.authenticate(ip, w, r, resp) { return } @@ -427,7 +428,7 @@ func (h *http2Handler) roundTrip(w http.ResponseWriter, r *http.Request) { fmt.Fprintf(&buf, "%s", host) log.Log("[route]", buf.String()) - cc, err = route.Dial(host, + cc, err = route.Dial(ip, host, TimeoutChainOption(h.options.Timeout), HostsChainOption(h.options.Hosts), ResolverChainOption(h.options.Resolver), @@ -481,13 +482,13 @@ func (h *http2Handler) roundTrip(w http.ResponseWriter, r *http.Request) { log.Logf("[http2] %s >-< %s", r.RemoteAddr, host) } -func (h *http2Handler) authenticate(w http.ResponseWriter, r *http.Request, resp *http.Response) (ok bool) { +func (h *http2Handler) authenticate(ip net.IP, w http.ResponseWriter, r *http.Request, resp *http.Response) (ok bool) { laddr := h.options.Addr u, p, _ := basicProxyAuth(r.Header.Get("Proxy-Authorization")) if Debug && (u != "" || p != "") { log.Logf("[http2] %s - %s : Authorization '%s' '%s'", r.RemoteAddr, laddr, u, p) } - if h.options.Authenticator == nil || h.options.Authenticator.Authenticate(u, p) { + if h.options.Authenticator == nil || h.options.Authenticator.IFAuthenticate(ip, u, p) { return true } diff --git a/redirect.go b/redirect.go index 199d5c2..d48d66c 100644 --- a/redirect.go +++ b/redirect.go @@ -53,8 +53,8 @@ func (h *tcpRedirectHandler) Handle(c net.Conn) { defer conn.Close() log.Logf("[red-tcp] %s -> %s", srcAddr, dstAddr) - - cc, err := h.options.Chain.DialContext(context.Background(), + ip := GetIP(c) + cc, err := h.options.Chain.DialContext(ip, context.Background(), "tcp", dstAddr.String(), RetryChainOption(h.options.Retries), TimeoutChainOption(h.options.Timeout), @@ -134,8 +134,8 @@ func (h *udpRedirectHandler) Handle(conn net.Conn) { log.Log("[red-udp] wrong connection type") return } - - cc, err := h.options.Chain.DialContext(context.Background(), + ip := GetIP(conn) + cc, err := h.options.Chain.DialContext(ip, context.Background(), "udp", raddr.String(), RetryChainOption(h.options.Retries), TimeoutChainOption(h.options.Timeout), diff --git a/relay.go b/relay.go index 7df8137..0dd36c3 100644 --- a/relay.go +++ b/relay.go @@ -165,7 +165,8 @@ func (h *relayHandler) Handle(conn net.Conn) { Version: relay.Version1, Status: relay.StatusOK, } - if h.options.Authenticator != nil && !h.options.Authenticator.Authenticate(user, pass) { + ip := GetIP(conn) + if h.options.Authenticator != nil && !h.options.Authenticator.IFAuthenticate(ip, user, pass) { resp.Status = relay.StatusUnauthorized resp.WriteTo(conn) log.Logf("[relay] %s -> %s : %s unauthorized", conn.RemoteAddr(), conn.LocalAddr(), user) @@ -228,7 +229,8 @@ func (h *relayHandler) Handle(conn net.Conn) { } log.Logf("[relay] %s -> %s -> %s", conn.RemoteAddr(), conn.LocalAddr(), raddr) - cc, err = h.options.Chain.DialContext(ctx, + + cc, err = h.options.Chain.DialContext(ip, ctx, network, raddr, RetryChainOption(h.options.Retries), TimeoutChainOption(h.options.Timeout), diff --git a/resolver.go b/resolver.go index 0f07ceb..5a708f0 100644 --- a/resolver.go +++ b/resolver.go @@ -173,7 +173,7 @@ type Resolver interface { Resolve(host string) ([]net.IP, error) // Exchange performs a synchronous query, // It sends the message query and waits for a reply. - Exchange(ctx context.Context, query []byte) (reply []byte, err error) + Exchange(ip net.IP, ctx context.Context, query []byte) (reply []byte, err error) } // ReloadResolver is resolover that support live reloading. @@ -282,7 +282,7 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) { ctx := context.Background() for _, ns := range r.copyServers() { - ips, err = r.resolve(ctx, ns.exchanger, host) + ips, err = r.resolve(ip, ctx, ns.exchanger, host) if err != nil { log.Logf("[resolver] %s via %s : %s", host, ns.String(), err) continue @@ -299,7 +299,7 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) { return } -func (r *resolver) resolve(ctx context.Context, ex Exchanger, host string) (ips []net.IP, err error) { +func (r *resolver) resolve(ip net.IP, ctx context.Context, ex Exchanger, host string) (ips []net.IP, err error) { if ex == nil { return } @@ -309,36 +309,36 @@ func (r *resolver) resolve(ctx context.Context, ex Exchanger, host string) (ips r.mux.RUnlock() if prefer == "ipv6" { // prefer ipv6 - if ips, err = r.resolve6(ctx, ex, host); len(ips) > 0 { + if ips, err = r.resolve6(ip, ctx, ex, host); len(ips) > 0 { return } - return r.resolve4(ctx, ex, host) + return r.resolve4(ip, ctx, ex, host) } - if ips, err = r.resolve4(ctx, ex, host); len(ips) > 0 { + if ips, err = r.resolve4(ip, ctx, ex, host); len(ips) > 0 { return } - return r.resolve6(ctx, ex, host) + return r.resolve6(ip, ctx, ex, host) } -func (r *resolver) resolve4(ctx context.Context, ex Exchanger, host string) (ips []net.IP, err error) { +func (r *resolver) resolve4(ip net.IP, ctx context.Context, ex Exchanger, host string) (ips []net.IP, err error) { mq := dns.Msg{} mq.SetQuestion(dns.Fqdn(host), dns.TypeA) - return r.resolveIPs(ctx, ex, &mq) + return r.resolveIPs(ip, ctx, ex, &mq) } -func (r *resolver) resolve6(ctx context.Context, ex Exchanger, host string) (ips []net.IP, err error) { +func (r *resolver) resolve6(ip net.IP, ctx context.Context, ex Exchanger, host string) (ips []net.IP, err error) { mq := dns.Msg{} mq.SetQuestion(dns.Fqdn(host), dns.TypeAAAA) - return r.resolveIPs(ctx, ex, &mq) + return r.resolveIPs(ip, ctx, ex, &mq) } -func (r *resolver) resolveIPs(ctx context.Context, ex Exchanger, mq *dns.Msg) (ips []net.IP, err error) { +func (r *resolver) resolveIPs(ip net.IP, ctx context.Context, ex Exchanger, mq *dns.Msg) (ips []net.IP, err error) { key := newResolverCacheKey(&mq.Question[0]) mr := r.cache.loadCache(key) if mr == nil { r.addSubnetOpt(mq) - mr, err = r.exchangeMsg(ctx, ex, mq) + mr, err = r.exchangeMsg(ip, ctx, ex, mq) if err != nil { return } @@ -379,7 +379,7 @@ func (r *resolver) addSubnetOpt(m *dns.Msg) { m.Extra = append(m.Extra, opt) } -func (r *resolver) Exchange(ctx context.Context, query []byte) (reply []byte, err error) { +func (r *resolver) Exchange(ip net.IP, ctx context.Context, query []byte) (reply []byte, err error) { mq := &dns.Msg{} if err = mq.Unpack(query); err != nil { return @@ -411,7 +411,7 @@ func (r *resolver) Exchange(ctx context.Context, query []byte) (reply []byte, er for _, ns := range r.copyServers() { log.Logf("[dns] exchange message %d via %s: %s", mq.Id, ns.String(), mq.Question[0].String()) - mr, err = r.exchangeMsg(ctx, ns.exchanger, mq) + mr, err = r.exchangeMsg(ip, ctx, ns.exchanger, mq) if err == nil { break } @@ -423,12 +423,12 @@ func (r *resolver) Exchange(ctx context.Context, query []byte) (reply []byte, er return mr.Pack() } -func (r *resolver) exchangeMsg(ctx context.Context, ex Exchanger, mq *dns.Msg) (mr *dns.Msg, err error) { +func (r *resolver) exchangeMsg(ip net.IP, ctx context.Context, ex Exchanger, mq *dns.Msg) (mr *dns.Msg, err error) { query, err := mq.Pack() if err != nil { return } - reply, err := ex.Exchange(ctx, query) + reply, err := ex.Exchange(ip, ctx, query) if err != nil { return } @@ -657,7 +657,7 @@ func (rc *resolverCache) storeCache(key resolverCacheKey, mr *dns.Msg, ttl time. // Exchanger is an interface for DNS synchronous query. type Exchanger interface { - Exchange(ctx context.Context, query []byte) ([]byte, error) + Exchange(ip net.IP, ctx context.Context, query []byte) ([]byte, error) } type exchangerOptions struct { @@ -704,9 +704,9 @@ func NewDNSExchanger(addr string, opts ...ExchangerOption) Exchanger { } } -func (ex *dnsExchanger) Exchange(ctx context.Context, query []byte) ([]byte, error) { +func (ex *dnsExchanger) Exchange(ip net.IP, ctx context.Context, query []byte) ([]byte, error) { t := time.Now() - c, err := ex.options.chain.DialContext(ctx, + c, err := ex.options.chain.DialContext(ip, ctx, "udp", ex.addr, TimeoutChainOption(ex.options.timeout), ) @@ -754,9 +754,9 @@ func NewDNSTCPExchanger(addr string, opts ...ExchangerOption) Exchanger { } } -func (ex *dnsTCPExchanger) Exchange(ctx context.Context, query []byte) ([]byte, error) { +func (ex *dnsTCPExchanger) Exchange(ip net.IP, ctx context.Context, query []byte) ([]byte, error) { t := time.Now() - c, err := ex.options.chain.DialContext(ctx, + c, err := ex.options.chain.DialContext(ip, ctx, "tcp", ex.addr, TimeoutChainOption(ex.options.timeout), ) @@ -782,6 +782,7 @@ func (ex *dnsTCPExchanger) Exchange(ctx context.Context, query []byte) ([]byte, } type dotExchanger struct { + ip net.IP addr string tlsConfig *tls.Config options exchangerOptions @@ -810,8 +811,8 @@ func NewDoTExchanger(addr string, tlsConfig *tls.Config, opts ...ExchangerOption } } -func (ex *dotExchanger) dial(ctx context.Context, network, address string) (conn net.Conn, err error) { - conn, err = ex.options.chain.DialContext(ctx, +func (ex *dotExchanger) dial(ip net.IP, ctx context.Context, network, address string) (conn net.Conn, err error) { + conn, err = ex.options.chain.DialContext(ip, ctx, network, address, TimeoutChainOption(ex.options.timeout), ) @@ -823,9 +824,9 @@ func (ex *dotExchanger) dial(ctx context.Context, network, address string) (conn return } -func (ex *dotExchanger) Exchange(ctx context.Context, query []byte) ([]byte, error) { +func (ex *dotExchanger) Exchange(ip net.IP, ctx context.Context, query []byte) ([]byte, error) { t := time.Now() - c, err := ex.dial(ctx, "tcp", ex.addr) + c, err := ex.dial(ip, ctx, "tcp", ex.addr) if err != nil { return nil, err } @@ -881,8 +882,9 @@ func NewDoHExchanger(urlStr *url.URL, tlsConfig *tls.Config, opts ...ExchangerOp return ex } -func (ex *dohExchanger) dialContext(ctx context.Context, network, address string) (net.Conn, error) { - return ex.options.chain.DialContext(ctx, +func (ex *dohExchanger) dialContext(ip net.IP, ctx context.Context, network, address string) (net.Conn, error) { + // todo:: + return ex.options.chain.DialContext(ip, ctx, network, address, TimeoutChainOption(ex.options.timeout), ) diff --git a/sni.go b/sni.go index 7d4c268..2b532c0 100644 --- a/sni.go +++ b/sni.go @@ -134,6 +134,7 @@ func (h *sniHandler) Handle(conn net.Conn) { var cc net.Conn var route *Chain + ip := GetIP(conn) for i := 0; i < retries; i++ { route, err = h.options.Chain.selectRouteFor(host) if err != nil { @@ -151,7 +152,7 @@ func (h *sniHandler) Handle(conn net.Conn) { fmt.Fprintf(&buf, "%s", host) log.Log("[route]", buf.String()) - cc, err = route.Dial(host, + cc, err = route.Dial(ip, host, TimeoutChainOption(h.options.Timeout), HostsChainOption(h.options.Hosts), ResolverChainOption(h.options.Resolver), diff --git a/socks.go b/socks.go index 69e4cd0..a7a1583 100644 --- a/socks.go +++ b/socks.go @@ -168,8 +168,8 @@ func (selector *serverSelector) OnSelected(method uint8, conn net.Conn) (string, if Debug { log.Logf("[socks5] %s - %s: %s", conn.RemoteAddr(), conn.LocalAddr(), req.String()) } - - if selector.Authenticator != nil && !selector.Authenticator.Authenticate(req.Username, req.Password) { + ip := GetIP(conn) + if selector.Authenticator != nil && !selector.Authenticator.IFAuthenticate(ip, req.Username, req.Password) { resp := gosocks5.NewUserPassResponse(gosocks5.UserPassVer, gosocks5.Failure) if err := resp.Write(conn); err != nil { log.Logf("[socks5] %s - %s: %s", conn.RemoteAddr(), conn.LocalAddr(), err) @@ -431,7 +431,7 @@ func (tr *socks5MuxBindTransporter) Dial(addr string, options ...DialOption) (co if opts.Chain == nil { conn, err = net.DialTimeout("tcp", addr, timeout) } else { - conn, err = opts.Chain.Dial(addr) + conn, err = opts.Chain.Dial(nil, addr) } if err != nil { return @@ -923,6 +923,7 @@ func (h *socks5Handler) handleConnect(conn net.Conn, req *gosocks5.Request) { var err error var cc net.Conn var route *Chain + ip := GetIP(conn) for i := 0; i < retries; i++ { route, err = h.options.Chain.selectRouteFor(host) if err != nil { @@ -940,7 +941,7 @@ func (h *socks5Handler) handleConnect(conn net.Conn, req *gosocks5.Request) { fmt.Fprintf(&buf, "%s", host) log.Log("[route]", buf.String()) - cc, err = route.Dial(host, + cc, err = route.Dial(ip, host, TimeoutChainOption(h.options.Timeout), HostsChainOption(h.options.Hosts), ResolverChainOption(h.options.Resolver), @@ -1773,6 +1774,7 @@ func (h *socks4Handler) handleConnect(conn net.Conn, req *gosocks4.Request) { var err error var cc net.Conn var route *Chain + ip := GetIP(conn) for i := 0; i < retries; i++ { route, err = h.options.Chain.selectRouteFor(addr) if err != nil { @@ -1790,7 +1792,7 @@ func (h *socks4Handler) handleConnect(conn net.Conn, req *gosocks4.Request) { fmt.Fprintf(&buf, "%s", addr) log.Log("[route]", buf.String()) - cc, err = route.Dial(addr, + cc, err = route.Dial(ip, addr, TimeoutChainOption(h.options.Timeout), HostsChainOption(h.options.Hosts), ResolverChainOption(h.options.Resolver), diff --git a/ss.go b/ss.go index ac45563..4dbf7c4 100644 --- a/ss.go +++ b/ss.go @@ -164,6 +164,7 @@ func (h *shadowHandler) Handle(conn net.Conn) { var cc net.Conn var route *Chain + ip := GetIP(conn) for i := 0; i < retries; i++ { route, err = h.options.Chain.selectRouteFor(host) if err != nil { @@ -180,8 +181,7 @@ func (h *shadowHandler) Handle(conn net.Conn) { } fmt.Fprintf(&buf, "%s", host) log.Log("[route]", buf.String()) - - cc, err = route.Dial(host, + cc, err = route.Dial(ip, host, TimeoutChainOption(h.options.Timeout), HostsChainOption(h.options.Hosts), ResolverChainOption(h.options.Resolver), @@ -297,7 +297,8 @@ func (h *shadowUDPHandler) Handle(conn net.Conn) { defer conn.Close() var cc net.PacketConn - c, err := h.options.Chain.DialContext(context.Background(), "udp", "") + ip := GetIP(conn) + c, err := h.options.Chain.DialContext(ip, context.Background(), "udp", "") if err != nil { log.Logf("[ssu] %s: %s", conn.LocalAddr(), err) return diff --git a/ssh.go b/ssh.go index a78e4c3..8d17bb2 100644 --- a/ssh.go +++ b/ssh.go @@ -194,7 +194,7 @@ func (tr *sshForwardTransporter) Dial(addr string, options ...DialOption) (conn if opts.Chain == nil { conn, err = net.DialTimeout("tcp", addr, timeout) } else { - conn, err = opts.Chain.Dial(addr) + conn, err = opts.Chain.Dial(nil, addr) } if err != nil { return @@ -308,7 +308,7 @@ func (tr *sshTunnelTransporter) Dial(addr string, options ...DialOption) (conn n if opts.Chain == nil { conn, err = net.DialTimeout("tcp", addr, timeout) } else { - conn, err = opts.Chain.Dial(addr) + conn, err = opts.Chain.Dial(nil, addr) } if err != nil { return @@ -583,9 +583,9 @@ func (h *sshForwardHandler) handleForward(conn ssh.Conn, chans <-chan ssh.NewCha if p.Host1 == "" { p.Host1 = "" } - + ip := GetSshIP(conn) go ssh.DiscardRequests(requests) - go h.directPortForwardChannel(channel, fmt.Sprintf("%s:%d", p.Host1, p.Port1)) + go h.directPortForwardChannel(ip, channel, fmt.Sprintf("%s:%d", p.Host1, p.Port1)) default: log.Log("[ssh] Unknown channel type:", t) newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t)) @@ -596,7 +596,7 @@ func (h *sshForwardHandler) handleForward(conn ssh.Conn, chans <-chan ssh.NewCha conn.Wait() } -func (h *sshForwardHandler) directPortForwardChannel(channel ssh.Channel, raddr string) { +func (h *sshForwardHandler) directPortForwardChannel(ip net.IP, channel ssh.Channel, raddr string) { defer channel.Close() log.Logf("[ssh-tcp] %s - %s", h.options.Node.Addr, raddr) @@ -611,7 +611,7 @@ func (h *sshForwardHandler) directPortForwardChannel(channel ssh.Channel, raddr return } - conn, err := h.options.Chain.Dial(raddr, + conn, err := h.options.Chain.Dial(ip, raddr, RetryChainOption(h.options.Retries), TimeoutChainOption(h.options.Timeout), HostsChainOption(h.options.Hosts), @@ -868,7 +868,8 @@ func defaultSSHPasswordCallback(au Authenticator) PasswordCallbackFunc { return nil } return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { - if au.Authenticate(conn.User(), string(password)) { + ip := GetSshIP(conn) + if au.IFAuthenticate(ip, conn.User(), string(password)) { return nil, nil } log.Logf("[ssh] %s -> %s : password rejected for %s", conn.RemoteAddr(), conn.LocalAddr(), conn.User()) diff --git a/tcp.go b/tcp.go index a255011..1e1ecd9 100644 --- a/tcp.go +++ b/tcp.go @@ -1,6 +1,8 @@ package gost -import "net" +import ( + "net" +) // tcpTransporter is a raw TCP transporter. type tcpTransporter struct{} @@ -23,7 +25,7 @@ func (tr *tcpTransporter) Dial(addr string, options ...DialOption) (net.Conn, er if opts.Chain == nil { return net.DialTimeout("tcp", addr, timeout) } - return opts.Chain.Dial(addr) + return opts.Chain.Dial(nil, addr) } func (tr *tcpTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { diff --git a/tls.go b/tls.go index 8526c6f..a0931f6 100644 --- a/tls.go +++ b/tls.go @@ -74,7 +74,7 @@ func (tr *mtlsTransporter) Dial(addr string, options ...DialOption) (conn net.Co if opts.Chain == nil { conn, err = net.DialTimeout("tcp", addr, timeout) } else { - conn, err = opts.Chain.Dial(addr) + conn, err = opts.Chain.Dial(nil, addr) } if err != nil { return diff --git a/tuntap.go b/tuntap.go index e38ab63..9edf7b1 100644 --- a/tuntap.go +++ b/tuntap.go @@ -162,13 +162,14 @@ func (h *tunHandler) Handle(conn net.Conn) { } var tempDelay time.Duration + ip := GetIP(conn) for { err := func() error { var err error var pc net.PacketConn // fake tcp mode will be ignored when the client specifies a chain. if raddr != nil && !h.options.Chain.IsEmpty() { - cc, err := h.options.Chain.DialContext(context.Background(), "udp", raddr.String()) + cc, err := h.options.Chain.DialContext(ip, context.Background(), "udp", raddr.String()) if err != nil { return err } @@ -550,15 +551,15 @@ func (h *tapHandler) Handle(conn net.Conn) { return } } - var tempDelay time.Duration + ip := GetIP(conn) for { err := func() error { var err error var pc net.PacketConn // fake tcp mode will be ignored when the client specifies a chain. if raddr != nil && !h.options.Chain.IsEmpty() { - cc, err := h.options.Chain.DialContext(context.Background(), "udp", raddr.String()) + cc, err := h.options.Chain.DialContext(ip, context.Background(), "udp", raddr.String()) if err != nil { return err } diff --git a/vsock.go b/vsock.go index 51aa6de..3d98cb8 100644 --- a/vsock.go +++ b/vsock.go @@ -27,10 +27,10 @@ func (tr *vsockTransporter) Dial(addr string, options ...DialOption) (net.Conn, } return vsock.Dial(vAddr.ContextID, vAddr.Port, nil) } - return opts.Chain.Dial(addr) + return opts.Chain.Dial(nil, addr) } -func parseUint32(s string) (uint32, error ) { +func parseUint32(s string) (uint32, error) { n, err := strconv.ParseUint(s, 10, 32) if err != nil { return 0, err diff --git a/ws.go b/ws.go index 9dc8f0d..51d6dda 100644 --- a/ws.go +++ b/ws.go @@ -104,7 +104,7 @@ func (tr *mwsTransporter) Dial(addr string, options ...DialOption) (conn net.Con if opts.Chain == nil { conn, err = net.DialTimeout("tcp", addr, timeout) } else { - conn, err = opts.Chain.Dial(addr) + conn, err = opts.Chain.Dial(nil, addr) } if err != nil { return @@ -261,7 +261,7 @@ func (tr *mwssTransporter) Dial(addr string, options ...DialOption) (conn net.Co if opts.Chain == nil { conn, err = net.DialTimeout("tcp", addr, timeout) } else { - conn, err = opts.Chain.Dial(addr) + conn, err = opts.Chain.Dial(nil, addr) } if err != nil { return