Browse Source

todo: parse ip

pull/1077/head
chuchur 3 months ago
parent
commit
08fcb8f36d
  1. 52
      chain.go
  2. 4
      dns.go
  3. 14
      forward.go
  4. 7
      http.go
  5. 6
      http2.go
  6. 60
      localAddr.go
  7. 5
      relay.go
  8. 17
      resolver.go
  9. 5
      sni.go
  10. 12
      socks.go
  11. 9
      ss.go
  12. 3
      ssh.go
  13. 8
      tuntap.go

52
chain.go

@ -9,13 +9,11 @@ import (
"time"
"github.com/go-log/log"
"golang.org/x/crypto/ssh"
)
var (
// ErrEmptyChain is an error that implies the chain is empty.
ErrEmptyChain = errors.New("empty chain")
localAddrKey = "localAddr"
)
// Chain is a proxy chain that holds a list of proxy node groups.
@ -111,29 +109,14 @@ func (c *Chain) IsEmpty() bool {
return c == nil || len(c.nodeGroups) == 0
}
func GetIP(conn net.Conn) (ip net.IP) {
IP := conn.LocalAddr().(*net.TCPAddr).IP
if IP != nil && !IP.IsPrivate() && !IP.IsLoopback() {
return IP
}
return nil
}
func GetSshIP(conn ssh.ConnMetadata) (ip net.IP) {
IP := conn.LocalAddr().(*net.TCPAddr).IP
if IP != nil && !IP.IsPrivate() && !IP.IsLoopback() {
return IP
}
return nil
}
// Dial connects to the target TCP address addr through the chain.
// Deprecated: use DialContext instead.
func (c *Chain) Dial(ip net.IP, address string, opts ...ChainOption) (conn net.Conn, err error) {
return c.DialContext(ip, context.Background(), "tcp", address, opts...)
func (c *Chain) Dial(address string, opts ...ChainOption) (conn net.Conn, err error) {
return c.DialContext(context.Background(), "tcp", address, opts...)
}
// DialContext connects to the address on the named network using the provided context.
func (c *Chain) DialContext(ip net.IP, ctx context.Context, network, address string, opts ...ChainOption) (conn net.Conn, err error) {
func (c *Chain) DialContext(ctx context.Context, network, address string, opts ...ChainOption) (conn net.Conn, err error) {
options := &ChainOptions{}
for _, opt := range opts {
opt(options)
@ -148,7 +131,7 @@ func (c *Chain) DialContext(ip net.IP, ctx context.Context, network, address str
}
for i := 0; i < retries; i++ {
conn, err = c.dialWithOptions(ip, ctx, network, address, options)
conn, err = c.dialWithOptions(ctx, network, address, options)
if err == nil {
break
}
@ -156,7 +139,7 @@ func (c *Chain) DialContext(ip net.IP, ctx context.Context, network, address str
return
}
func (c *Chain) dialWithOptions(ip net.IP, ctx context.Context, network, address string, options *ChainOptions) (net.Conn, error) {
func (c *Chain) dialWithOptions(ctx context.Context, network, address string, options *ChainOptions) (net.Conn, error) {
if options == nil {
options = &ChainOptions{}
}
@ -215,24 +198,8 @@ func (c *Chain) dialWithOptions(ip net.IP, ctx context.Context, network, address
}
default:
}
var localAddr net.Addr
// ip := cn.LocalAddr().(*net.TCPAddr).IP
// ctx := context.WithValue(context.Background(), localAddrKey, ip)
// if v := ctx.Value(localAddrKey); v != nil {
// if ip, ok := v.(net.IP); ok {
// localAddr = &net.TCPAddr{
// IP: ip,
// Port: 0,
// }
// }
// }
if ip != nil {
localAddr = &net.TCPAddr{
IP: ip,
Port: 0,
}
}
localAddr := getLocalAddr(ctx)
d := &net.Dialer{
Timeout: timeout,
Control: controlFunction,
@ -407,6 +374,7 @@ type ChainOptions struct {
Hosts *Hosts
Resolver Resolver
Mark int
IP net.IP
}
// ChainOption allows a common way to set chain options.
@ -439,3 +407,9 @@ func ResolverChainOption(resolver Resolver) ChainOption {
opts.Resolver = resolver
}
}
func IPChainOption(ip net.IP) ChainOption {
return func(opts *ChainOptions) {
opts.IP = ip
}
}

4
dns.go

@ -82,8 +82,8 @@ func (h *dnsHandler) Handle(conn net.Conn) {
if resolver == nil {
resolver = defaultResolver
}
ip := GetIP(conn)
reply, err := resolver.Exchange(ip, context.Background(), b[:n])
ctx := getContext(conn, context.Background())
reply, err := resolver.Exchange(ctx, b[:n])
if err != nil {
log.Logf("[dns] %s - %s exchange: %v", conn.RemoteAddr(), conn.LocalAddr(), err)
return

14
forward.go

@ -119,7 +119,7 @@ func (h *tcpDirectForwardHandler) Handle(conn net.Conn) {
var cc net.Conn
var node Node
var err error
ip := GetIP(conn)
ip := getIP(conn)
for i := 0; i < retries; i++ {
if len(h.group.Nodes()) > 0 {
node, err = h.group.Next()
@ -129,10 +129,11 @@ func (h *tcpDirectForwardHandler) Handle(conn net.Conn) {
}
}
cc, err = h.options.Chain.Dial(ip, node.Addr,
cc, err = h.options.Chain.Dial(node.Addr,
RetryChainOption(h.options.Retries),
TimeoutChainOption(h.options.Timeout),
ResolverChainOption(h.options.Resolver),
IPChainOption(ip),
)
if err != nil {
log.Logf("[tcp] %s -> %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err)
@ -198,8 +199,10 @@ func (h *udpDirectForwardHandler) Handle(conn net.Conn) {
return
}
}
ip := GetIP(conn)
cc, err := h.options.Chain.DialContext(ip, context.Background(), "udp", node.Addr, ResolverChainOption(h.options.Resolver))
ip := getIP(conn)
cc, err := h.options.Chain.DialContext(context.Background(), "udp", node.Addr,
IPChainOption(ip),
ResolverChainOption(h.options.Resolver))
if err != nil {
node.MarkDead()
log.Logf("[udp] %s - %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err)
@ -448,7 +451,8 @@ func (l *tcpRemoteForwardListener) Accept() (conn net.Conn, err error) {
func (l *tcpRemoteForwardListener) accept() (conn net.Conn, err error) {
lastNode := l.chain.LastNode()
if lastNode.Protocol == "forward" && lastNode.Transport == "ssh" {
return l.chain.Dial(nil, l.addr.String())
ip := getIP(conn)
return l.chain.Dial(l.addr.String(), IPChainOption(ip))
}
if l.isChainValid() {

7
http.go

@ -239,6 +239,7 @@ func (h *httpHandler) handleRequest(conn net.Conn, req *http.Request) {
var err error
var cc net.Conn
var route *Chain
ip := getIP(conn)
for i := 0; i < retries; i++ {
route, err = h.options.Chain.selectRouteFor(host)
if err != nil {
@ -268,11 +269,11 @@ func (h *httpHandler) handleRequest(conn net.Conn, req *http.Request) {
log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err)
continue
}
ip := GetIP(conn)
cc, err = route.Dial(ip, host,
cc, err = route.Dial(host,
TimeoutChainOption(h.options.Timeout),
HostsChainOption(h.options.Hosts),
ResolverChainOption(h.options.Resolver),
IPChainOption(ip),
)
if err == nil {
break
@ -358,7 +359,7 @@ func (h *httpHandler) authenticate(conn net.Conn, req *http.Request, resp *http.
log.Logf("[http] %s -> %s : Authorization '%s' '%s'",
conn.RemoteAddr(), conn.LocalAddr(), u, p)
}
ip := GetIP(conn)
ip := getIP(conn)
if h.options.Authenticator == nil || h.options.Authenticator.IFAuthenticate(ip, u, p) {
return true
}

6
http2.go

@ -338,7 +338,7 @@ func (h *http2Handler) Handle(conn net.Conn) {
return
}
ip := GetIP(conn)
ip := getIP(conn)
h.roundTrip(ip, h2c.w, h2c.r)
}
@ -427,11 +427,11 @@ func (h *http2Handler) roundTrip(ip net.IP, w http.ResponseWriter, r *http.Reque
}
fmt.Fprintf(&buf, "%s", host)
log.Log("[route]", buf.String())
cc, err = route.Dial(ip, host,
cc, err = route.Dial(host,
TimeoutChainOption(h.options.Timeout),
HostsChainOption(h.options.Hosts),
ResolverChainOption(h.options.Resolver),
IPChainOption(ip),
)
if err == nil {
break

60
localAddr.go

@ -0,0 +1,60 @@
package gost
import (
"context"
"net"
"golang.org/x/crypto/ssh"
)
var localAddrKey = "localAddr"
func getIP(conn net.Conn) (ip net.IP) {
IP := conn.LocalAddr().(*net.TCPAddr).IP
if IP != nil && !IP.IsPrivate() && !IP.IsLoopback() {
return IP
}
return nil
}
func getContext1(ip net.IP) (ctx context.Context) {
ctx = context.Background()
if ip != nil {
return context.WithValue(ctx, localAddrKey, ip)
}
return
}
func getContext(conn net.Conn, parentCtx context.Context) (ctx context.Context) {
IP := getIP(conn)
if IP != nil {
return context.WithValue(parentCtx, localAddrKey, IP)
}
return parentCtx
}
func GetIP(ctx context.Context) (ip net.IP) {
if v := ctx.Value(localAddrKey); v != nil {
if ip, ok := v.(net.IP); ok && !ip.IsPrivate() && !ip.IsLoopback() {
return ip
}
}
return nil
}
func GetSshIP(conn ssh.ConnMetadata) (ip net.IP) {
IP := conn.LocalAddr().(*net.TCPAddr).IP
if IP != nil && !IP.IsPrivate() && !IP.IsLoopback() {
return IP
}
return nil
}
func getLocalAddr(ctx context.Context) (addr net.Addr) {
ip := GetIP(ctx)
if ip != nil {
addr = &net.TCPAddr{
IP: ip,
Port: 0,
}
}
return
}

5
relay.go

@ -165,7 +165,7 @@ func (h *relayHandler) Handle(conn net.Conn) {
Version: relay.Version1,
Status: relay.StatusOK,
}
ip := GetIP(conn)
ip := getIP(conn)
if h.options.Authenticator != nil && !h.options.Authenticator.IFAuthenticate(ip, user, pass) {
resp.Status = relay.StatusUnauthorized
resp.WriteTo(conn)
@ -230,10 +230,11 @@ func (h *relayHandler) Handle(conn net.Conn) {
log.Logf("[relay] %s -> %s -> %s", conn.RemoteAddr(), conn.LocalAddr(), raddr)
cc, err = h.options.Chain.DialContext(ip, ctx,
cc, err = h.options.Chain.DialContext(ctx,
network, raddr,
RetryChainOption(h.options.Retries),
TimeoutChainOption(h.options.Timeout),
IPChainOption(ip),
)
if err != nil {
log.Logf("[relay] %s -> %s : %s", conn.RemoteAddr(), raddr, err)

17
resolver.go

@ -173,7 +173,7 @@ type Resolver interface {
Resolve(host string) ([]net.IP, error)
// Exchange performs a synchronous query,
// It sends the message query and waits for a reply.
Exchange(ip net.IP, ctx context.Context, query []byte) (reply []byte, err error)
Exchange(ctx context.Context, query []byte) (reply []byte, err error)
}
// ReloadResolver is resolover that support live reloading.
@ -379,7 +379,7 @@ func (r *resolver) addSubnetOpt(m *dns.Msg) {
m.Extra = append(m.Extra, opt)
}
func (r *resolver) Exchange(ip net.IP, ctx context.Context, query []byte) (reply []byte, err error) {
func (r *resolver) Exchange(ctx context.Context, query []byte) (reply []byte, err error) {
mq := &dns.Msg{}
if err = mq.Unpack(query); err != nil {
return
@ -408,10 +408,9 @@ func (r *resolver) Exchange(ip net.IP, ctx context.Context, query []byte) (reply
}
r.addSubnetOpt(mq)
for _, ns := range r.copyServers() {
log.Logf("[dns] exchange message %d via %s: %s", mq.Id, ns.String(), mq.Question[0].String())
mr, err = r.exchangeMsg(ip, ctx, ns.exchanger, mq)
mr, err = r.exchangeMsg(ctx, ns.exchanger, mq)
if err == nil {
break
}
@ -423,12 +422,12 @@ func (r *resolver) Exchange(ip net.IP, ctx context.Context, query []byte) (reply
return mr.Pack()
}
func (r *resolver) exchangeMsg(ip net.IP, ctx context.Context, ex Exchanger, mq *dns.Msg) (mr *dns.Msg, err error) {
func (r *resolver) exchangeMsg(ctx context.Context, ex Exchanger, mq *dns.Msg) (mr *dns.Msg, err error) {
query, err := mq.Pack()
if err != nil {
return
}
reply, err := ex.Exchange(ip, ctx, query)
reply, err := ex.Exchange(ctx, query)
if err != nil {
return
}
@ -657,7 +656,7 @@ func (rc *resolverCache) storeCache(key resolverCacheKey, mr *dns.Msg, ttl time.
// Exchanger is an interface for DNS synchronous query.
type Exchanger interface {
Exchange(ip net.IP, ctx context.Context, query []byte) ([]byte, error)
Exchange(ctx context.Context, query []byte) ([]byte, error)
}
type exchangerOptions struct {
@ -704,9 +703,9 @@ func NewDNSExchanger(addr string, opts ...ExchangerOption) Exchanger {
}
}
func (ex *dnsExchanger) Exchange(ip net.IP, ctx context.Context, query []byte) ([]byte, error) {
func (ex *dnsExchanger) Exchange(ctx context.Context, query []byte) ([]byte, error) {
t := time.Now()
c, err := ex.options.chain.DialContext(ip, ctx,
c, err := ex.options.chain.DialContext(ctx,
"udp", ex.addr,
TimeoutChainOption(ex.options.timeout),
)

5
sni.go

@ -134,7 +134,7 @@ func (h *sniHandler) Handle(conn net.Conn) {
var cc net.Conn
var route *Chain
ip := GetIP(conn)
ip := getIP(conn)
for i := 0; i < retries; i++ {
route, err = h.options.Chain.selectRouteFor(host)
if err != nil {
@ -152,10 +152,11 @@ func (h *sniHandler) Handle(conn net.Conn) {
fmt.Fprintf(&buf, "%s", host)
log.Log("[route]", buf.String())
cc, err = route.Dial(ip, host,
cc, err = route.Dial(host,
TimeoutChainOption(h.options.Timeout),
HostsChainOption(h.options.Hosts),
ResolverChainOption(h.options.Resolver),
IPChainOption(ip),
)
if err == nil {
break

12
socks.go

@ -168,7 +168,7 @@ func (selector *serverSelector) OnSelected(method uint8, conn net.Conn) (string,
if Debug {
log.Logf("[socks5] %s - %s: %s", conn.RemoteAddr(), conn.LocalAddr(), req.String())
}
ip := GetIP(conn)
ip := getIP(conn)
if selector.Authenticator != nil && !selector.Authenticator.IFAuthenticate(ip, req.Username, req.Password) {
resp := gosocks5.NewUserPassResponse(gosocks5.UserPassVer, gosocks5.Failure)
if err := resp.Write(conn); err != nil {
@ -923,7 +923,7 @@ func (h *socks5Handler) handleConnect(conn net.Conn, req *gosocks5.Request) {
var err error
var cc net.Conn
var route *Chain
ip := GetIP(conn)
ip := getIP(conn)
for i := 0; i < retries; i++ {
route, err = h.options.Chain.selectRouteFor(host)
if err != nil {
@ -941,10 +941,11 @@ func (h *socks5Handler) handleConnect(conn net.Conn, req *gosocks5.Request) {
fmt.Fprintf(&buf, "%s", host)
log.Log("[route]", buf.String())
cc, err = route.Dial(ip, host,
cc, err = route.Dial(host,
TimeoutChainOption(h.options.Timeout),
HostsChainOption(h.options.Hosts),
ResolverChainOption(h.options.Resolver),
IPChainOption(ip),
)
if err == nil {
break
@ -1774,7 +1775,7 @@ func (h *socks4Handler) handleConnect(conn net.Conn, req *gosocks4.Request) {
var err error
var cc net.Conn
var route *Chain
ip := GetIP(conn)
ip := getIP(conn)
for i := 0; i < retries; i++ {
route, err = h.options.Chain.selectRouteFor(addr)
if err != nil {
@ -1792,10 +1793,11 @@ func (h *socks4Handler) handleConnect(conn net.Conn, req *gosocks4.Request) {
fmt.Fprintf(&buf, "%s", addr)
log.Log("[route]", buf.String())
cc, err = route.Dial(ip, addr,
cc, err = route.Dial(addr,
TimeoutChainOption(h.options.Timeout),
HostsChainOption(h.options.Hosts),
ResolverChainOption(h.options.Resolver),
IPChainOption(ip),
)
if err == nil {
break

9
ss.go

@ -164,7 +164,7 @@ func (h *shadowHandler) Handle(conn net.Conn) {
var cc net.Conn
var route *Chain
ip := GetIP(conn)
ip := getIP(conn)
for i := 0; i < retries; i++ {
route, err = h.options.Chain.selectRouteFor(host)
if err != nil {
@ -181,7 +181,8 @@ func (h *shadowHandler) Handle(conn net.Conn) {
}
fmt.Fprintf(&buf, "%s", host)
log.Log("[route]", buf.String())
cc, err = route.Dial(ip, host,
cc, err = route.Dial(host,
IPChainOption(ip),
TimeoutChainOption(h.options.Timeout),
HostsChainOption(h.options.Hosts),
ResolverChainOption(h.options.Resolver),
@ -297,8 +298,8 @@ func (h *shadowUDPHandler) Handle(conn net.Conn) {
defer conn.Close()
var cc net.PacketConn
ip := GetIP(conn)
c, err := h.options.Chain.DialContext(ip, context.Background(), "udp", "")
ip := getIP(conn)
c, err := h.options.Chain.DialContext(context.Background(), "udp", "", IPChainOption(ip))
if err != nil {
log.Logf("[ssu] %s: %s", conn.LocalAddr(), err)
return

3
ssh.go

@ -611,11 +611,12 @@ func (h *sshForwardHandler) directPortForwardChannel(ip net.IP, channel ssh.Chan
return
}
conn, err := h.options.Chain.Dial(ip, raddr,
conn, err := h.options.Chain.Dial(raddr,
RetryChainOption(h.options.Retries),
TimeoutChainOption(h.options.Timeout),
HostsChainOption(h.options.Hosts),
ResolverChainOption(h.options.Resolver),
IPChainOption(ip),
)
if err != nil {
log.Logf("[ssh-tcp] %s - %s : %s", h.options.Node.Addr, raddr, err)

8
tuntap.go

@ -162,14 +162,14 @@ func (h *tunHandler) Handle(conn net.Conn) {
}
var tempDelay time.Duration
ip := GetIP(conn)
ip := getIP(conn)
for {
err := func() error {
var err error
var pc net.PacketConn
// fake tcp mode will be ignored when the client specifies a chain.
if raddr != nil && !h.options.Chain.IsEmpty() {
cc, err := h.options.Chain.DialContext(ip, context.Background(), "udp", raddr.String())
cc, err := h.options.Chain.DialContext(context.Background(), "udp", raddr.String(), IPChainOption(ip))
if err != nil {
return err
}
@ -552,14 +552,14 @@ func (h *tapHandler) Handle(conn net.Conn) {
}
}
var tempDelay time.Duration
ip := GetIP(conn)
ip := getIP(conn)
for {
err := func() error {
var err error
var pc net.PacketConn
// fake tcp mode will be ignored when the client specifies a chain.
if raddr != nil && !h.options.Chain.IsEmpty() {
cc, err := h.options.Chain.DialContext(ip, context.Background(), "udp", raddr.String())
cc, err := h.options.Chain.DialContext(context.Background(), "udp", raddr.String(), IPChainOption(ip))
if err != nil {
return err
}

Loading…
Cancel
Save