Browse Source

处理A进A出

pull/1077/head
chuchur 4 months ago
parent
commit
61e4cf69a1
  1. 9
      auth.go
  2. 51
      chain.go
  3. 6
      cmd/gost/build.sh
  4. 3
      dns.go
  5. 4
      examples/ssu/ssu.go
  6. 14
      forward.go
  7. 7
      http.go
  8. 19
      http2.go
  9. 8
      redirect.go
  10. 6
      relay.go
  11. 58
      resolver.go
  12. 3
      sni.go
  13. 12
      socks.go
  14. 7
      ss.go
  15. 15
      ssh.go
  16. 6
      tcp.go
  17. 2
      tls.go
  18. 7
      tuntap.go
  19. 4
      vsock.go
  20. 4
      ws.go

9
auth.go

@ -2,7 +2,6 @@ package gost
import (
"bufio"
"context"
"crypto/sha256"
"encoding/hex"
"io"
@ -17,13 +16,11 @@ import (
// Authenticator is an interface for user authentication.
type Authenticator interface {
Authenticate(user, password string) bool
InflowwAuthenticateContext(ctx context.Context, user, password string) bool
IFAuthenticate(ip net.IP, user, password string) bool
}
func (au *LocalAuthenticator) InflowwAuthenticateContext(ctx context.Context, user, password string) bool {
inboundIP := ctx.Value("InboundIP")
if inboundIP != nil {
ip := inboundIP.(net.IP)
func (au *LocalAuthenticator) IFAuthenticate(ip net.IP, user, password string) bool {
if ip != nil {
if !ip.IsLoopback() && !ip.IsPrivate() {
expected := GeneratePass(ip.String(), user)
if expected == password {

51
chain.go

@ -9,11 +9,13 @@ import (
"time"
"github.com/go-log/log"
"golang.org/x/crypto/ssh"
)
var (
// ErrEmptyChain is an error that implies the chain is empty.
ErrEmptyChain = errors.New("empty chain")
localAddrKey = "localAddr"
)
// Chain is a proxy chain that holds a list of proxy node groups.
@ -109,14 +111,29 @@ func (c *Chain) IsEmpty() bool {
return c == nil || len(c.nodeGroups) == 0
}
func GetIP(conn net.Conn) (ip net.IP) {
IP := conn.LocalAddr().(*net.TCPAddr).IP
if IP != nil && !IP.IsPrivate() && !IP.IsLoopback() {
return IP
}
return nil
}
func GetSshIP(conn ssh.ConnMetadata) (ip net.IP) {
IP := conn.LocalAddr().(*net.TCPAddr).IP
if IP != nil && !IP.IsPrivate() && !IP.IsLoopback() {
return IP
}
return nil
}
// Dial connects to the target TCP address addr through the chain.
// Deprecated: use DialContext instead.
func (c *Chain) Dial(address string, opts ...ChainOption) (conn net.Conn, err error) {
return c.DialContext(context.Background(), "tcp", address, opts...)
func (c *Chain) Dial(ip net.IP, address string, opts ...ChainOption) (conn net.Conn, err error) {
return c.DialContext(ip, context.Background(), "tcp", address, opts...)
}
// DialContext connects to the address on the named network using the provided context.
func (c *Chain) DialContext(ctx context.Context, network, address string, opts ...ChainOption) (conn net.Conn, err error) {
func (c *Chain) DialContext(ip net.IP, ctx context.Context, network, address string, opts ...ChainOption) (conn net.Conn, err error) {
options := &ChainOptions{}
for _, opt := range opts {
opt(options)
@ -131,7 +148,7 @@ func (c *Chain) DialContext(ctx context.Context, network, address string, opts .
}
for i := 0; i < retries; i++ {
conn, err = c.dialWithOptions(ctx, network, address, options)
conn, err = c.dialWithOptions(ip, ctx, network, address, options)
if err == nil {
break
}
@ -139,7 +156,7 @@ func (c *Chain) DialContext(ctx context.Context, network, address string, opts .
return
}
func (c *Chain) dialWithOptions(ctx context.Context, network, address string, options *ChainOptions) (net.Conn, error) {
func (c *Chain) dialWithOptions(ip net.IP, ctx context.Context, network, address string, options *ChainOptions) (net.Conn, error) {
if options == nil {
options = &ChainOptions{}
}
@ -198,10 +215,28 @@ func (c *Chain) dialWithOptions(ctx context.Context, network, address string, op
}
default:
}
var localAddr net.Addr
// ip := cn.LocalAddr().(*net.TCPAddr).IP
// ctx := context.WithValue(context.Background(), localAddrKey, ip)
// if v := ctx.Value(localAddrKey); v != nil {
// if ip, ok := v.(net.IP); ok {
// localAddr = &net.TCPAddr{
// IP: ip,
// Port: 0,
// }
// }
// }
if ip != nil {
localAddr = &net.TCPAddr{
IP: ip,
Port: 0,
}
}
d := &net.Dialer{
Timeout: timeout,
Control: controlFunction,
// LocalAddr: laddr, // TODO: optional local address
Timeout: timeout,
Control: controlFunction,
LocalAddr: localAddr,
}
return d.DialContext(ctx, network, ipAddr)
}

6
cmd/gost/build.sh

@ -2,6 +2,6 @@ GOOS=linux GOARCH=amd64 go build -o gost
rsync -avz gost [email protected]:/root/gost/gost
rsync -avz gost [email protected]:/root/gost/gost
rsync -avz gost [email protected]:/root/gost/gost
# rsync -avz gost [email protected]:/root/gost/gost
# rsync -avz gost [email protected]:/root/gost/gost
# rsync -avz gost [email protected]:/root/gost/gost

3
dns.go

@ -82,7 +82,8 @@ func (h *dnsHandler) Handle(conn net.Conn) {
if resolver == nil {
resolver = defaultResolver
}
reply, err := resolver.Exchange(context.Background(), b[:n])
ip := GetIP(conn)
reply, err := resolver.Exchange(ip, context.Background(), b[:n])
if err != nil {
log.Logf("[dns] %s - %s exchange: %v", conn.RemoteAddr(), conn.LocalAddr(), err)
return

4
examples/ssu/ssu.go

@ -28,13 +28,13 @@ func ssuClient() {
if err != nil {
log.Fatal(err)
}
cc := ss.NewSecurePacketConn(conn, cp, false)
cc := ss.NewSecurePacketConn(conn, cp)
raddr, _ := net.ResolveUDPAddr("udp", ":8080")
msg := []byte(`abcdefghijklmnopqrstuvwxyz`)
dgram := gosocks5.NewUDPDatagram(gosocks5.NewUDPHeader(0, 0, toSocksAddr(raddr)), msg)
buf := bytes.Buffer{}
dgram.Write(&buf)
dgram.WriteTo(&buf)
for {
log.Printf("%# x", buf.Bytes()[3:])
if _, err := cc.WriteTo(buf.Bytes()[3:], addr); err != nil {

14
forward.go

@ -119,6 +119,7 @@ func (h *tcpDirectForwardHandler) Handle(conn net.Conn) {
var cc net.Conn
var node Node
var err error
ip := GetIP(conn)
for i := 0; i < retries; i++ {
if len(h.group.Nodes()) > 0 {
node, err = h.group.Next()
@ -128,7 +129,7 @@ func (h *tcpDirectForwardHandler) Handle(conn net.Conn) {
}
}
cc, err = h.options.Chain.Dial(node.Addr,
cc, err = h.options.Chain.Dial(ip, node.Addr,
RetryChainOption(h.options.Retries),
TimeoutChainOption(h.options.Timeout),
ResolverChainOption(h.options.Resolver),
@ -197,13 +198,8 @@ func (h *udpDirectForwardHandler) Handle(conn net.Conn) {
return
}
}
cc, err := h.options.Chain.DialContext(
context.Background(),
"udp",
node.Addr,
ResolverChainOption(h.options.Resolver),
)
ip := GetIP(conn)
cc, err := h.options.Chain.DialContext(ip, context.Background(), "udp", node.Addr, ResolverChainOption(h.options.Resolver))
if err != nil {
node.MarkDead()
log.Logf("[udp] %s - %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err)
@ -452,7 +448,7 @@ func (l *tcpRemoteForwardListener) Accept() (conn net.Conn, err error) {
func (l *tcpRemoteForwardListener) accept() (conn net.Conn, err error) {
lastNode := l.chain.LastNode()
if lastNode.Protocol == "forward" && lastNode.Transport == "ssh" {
return l.chain.Dial(l.addr.String())
return l.chain.Dial(nil, l.addr.String())
}
if l.isChainValid() {

7
http.go

@ -268,8 +268,8 @@ func (h *httpHandler) handleRequest(conn net.Conn, req *http.Request) {
log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err)
continue
}
cc, err = route.Dial(host,
ip := GetIP(conn)
cc, err = route.Dial(ip, host,
TimeoutChainOption(h.options.Timeout),
HostsChainOption(h.options.Hosts),
ResolverChainOption(h.options.Resolver),
@ -358,7 +358,8 @@ func (h *httpHandler) authenticate(conn net.Conn, req *http.Request, resp *http.
log.Logf("[http] %s -> %s : Authorization '%s' '%s'",
conn.RemoteAddr(), conn.LocalAddr(), u, p)
}
if h.options.Authenticator == nil || h.options.Authenticator.Authenticate(u, p) {
ip := GetIP(conn)
if h.options.Authenticator == nil || h.options.Authenticator.IFAuthenticate(ip, u, p) {
return true
}

19
http2.go

@ -142,7 +142,7 @@ func (tr *http2Transporter) Dial(addr string, options ...DialOption) (net.Conn,
if !ok {
// NOTE: There is no real connection to the HTTP2 server at this moment.
// So we try to connect to the server to check the server health.
conn, err := opts.Chain.Dial(addr)
conn, err := opts.Chain.Dial(nil, addr)
if err != nil {
log.Log("http2 dial:", addr, err)
return nil, err
@ -156,7 +156,7 @@ func (tr *http2Transporter) Dial(addr string, options ...DialOption) (net.Conn,
transport := http2.Transport{
TLSClientConfig: tr.tlsConfig,
DialTLS: func(network, adr string, cfg *tls.Config) (net.Conn, error) {
conn, err := opts.Chain.Dial(adr)
conn, err := opts.Chain.Dial(nil, adr)
if err != nil {
return nil, err
}
@ -234,7 +234,7 @@ func (tr *h2Transporter) Dial(addr string, options ...DialOption) (net.Conn, err
transport := http2.Transport{
TLSClientConfig: tr.tlsConfig,
DialTLS: func(network, adr string, cfg *tls.Config) (net.Conn, error) {
conn, err := opts.Chain.Dial(addr)
conn, err := opts.Chain.Dial(nil, addr)
if err != nil {
return nil, err
}
@ -338,10 +338,11 @@ func (h *http2Handler) Handle(conn net.Conn) {
return
}
h.roundTrip(h2c.w, h2c.r)
ip := GetIP(conn)
h.roundTrip(ip, h2c.w, h2c.r)
}
func (h *http2Handler) roundTrip(w http.ResponseWriter, r *http.Request) {
func (h *http2Handler) roundTrip(ip net.IP, w http.ResponseWriter, r *http.Request) {
host := r.Header.Get("Gost-Target")
if host == "" {
host = r.Host
@ -391,7 +392,7 @@ func (h *http2Handler) roundTrip(w http.ResponseWriter, r *http.Request) {
Body: io.NopCloser(bytes.NewReader([]byte{})),
}
if !h.authenticate(w, r, resp) {
if !h.authenticate(ip, w, r, resp) {
return
}
@ -427,7 +428,7 @@ func (h *http2Handler) roundTrip(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(&buf, "%s", host)
log.Log("[route]", buf.String())
cc, err = route.Dial(host,
cc, err = route.Dial(ip, host,
TimeoutChainOption(h.options.Timeout),
HostsChainOption(h.options.Hosts),
ResolverChainOption(h.options.Resolver),
@ -481,13 +482,13 @@ func (h *http2Handler) roundTrip(w http.ResponseWriter, r *http.Request) {
log.Logf("[http2] %s >-< %s", r.RemoteAddr, host)
}
func (h *http2Handler) authenticate(w http.ResponseWriter, r *http.Request, resp *http.Response) (ok bool) {
func (h *http2Handler) authenticate(ip net.IP, w http.ResponseWriter, r *http.Request, resp *http.Response) (ok bool) {
laddr := h.options.Addr
u, p, _ := basicProxyAuth(r.Header.Get("Proxy-Authorization"))
if Debug && (u != "" || p != "") {
log.Logf("[http2] %s - %s : Authorization '%s' '%s'", r.RemoteAddr, laddr, u, p)
}
if h.options.Authenticator == nil || h.options.Authenticator.Authenticate(u, p) {
if h.options.Authenticator == nil || h.options.Authenticator.IFAuthenticate(ip, u, p) {
return true
}

8
redirect.go

@ -53,8 +53,8 @@ func (h *tcpRedirectHandler) Handle(c net.Conn) {
defer conn.Close()
log.Logf("[red-tcp] %s -> %s", srcAddr, dstAddr)
cc, err := h.options.Chain.DialContext(context.Background(),
ip := GetIP(c)
cc, err := h.options.Chain.DialContext(ip, context.Background(),
"tcp", dstAddr.String(),
RetryChainOption(h.options.Retries),
TimeoutChainOption(h.options.Timeout),
@ -134,8 +134,8 @@ func (h *udpRedirectHandler) Handle(conn net.Conn) {
log.Log("[red-udp] wrong connection type")
return
}
cc, err := h.options.Chain.DialContext(context.Background(),
ip := GetIP(conn)
cc, err := h.options.Chain.DialContext(ip, context.Background(),
"udp", raddr.String(),
RetryChainOption(h.options.Retries),
TimeoutChainOption(h.options.Timeout),

6
relay.go

@ -165,7 +165,8 @@ func (h *relayHandler) Handle(conn net.Conn) {
Version: relay.Version1,
Status: relay.StatusOK,
}
if h.options.Authenticator != nil && !h.options.Authenticator.Authenticate(user, pass) {
ip := GetIP(conn)
if h.options.Authenticator != nil && !h.options.Authenticator.IFAuthenticate(ip, user, pass) {
resp.Status = relay.StatusUnauthorized
resp.WriteTo(conn)
log.Logf("[relay] %s -> %s : %s unauthorized", conn.RemoteAddr(), conn.LocalAddr(), user)
@ -228,7 +229,8 @@ func (h *relayHandler) Handle(conn net.Conn) {
}
log.Logf("[relay] %s -> %s -> %s", conn.RemoteAddr(), conn.LocalAddr(), raddr)
cc, err = h.options.Chain.DialContext(ctx,
cc, err = h.options.Chain.DialContext(ip, ctx,
network, raddr,
RetryChainOption(h.options.Retries),
TimeoutChainOption(h.options.Timeout),

58
resolver.go

@ -173,7 +173,7 @@ type Resolver interface {
Resolve(host string) ([]net.IP, error)
// Exchange performs a synchronous query,
// It sends the message query and waits for a reply.
Exchange(ctx context.Context, query []byte) (reply []byte, err error)
Exchange(ip net.IP, ctx context.Context, query []byte) (reply []byte, err error)
}
// ReloadResolver is resolover that support live reloading.
@ -282,7 +282,7 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
ctx := context.Background()
for _, ns := range r.copyServers() {
ips, err = r.resolve(ctx, ns.exchanger, host)
ips, err = r.resolve(ip, ctx, ns.exchanger, host)
if err != nil {
log.Logf("[resolver] %s via %s : %s", host, ns.String(), err)
continue
@ -299,7 +299,7 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
return
}
func (r *resolver) resolve(ctx context.Context, ex Exchanger, host string) (ips []net.IP, err error) {
func (r *resolver) resolve(ip net.IP, ctx context.Context, ex Exchanger, host string) (ips []net.IP, err error) {
if ex == nil {
return
}
@ -309,36 +309,36 @@ func (r *resolver) resolve(ctx context.Context, ex Exchanger, host string) (ips
r.mux.RUnlock()
if prefer == "ipv6" { // prefer ipv6
if ips, err = r.resolve6(ctx, ex, host); len(ips) > 0 {
if ips, err = r.resolve6(ip, ctx, ex, host); len(ips) > 0 {
return
}
return r.resolve4(ctx, ex, host)
return r.resolve4(ip, ctx, ex, host)
}
if ips, err = r.resolve4(ctx, ex, host); len(ips) > 0 {
if ips, err = r.resolve4(ip, ctx, ex, host); len(ips) > 0 {
return
}
return r.resolve6(ctx, ex, host)
return r.resolve6(ip, ctx, ex, host)
}
func (r *resolver) resolve4(ctx context.Context, ex Exchanger, host string) (ips []net.IP, err error) {
func (r *resolver) resolve4(ip net.IP, ctx context.Context, ex Exchanger, host string) (ips []net.IP, err error) {
mq := dns.Msg{}
mq.SetQuestion(dns.Fqdn(host), dns.TypeA)
return r.resolveIPs(ctx, ex, &mq)
return r.resolveIPs(ip, ctx, ex, &mq)
}
func (r *resolver) resolve6(ctx context.Context, ex Exchanger, host string) (ips []net.IP, err error) {
func (r *resolver) resolve6(ip net.IP, ctx context.Context, ex Exchanger, host string) (ips []net.IP, err error) {
mq := dns.Msg{}
mq.SetQuestion(dns.Fqdn(host), dns.TypeAAAA)
return r.resolveIPs(ctx, ex, &mq)
return r.resolveIPs(ip, ctx, ex, &mq)
}
func (r *resolver) resolveIPs(ctx context.Context, ex Exchanger, mq *dns.Msg) (ips []net.IP, err error) {
func (r *resolver) resolveIPs(ip net.IP, ctx context.Context, ex Exchanger, mq *dns.Msg) (ips []net.IP, err error) {
key := newResolverCacheKey(&mq.Question[0])
mr := r.cache.loadCache(key)
if mr == nil {
r.addSubnetOpt(mq)
mr, err = r.exchangeMsg(ctx, ex, mq)
mr, err = r.exchangeMsg(ip, ctx, ex, mq)
if err != nil {
return
}
@ -379,7 +379,7 @@ func (r *resolver) addSubnetOpt(m *dns.Msg) {
m.Extra = append(m.Extra, opt)
}
func (r *resolver) Exchange(ctx context.Context, query []byte) (reply []byte, err error) {
func (r *resolver) Exchange(ip net.IP, ctx context.Context, query []byte) (reply []byte, err error) {
mq := &dns.Msg{}
if err = mq.Unpack(query); err != nil {
return
@ -411,7 +411,7 @@ func (r *resolver) Exchange(ctx context.Context, query []byte) (reply []byte, er
for _, ns := range r.copyServers() {
log.Logf("[dns] exchange message %d via %s: %s", mq.Id, ns.String(), mq.Question[0].String())
mr, err = r.exchangeMsg(ctx, ns.exchanger, mq)
mr, err = r.exchangeMsg(ip, ctx, ns.exchanger, mq)
if err == nil {
break
}
@ -423,12 +423,12 @@ func (r *resolver) Exchange(ctx context.Context, query []byte) (reply []byte, er
return mr.Pack()
}
func (r *resolver) exchangeMsg(ctx context.Context, ex Exchanger, mq *dns.Msg) (mr *dns.Msg, err error) {
func (r *resolver) exchangeMsg(ip net.IP, ctx context.Context, ex Exchanger, mq *dns.Msg) (mr *dns.Msg, err error) {
query, err := mq.Pack()
if err != nil {
return
}
reply, err := ex.Exchange(ctx, query)
reply, err := ex.Exchange(ip, ctx, query)
if err != nil {
return
}
@ -657,7 +657,7 @@ func (rc *resolverCache) storeCache(key resolverCacheKey, mr *dns.Msg, ttl time.
// Exchanger is an interface for DNS synchronous query.
type Exchanger interface {
Exchange(ctx context.Context, query []byte) ([]byte, error)
Exchange(ip net.IP, ctx context.Context, query []byte) ([]byte, error)
}
type exchangerOptions struct {
@ -704,9 +704,9 @@ func NewDNSExchanger(addr string, opts ...ExchangerOption) Exchanger {
}
}
func (ex *dnsExchanger) Exchange(ctx context.Context, query []byte) ([]byte, error) {
func (ex *dnsExchanger) Exchange(ip net.IP, ctx context.Context, query []byte) ([]byte, error) {
t := time.Now()
c, err := ex.options.chain.DialContext(ctx,
c, err := ex.options.chain.DialContext(ip, ctx,
"udp", ex.addr,
TimeoutChainOption(ex.options.timeout),
)
@ -754,9 +754,9 @@ func NewDNSTCPExchanger(addr string, opts ...ExchangerOption) Exchanger {
}
}
func (ex *dnsTCPExchanger) Exchange(ctx context.Context, query []byte) ([]byte, error) {
func (ex *dnsTCPExchanger) Exchange(ip net.IP, ctx context.Context, query []byte) ([]byte, error) {
t := time.Now()
c, err := ex.options.chain.DialContext(ctx,
c, err := ex.options.chain.DialContext(ip, ctx,
"tcp", ex.addr,
TimeoutChainOption(ex.options.timeout),
)
@ -782,6 +782,7 @@ func (ex *dnsTCPExchanger) Exchange(ctx context.Context, query []byte) ([]byte,
}
type dotExchanger struct {
ip net.IP
addr string
tlsConfig *tls.Config
options exchangerOptions
@ -810,8 +811,8 @@ func NewDoTExchanger(addr string, tlsConfig *tls.Config, opts ...ExchangerOption
}
}
func (ex *dotExchanger) dial(ctx context.Context, network, address string) (conn net.Conn, err error) {
conn, err = ex.options.chain.DialContext(ctx,
func (ex *dotExchanger) dial(ip net.IP, ctx context.Context, network, address string) (conn net.Conn, err error) {
conn, err = ex.options.chain.DialContext(ip, ctx,
network, address,
TimeoutChainOption(ex.options.timeout),
)
@ -823,9 +824,9 @@ func (ex *dotExchanger) dial(ctx context.Context, network, address string) (conn
return
}
func (ex *dotExchanger) Exchange(ctx context.Context, query []byte) ([]byte, error) {
func (ex *dotExchanger) Exchange(ip net.IP, ctx context.Context, query []byte) ([]byte, error) {
t := time.Now()
c, err := ex.dial(ctx, "tcp", ex.addr)
c, err := ex.dial(ip, ctx, "tcp", ex.addr)
if err != nil {
return nil, err
}
@ -881,8 +882,9 @@ func NewDoHExchanger(urlStr *url.URL, tlsConfig *tls.Config, opts ...ExchangerOp
return ex
}
func (ex *dohExchanger) dialContext(ctx context.Context, network, address string) (net.Conn, error) {
return ex.options.chain.DialContext(ctx,
func (ex *dohExchanger) dialContext(ip net.IP, ctx context.Context, network, address string) (net.Conn, error) {
// todo::
return ex.options.chain.DialContext(ip, ctx,
network, address,
TimeoutChainOption(ex.options.timeout),
)

3
sni.go

@ -134,6 +134,7 @@ func (h *sniHandler) Handle(conn net.Conn) {
var cc net.Conn
var route *Chain
ip := GetIP(conn)
for i := 0; i < retries; i++ {
route, err = h.options.Chain.selectRouteFor(host)
if err != nil {
@ -151,7 +152,7 @@ func (h *sniHandler) Handle(conn net.Conn) {
fmt.Fprintf(&buf, "%s", host)
log.Log("[route]", buf.String())
cc, err = route.Dial(host,
cc, err = route.Dial(ip, host,
TimeoutChainOption(h.options.Timeout),
HostsChainOption(h.options.Hosts),
ResolverChainOption(h.options.Resolver),

12
socks.go

@ -168,8 +168,8 @@ func (selector *serverSelector) OnSelected(method uint8, conn net.Conn) (string,
if Debug {
log.Logf("[socks5] %s - %s: %s", conn.RemoteAddr(), conn.LocalAddr(), req.String())
}
if selector.Authenticator != nil && !selector.Authenticator.Authenticate(req.Username, req.Password) {
ip := GetIP(conn)
if selector.Authenticator != nil && !selector.Authenticator.IFAuthenticate(ip, req.Username, req.Password) {
resp := gosocks5.NewUserPassResponse(gosocks5.UserPassVer, gosocks5.Failure)
if err := resp.Write(conn); err != nil {
log.Logf("[socks5] %s - %s: %s", conn.RemoteAddr(), conn.LocalAddr(), err)
@ -431,7 +431,7 @@ func (tr *socks5MuxBindTransporter) Dial(addr string, options ...DialOption) (co
if opts.Chain == nil {
conn, err = net.DialTimeout("tcp", addr, timeout)
} else {
conn, err = opts.Chain.Dial(addr)
conn, err = opts.Chain.Dial(nil, addr)
}
if err != nil {
return
@ -923,6 +923,7 @@ func (h *socks5Handler) handleConnect(conn net.Conn, req *gosocks5.Request) {
var err error
var cc net.Conn
var route *Chain
ip := GetIP(conn)
for i := 0; i < retries; i++ {
route, err = h.options.Chain.selectRouteFor(host)
if err != nil {
@ -940,7 +941,7 @@ func (h *socks5Handler) handleConnect(conn net.Conn, req *gosocks5.Request) {
fmt.Fprintf(&buf, "%s", host)
log.Log("[route]", buf.String())
cc, err = route.Dial(host,
cc, err = route.Dial(ip, host,
TimeoutChainOption(h.options.Timeout),
HostsChainOption(h.options.Hosts),
ResolverChainOption(h.options.Resolver),
@ -1773,6 +1774,7 @@ func (h *socks4Handler) handleConnect(conn net.Conn, req *gosocks4.Request) {
var err error
var cc net.Conn
var route *Chain
ip := GetIP(conn)
for i := 0; i < retries; i++ {
route, err = h.options.Chain.selectRouteFor(addr)
if err != nil {
@ -1790,7 +1792,7 @@ func (h *socks4Handler) handleConnect(conn net.Conn, req *gosocks4.Request) {
fmt.Fprintf(&buf, "%s", addr)
log.Log("[route]", buf.String())
cc, err = route.Dial(addr,
cc, err = route.Dial(ip, addr,
TimeoutChainOption(h.options.Timeout),
HostsChainOption(h.options.Hosts),
ResolverChainOption(h.options.Resolver),

7
ss.go

@ -164,6 +164,7 @@ func (h *shadowHandler) Handle(conn net.Conn) {
var cc net.Conn
var route *Chain
ip := GetIP(conn)
for i := 0; i < retries; i++ {
route, err = h.options.Chain.selectRouteFor(host)
if err != nil {
@ -180,8 +181,7 @@ func (h *shadowHandler) Handle(conn net.Conn) {
}
fmt.Fprintf(&buf, "%s", host)
log.Log("[route]", buf.String())
cc, err = route.Dial(host,
cc, err = route.Dial(ip, host,
TimeoutChainOption(h.options.Timeout),
HostsChainOption(h.options.Hosts),
ResolverChainOption(h.options.Resolver),
@ -297,7 +297,8 @@ func (h *shadowUDPHandler) Handle(conn net.Conn) {
defer conn.Close()
var cc net.PacketConn
c, err := h.options.Chain.DialContext(context.Background(), "udp", "")
ip := GetIP(conn)
c, err := h.options.Chain.DialContext(ip, context.Background(), "udp", "")
if err != nil {
log.Logf("[ssu] %s: %s", conn.LocalAddr(), err)
return

15
ssh.go

@ -194,7 +194,7 @@ func (tr *sshForwardTransporter) Dial(addr string, options ...DialOption) (conn
if opts.Chain == nil {
conn, err = net.DialTimeout("tcp", addr, timeout)
} else {
conn, err = opts.Chain.Dial(addr)
conn, err = opts.Chain.Dial(nil, addr)
}
if err != nil {
return
@ -308,7 +308,7 @@ func (tr *sshTunnelTransporter) Dial(addr string, options ...DialOption) (conn n
if opts.Chain == nil {
conn, err = net.DialTimeout("tcp", addr, timeout)
} else {
conn, err = opts.Chain.Dial(addr)
conn, err = opts.Chain.Dial(nil, addr)
}
if err != nil {
return
@ -583,9 +583,9 @@ func (h *sshForwardHandler) handleForward(conn ssh.Conn, chans <-chan ssh.NewCha
if p.Host1 == "<nil>" {
p.Host1 = ""
}
ip := GetSshIP(conn)
go ssh.DiscardRequests(requests)
go h.directPortForwardChannel(channel, fmt.Sprintf("%s:%d", p.Host1, p.Port1))
go h.directPortForwardChannel(ip, channel, fmt.Sprintf("%s:%d", p.Host1, p.Port1))
default:
log.Log("[ssh] Unknown channel type:", t)
newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t))
@ -596,7 +596,7 @@ func (h *sshForwardHandler) handleForward(conn ssh.Conn, chans <-chan ssh.NewCha
conn.Wait()
}
func (h *sshForwardHandler) directPortForwardChannel(channel ssh.Channel, raddr string) {
func (h *sshForwardHandler) directPortForwardChannel(ip net.IP, channel ssh.Channel, raddr string) {
defer channel.Close()
log.Logf("[ssh-tcp] %s - %s", h.options.Node.Addr, raddr)
@ -611,7 +611,7 @@ func (h *sshForwardHandler) directPortForwardChannel(channel ssh.Channel, raddr
return
}
conn, err := h.options.Chain.Dial(raddr,
conn, err := h.options.Chain.Dial(ip, raddr,
RetryChainOption(h.options.Retries),
TimeoutChainOption(h.options.Timeout),
HostsChainOption(h.options.Hosts),
@ -868,7 +868,8 @@ func defaultSSHPasswordCallback(au Authenticator) PasswordCallbackFunc {
return nil
}
return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
if au.Authenticate(conn.User(), string(password)) {
ip := GetSshIP(conn)
if au.IFAuthenticate(ip, conn.User(), string(password)) {
return nil, nil
}
log.Logf("[ssh] %s -> %s : password rejected for %s", conn.RemoteAddr(), conn.LocalAddr(), conn.User())

6
tcp.go

@ -1,6 +1,8 @@
package gost
import "net"
import (
"net"
)
// tcpTransporter is a raw TCP transporter.
type tcpTransporter struct{}
@ -23,7 +25,7 @@ func (tr *tcpTransporter) Dial(addr string, options ...DialOption) (net.Conn, er
if opts.Chain == nil {
return net.DialTimeout("tcp", addr, timeout)
}
return opts.Chain.Dial(addr)
return opts.Chain.Dial(nil, addr)
}
func (tr *tcpTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) {

2
tls.go

@ -74,7 +74,7 @@ func (tr *mtlsTransporter) Dial(addr string, options ...DialOption) (conn net.Co
if opts.Chain == nil {
conn, err = net.DialTimeout("tcp", addr, timeout)
} else {
conn, err = opts.Chain.Dial(addr)
conn, err = opts.Chain.Dial(nil, addr)
}
if err != nil {
return

7
tuntap.go

@ -162,13 +162,14 @@ func (h *tunHandler) Handle(conn net.Conn) {
}
var tempDelay time.Duration
ip := GetIP(conn)
for {
err := func() error {
var err error
var pc net.PacketConn
// fake tcp mode will be ignored when the client specifies a chain.
if raddr != nil && !h.options.Chain.IsEmpty() {
cc, err := h.options.Chain.DialContext(context.Background(), "udp", raddr.String())
cc, err := h.options.Chain.DialContext(ip, context.Background(), "udp", raddr.String())
if err != nil {
return err
}
@ -550,15 +551,15 @@ func (h *tapHandler) Handle(conn net.Conn) {
return
}
}
var tempDelay time.Duration
ip := GetIP(conn)
for {
err := func() error {
var err error
var pc net.PacketConn
// fake tcp mode will be ignored when the client specifies a chain.
if raddr != nil && !h.options.Chain.IsEmpty() {
cc, err := h.options.Chain.DialContext(context.Background(), "udp", raddr.String())
cc, err := h.options.Chain.DialContext(ip, context.Background(), "udp", raddr.String())
if err != nil {
return err
}

4
vsock.go

@ -27,10 +27,10 @@ func (tr *vsockTransporter) Dial(addr string, options ...DialOption) (net.Conn,
}
return vsock.Dial(vAddr.ContextID, vAddr.Port, nil)
}
return opts.Chain.Dial(addr)
return opts.Chain.Dial(nil, addr)
}
func parseUint32(s string) (uint32, error ) {
func parseUint32(s string) (uint32, error) {
n, err := strconv.ParseUint(s, 10, 32)
if err != nil {
return 0, err

4
ws.go

@ -104,7 +104,7 @@ func (tr *mwsTransporter) Dial(addr string, options ...DialOption) (conn net.Con
if opts.Chain == nil {
conn, err = net.DialTimeout("tcp", addr, timeout)
} else {
conn, err = opts.Chain.Dial(addr)
conn, err = opts.Chain.Dial(nil, addr)
}
if err != nil {
return
@ -261,7 +261,7 @@ func (tr *mwssTransporter) Dial(addr string, options ...DialOption) (conn net.Co
if opts.Chain == nil {
conn, err = net.DialTimeout("tcp", addr, timeout)
} else {
conn, err = opts.Chain.Dial(addr)
conn, err = opts.Chain.Dial(nil, addr)
}
if err != nil {
return

Loading…
Cancel
Save