diff --git a/chain.go b/chain.go index d001980..8ad7a89 100644 --- a/chain.go +++ b/chain.go @@ -9,13 +9,11 @@ import ( "time" "github.com/go-log/log" - "golang.org/x/crypto/ssh" ) var ( // ErrEmptyChain is an error that implies the chain is empty. ErrEmptyChain = errors.New("empty chain") - localAddrKey = "localAddr" ) // Chain is a proxy chain that holds a list of proxy node groups. @@ -111,29 +109,14 @@ func (c *Chain) IsEmpty() bool { return c == nil || len(c.nodeGroups) == 0 } -func GetIP(conn net.Conn) (ip net.IP) { - IP := conn.LocalAddr().(*net.TCPAddr).IP - if IP != nil && !IP.IsPrivate() && !IP.IsLoopback() { - return IP - } - return nil -} -func GetSshIP(conn ssh.ConnMetadata) (ip net.IP) { - IP := conn.LocalAddr().(*net.TCPAddr).IP - if IP != nil && !IP.IsPrivate() && !IP.IsLoopback() { - return IP - } - return nil -} - // Dial connects to the target TCP address addr through the chain. // Deprecated: use DialContext instead. -func (c *Chain) Dial(ip net.IP, address string, opts ...ChainOption) (conn net.Conn, err error) { - return c.DialContext(ip, context.Background(), "tcp", address, opts...) +func (c *Chain) Dial(address string, opts ...ChainOption) (conn net.Conn, err error) { + return c.DialContext(context.Background(), "tcp", address, opts...) } // DialContext connects to the address on the named network using the provided context. -func (c *Chain) DialContext(ip net.IP, ctx context.Context, network, address string, opts ...ChainOption) (conn net.Conn, err error) { +func (c *Chain) DialContext(ctx context.Context, network, address string, opts ...ChainOption) (conn net.Conn, err error) { options := &ChainOptions{} for _, opt := range opts { opt(options) @@ -148,7 +131,7 @@ func (c *Chain) DialContext(ip net.IP, ctx context.Context, network, address str } for i := 0; i < retries; i++ { - conn, err = c.dialWithOptions(ip, ctx, network, address, options) + conn, err = c.dialWithOptions(ctx, network, address, options) if err == nil { break } @@ -156,7 +139,7 @@ func (c *Chain) DialContext(ip net.IP, ctx context.Context, network, address str return } -func (c *Chain) dialWithOptions(ip net.IP, ctx context.Context, network, address string, options *ChainOptions) (net.Conn, error) { +func (c *Chain) dialWithOptions(ctx context.Context, network, address string, options *ChainOptions) (net.Conn, error) { if options == nil { options = &ChainOptions{} } @@ -215,24 +198,8 @@ func (c *Chain) dialWithOptions(ip net.IP, ctx context.Context, network, address } default: } - var localAddr net.Addr - // ip := cn.LocalAddr().(*net.TCPAddr).IP - // ctx := context.WithValue(context.Background(), localAddrKey, ip) - // if v := ctx.Value(localAddrKey); v != nil { - // if ip, ok := v.(net.IP); ok { - // localAddr = &net.TCPAddr{ - // IP: ip, - // Port: 0, - // } - // } - // } - if ip != nil { - localAddr = &net.TCPAddr{ - IP: ip, - Port: 0, - } - } + localAddr := getLocalAddr(ctx) d := &net.Dialer{ Timeout: timeout, Control: controlFunction, @@ -407,6 +374,7 @@ type ChainOptions struct { Hosts *Hosts Resolver Resolver Mark int + IP net.IP } // ChainOption allows a common way to set chain options. @@ -439,3 +407,9 @@ func ResolverChainOption(resolver Resolver) ChainOption { opts.Resolver = resolver } } + +func IPChainOption(ip net.IP) ChainOption { + return func(opts *ChainOptions) { + opts.IP = ip + } +} diff --git a/dns.go b/dns.go index b87bdc4..efa8759 100644 --- a/dns.go +++ b/dns.go @@ -82,8 +82,8 @@ func (h *dnsHandler) Handle(conn net.Conn) { if resolver == nil { resolver = defaultResolver } - ip := GetIP(conn) - reply, err := resolver.Exchange(ip, context.Background(), b[:n]) + ctx := getContext(conn, context.Background()) + reply, err := resolver.Exchange(ctx, b[:n]) if err != nil { log.Logf("[dns] %s - %s exchange: %v", conn.RemoteAddr(), conn.LocalAddr(), err) return diff --git a/forward.go b/forward.go index e2913e8..e409ff6 100644 --- a/forward.go +++ b/forward.go @@ -119,7 +119,7 @@ func (h *tcpDirectForwardHandler) Handle(conn net.Conn) { var cc net.Conn var node Node var err error - ip := GetIP(conn) + ip := getIP(conn) for i := 0; i < retries; i++ { if len(h.group.Nodes()) > 0 { node, err = h.group.Next() @@ -129,10 +129,11 @@ func (h *tcpDirectForwardHandler) Handle(conn net.Conn) { } } - cc, err = h.options.Chain.Dial(ip, node.Addr, + cc, err = h.options.Chain.Dial(node.Addr, RetryChainOption(h.options.Retries), TimeoutChainOption(h.options.Timeout), ResolverChainOption(h.options.Resolver), + IPChainOption(ip), ) if err != nil { log.Logf("[tcp] %s -> %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) @@ -198,8 +199,10 @@ func (h *udpDirectForwardHandler) Handle(conn net.Conn) { return } } - ip := GetIP(conn) - cc, err := h.options.Chain.DialContext(ip, context.Background(), "udp", node.Addr, ResolverChainOption(h.options.Resolver)) + ip := getIP(conn) + cc, err := h.options.Chain.DialContext(context.Background(), "udp", node.Addr, + IPChainOption(ip), + ResolverChainOption(h.options.Resolver)) if err != nil { node.MarkDead() log.Logf("[udp] %s - %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) @@ -448,7 +451,8 @@ func (l *tcpRemoteForwardListener) Accept() (conn net.Conn, err error) { func (l *tcpRemoteForwardListener) accept() (conn net.Conn, err error) { lastNode := l.chain.LastNode() if lastNode.Protocol == "forward" && lastNode.Transport == "ssh" { - return l.chain.Dial(nil, l.addr.String()) + ip := getIP(conn) + return l.chain.Dial(l.addr.String(), IPChainOption(ip)) } if l.isChainValid() { diff --git a/http.go b/http.go index ed5fafd..509a24a 100644 --- a/http.go +++ b/http.go @@ -239,6 +239,7 @@ func (h *httpHandler) handleRequest(conn net.Conn, req *http.Request) { var err error var cc net.Conn var route *Chain + ip := getIP(conn) for i := 0; i < retries; i++ { route, err = h.options.Chain.selectRouteFor(host) if err != nil { @@ -268,11 +269,11 @@ func (h *httpHandler) handleRequest(conn net.Conn, req *http.Request) { log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) continue } - ip := GetIP(conn) - cc, err = route.Dial(ip, host, + cc, err = route.Dial(host, TimeoutChainOption(h.options.Timeout), HostsChainOption(h.options.Hosts), ResolverChainOption(h.options.Resolver), + IPChainOption(ip), ) if err == nil { break @@ -358,7 +359,7 @@ func (h *httpHandler) authenticate(conn net.Conn, req *http.Request, resp *http. log.Logf("[http] %s -> %s : Authorization '%s' '%s'", conn.RemoteAddr(), conn.LocalAddr(), u, p) } - ip := GetIP(conn) + ip := getIP(conn) if h.options.Authenticator == nil || h.options.Authenticator.IFAuthenticate(ip, u, p) { return true } diff --git a/http2.go b/http2.go index 278644d..e79a73d 100644 --- a/http2.go +++ b/http2.go @@ -338,7 +338,7 @@ func (h *http2Handler) Handle(conn net.Conn) { return } - ip := GetIP(conn) + ip := getIP(conn) h.roundTrip(ip, h2c.w, h2c.r) } @@ -427,11 +427,11 @@ func (h *http2Handler) roundTrip(ip net.IP, w http.ResponseWriter, r *http.Reque } fmt.Fprintf(&buf, "%s", host) log.Log("[route]", buf.String()) - - cc, err = route.Dial(ip, host, + cc, err = route.Dial(host, TimeoutChainOption(h.options.Timeout), HostsChainOption(h.options.Hosts), ResolverChainOption(h.options.Resolver), + IPChainOption(ip), ) if err == nil { break diff --git a/localAddr.go b/localAddr.go new file mode 100644 index 0000000..3602725 --- /dev/null +++ b/localAddr.go @@ -0,0 +1,60 @@ +package gost + +import ( + "context" + "net" + + "golang.org/x/crypto/ssh" +) + +var localAddrKey = "localAddr" + +func getIP(conn net.Conn) (ip net.IP) { + IP := conn.LocalAddr().(*net.TCPAddr).IP + if IP != nil && !IP.IsPrivate() && !IP.IsLoopback() { + return IP + } + return nil +} + +func getContext1(ip net.IP) (ctx context.Context) { + ctx = context.Background() + if ip != nil { + return context.WithValue(ctx, localAddrKey, ip) + } + return +} +func getContext(conn net.Conn, parentCtx context.Context) (ctx context.Context) { + IP := getIP(conn) + if IP != nil { + return context.WithValue(parentCtx, localAddrKey, IP) + } + return parentCtx +} + +func GetIP(ctx context.Context) (ip net.IP) { + if v := ctx.Value(localAddrKey); v != nil { + if ip, ok := v.(net.IP); ok && !ip.IsPrivate() && !ip.IsLoopback() { + return ip + } + } + return nil +} +func GetSshIP(conn ssh.ConnMetadata) (ip net.IP) { + IP := conn.LocalAddr().(*net.TCPAddr).IP + if IP != nil && !IP.IsPrivate() && !IP.IsLoopback() { + return IP + } + return nil +} + +func getLocalAddr(ctx context.Context) (addr net.Addr) { + ip := GetIP(ctx) + if ip != nil { + addr = &net.TCPAddr{ + IP: ip, + Port: 0, + } + } + return +} diff --git a/relay.go b/relay.go index 0dd36c3..ba3569b 100644 --- a/relay.go +++ b/relay.go @@ -165,7 +165,7 @@ func (h *relayHandler) Handle(conn net.Conn) { Version: relay.Version1, Status: relay.StatusOK, } - ip := GetIP(conn) + ip := getIP(conn) if h.options.Authenticator != nil && !h.options.Authenticator.IFAuthenticate(ip, user, pass) { resp.Status = relay.StatusUnauthorized resp.WriteTo(conn) @@ -230,10 +230,11 @@ func (h *relayHandler) Handle(conn net.Conn) { log.Logf("[relay] %s -> %s -> %s", conn.RemoteAddr(), conn.LocalAddr(), raddr) - cc, err = h.options.Chain.DialContext(ip, ctx, + cc, err = h.options.Chain.DialContext(ctx, network, raddr, RetryChainOption(h.options.Retries), TimeoutChainOption(h.options.Timeout), + IPChainOption(ip), ) if err != nil { log.Logf("[relay] %s -> %s : %s", conn.RemoteAddr(), raddr, err) diff --git a/resolver.go b/resolver.go index 5a708f0..4cecc40 100644 --- a/resolver.go +++ b/resolver.go @@ -173,7 +173,7 @@ type Resolver interface { Resolve(host string) ([]net.IP, error) // Exchange performs a synchronous query, // It sends the message query and waits for a reply. - Exchange(ip net.IP, ctx context.Context, query []byte) (reply []byte, err error) + Exchange(ctx context.Context, query []byte) (reply []byte, err error) } // ReloadResolver is resolover that support live reloading. @@ -379,7 +379,7 @@ func (r *resolver) addSubnetOpt(m *dns.Msg) { m.Extra = append(m.Extra, opt) } -func (r *resolver) Exchange(ip net.IP, ctx context.Context, query []byte) (reply []byte, err error) { +func (r *resolver) Exchange(ctx context.Context, query []byte) (reply []byte, err error) { mq := &dns.Msg{} if err = mq.Unpack(query); err != nil { return @@ -408,10 +408,9 @@ func (r *resolver) Exchange(ip net.IP, ctx context.Context, query []byte) (reply } r.addSubnetOpt(mq) - for _, ns := range r.copyServers() { log.Logf("[dns] exchange message %d via %s: %s", mq.Id, ns.String(), mq.Question[0].String()) - mr, err = r.exchangeMsg(ip, ctx, ns.exchanger, mq) + mr, err = r.exchangeMsg(ctx, ns.exchanger, mq) if err == nil { break } @@ -423,12 +422,12 @@ func (r *resolver) Exchange(ip net.IP, ctx context.Context, query []byte) (reply return mr.Pack() } -func (r *resolver) exchangeMsg(ip net.IP, ctx context.Context, ex Exchanger, mq *dns.Msg) (mr *dns.Msg, err error) { +func (r *resolver) exchangeMsg(ctx context.Context, ex Exchanger, mq *dns.Msg) (mr *dns.Msg, err error) { query, err := mq.Pack() if err != nil { return } - reply, err := ex.Exchange(ip, ctx, query) + reply, err := ex.Exchange(ctx, query) if err != nil { return } @@ -657,7 +656,7 @@ func (rc *resolverCache) storeCache(key resolverCacheKey, mr *dns.Msg, ttl time. // Exchanger is an interface for DNS synchronous query. type Exchanger interface { - Exchange(ip net.IP, ctx context.Context, query []byte) ([]byte, error) + Exchange(ctx context.Context, query []byte) ([]byte, error) } type exchangerOptions struct { @@ -704,9 +703,9 @@ func NewDNSExchanger(addr string, opts ...ExchangerOption) Exchanger { } } -func (ex *dnsExchanger) Exchange(ip net.IP, ctx context.Context, query []byte) ([]byte, error) { +func (ex *dnsExchanger) Exchange(ctx context.Context, query []byte) ([]byte, error) { t := time.Now() - c, err := ex.options.chain.DialContext(ip, ctx, + c, err := ex.options.chain.DialContext(ctx, "udp", ex.addr, TimeoutChainOption(ex.options.timeout), ) diff --git a/sni.go b/sni.go index 2b532c0..38500ff 100644 --- a/sni.go +++ b/sni.go @@ -134,7 +134,7 @@ func (h *sniHandler) Handle(conn net.Conn) { var cc net.Conn var route *Chain - ip := GetIP(conn) + ip := getIP(conn) for i := 0; i < retries; i++ { route, err = h.options.Chain.selectRouteFor(host) if err != nil { @@ -152,10 +152,11 @@ func (h *sniHandler) Handle(conn net.Conn) { fmt.Fprintf(&buf, "%s", host) log.Log("[route]", buf.String()) - cc, err = route.Dial(ip, host, + cc, err = route.Dial(host, TimeoutChainOption(h.options.Timeout), HostsChainOption(h.options.Hosts), ResolverChainOption(h.options.Resolver), + IPChainOption(ip), ) if err == nil { break diff --git a/socks.go b/socks.go index a7a1583..08db5b6 100644 --- a/socks.go +++ b/socks.go @@ -168,7 +168,7 @@ func (selector *serverSelector) OnSelected(method uint8, conn net.Conn) (string, if Debug { log.Logf("[socks5] %s - %s: %s", conn.RemoteAddr(), conn.LocalAddr(), req.String()) } - ip := GetIP(conn) + ip := getIP(conn) if selector.Authenticator != nil && !selector.Authenticator.IFAuthenticate(ip, req.Username, req.Password) { resp := gosocks5.NewUserPassResponse(gosocks5.UserPassVer, gosocks5.Failure) if err := resp.Write(conn); err != nil { @@ -923,7 +923,7 @@ func (h *socks5Handler) handleConnect(conn net.Conn, req *gosocks5.Request) { var err error var cc net.Conn var route *Chain - ip := GetIP(conn) + ip := getIP(conn) for i := 0; i < retries; i++ { route, err = h.options.Chain.selectRouteFor(host) if err != nil { @@ -941,10 +941,11 @@ func (h *socks5Handler) handleConnect(conn net.Conn, req *gosocks5.Request) { fmt.Fprintf(&buf, "%s", host) log.Log("[route]", buf.String()) - cc, err = route.Dial(ip, host, + cc, err = route.Dial(host, TimeoutChainOption(h.options.Timeout), HostsChainOption(h.options.Hosts), ResolverChainOption(h.options.Resolver), + IPChainOption(ip), ) if err == nil { break @@ -1774,7 +1775,7 @@ func (h *socks4Handler) handleConnect(conn net.Conn, req *gosocks4.Request) { var err error var cc net.Conn var route *Chain - ip := GetIP(conn) + ip := getIP(conn) for i := 0; i < retries; i++ { route, err = h.options.Chain.selectRouteFor(addr) if err != nil { @@ -1792,10 +1793,11 @@ func (h *socks4Handler) handleConnect(conn net.Conn, req *gosocks4.Request) { fmt.Fprintf(&buf, "%s", addr) log.Log("[route]", buf.String()) - cc, err = route.Dial(ip, addr, + cc, err = route.Dial(addr, TimeoutChainOption(h.options.Timeout), HostsChainOption(h.options.Hosts), ResolverChainOption(h.options.Resolver), + IPChainOption(ip), ) if err == nil { break diff --git a/ss.go b/ss.go index 4dbf7c4..633e7f9 100644 --- a/ss.go +++ b/ss.go @@ -164,7 +164,7 @@ func (h *shadowHandler) Handle(conn net.Conn) { var cc net.Conn var route *Chain - ip := GetIP(conn) + ip := getIP(conn) for i := 0; i < retries; i++ { route, err = h.options.Chain.selectRouteFor(host) if err != nil { @@ -181,7 +181,8 @@ func (h *shadowHandler) Handle(conn net.Conn) { } fmt.Fprintf(&buf, "%s", host) log.Log("[route]", buf.String()) - cc, err = route.Dial(ip, host, + cc, err = route.Dial(host, + IPChainOption(ip), TimeoutChainOption(h.options.Timeout), HostsChainOption(h.options.Hosts), ResolverChainOption(h.options.Resolver), @@ -297,8 +298,8 @@ func (h *shadowUDPHandler) Handle(conn net.Conn) { defer conn.Close() var cc net.PacketConn - ip := GetIP(conn) - c, err := h.options.Chain.DialContext(ip, context.Background(), "udp", "") + ip := getIP(conn) + c, err := h.options.Chain.DialContext(context.Background(), "udp", "", IPChainOption(ip)) if err != nil { log.Logf("[ssu] %s: %s", conn.LocalAddr(), err) return diff --git a/ssh.go b/ssh.go index 8d17bb2..007e71d 100644 --- a/ssh.go +++ b/ssh.go @@ -611,11 +611,12 @@ func (h *sshForwardHandler) directPortForwardChannel(ip net.IP, channel ssh.Chan return } - conn, err := h.options.Chain.Dial(ip, raddr, + conn, err := h.options.Chain.Dial(raddr, RetryChainOption(h.options.Retries), TimeoutChainOption(h.options.Timeout), HostsChainOption(h.options.Hosts), ResolverChainOption(h.options.Resolver), + IPChainOption(ip), ) if err != nil { log.Logf("[ssh-tcp] %s - %s : %s", h.options.Node.Addr, raddr, err) diff --git a/tuntap.go b/tuntap.go index 9edf7b1..cf46685 100644 --- a/tuntap.go +++ b/tuntap.go @@ -162,14 +162,14 @@ func (h *tunHandler) Handle(conn net.Conn) { } var tempDelay time.Duration - ip := GetIP(conn) + ip := getIP(conn) for { err := func() error { var err error var pc net.PacketConn // fake tcp mode will be ignored when the client specifies a chain. if raddr != nil && !h.options.Chain.IsEmpty() { - cc, err := h.options.Chain.DialContext(ip, context.Background(), "udp", raddr.String()) + cc, err := h.options.Chain.DialContext(context.Background(), "udp", raddr.String(), IPChainOption(ip)) if err != nil { return err } @@ -552,14 +552,14 @@ func (h *tapHandler) Handle(conn net.Conn) { } } var tempDelay time.Duration - ip := GetIP(conn) + ip := getIP(conn) for { err := func() error { var err error var pc net.PacketConn // fake tcp mode will be ignored when the client specifies a chain. if raddr != nil && !h.options.Chain.IsEmpty() { - cc, err := h.options.Chain.DialContext(ip, context.Background(), "udp", raddr.String()) + cc, err := h.options.Chain.DialContext(context.Background(), "udp", raddr.String(), IPChainOption(ip)) if err != nil { return err }