Browse Source

1.支持同源IP进出

2.使用Infloww认证
pull/1077/head
karl 3 years ago
parent
commit
abecfac548
  1. 27
      auth.go
  2. 12
      chain.go
  3. 8
      client.go
  4. 40
      cmd/gost/route.go
  5. 18
      http.go
  6. 348
      quic.go
  7. 463
      quic_test.go
  8. 14
      socks.go

27
auth.go

@ -2,15 +2,42 @@ package gost
import (
"bufio"
"context"
"crypto/sha256"
"encoding/hex"
"io"
"net"
"strings"
"sync"
"time"
"github.com/go-log/log"
)
// Authenticator is an interface for user authentication.
type Authenticator interface {
Authenticate(user, password string) bool
InflowwAuthenticateContext(ctx context.Context, user, password string) bool
}
func (au *LocalAuthenticator) InflowwAuthenticateContext(ctx context.Context, user, password string) bool {
inboundIP := ctx.Value("InboundIP")
if inboundIP != nil {
ip := inboundIP.(net.IP)
if !ip.IsLoopback() && !ip.IsPrivate() {
p := ip.String()
src := p + user + "&&4sg123g[]/~"
hash := sha256.New()
hash.Write([]byte(src))
hashedSrc := hash.Sum(nil)
hashedSrcHex := hex.EncodeToString(hashedSrc)
if hashedSrcHex == password {
return true
} else {
log.Logf("user pass %s/%s, expect pass %s", user, password, hashedSrcHex)
}
}
}
return false
}
// LocalAuthenticator is an Authenticator that authenticates client by local key-value pairs.

12
chain.go

@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"net"
"strings"
"syscall"
"time"
@ -201,7 +202,16 @@ func (c *Chain) dialWithOptions(ctx context.Context, network, address string, op
d := &net.Dialer{
Timeout: timeout,
Control: controlFunction,
// LocalAddr: laddr, // TODO: optional local address
}
// use same ip between inbound and outbound
inboundIP := ctx.Value("InboundIP")
if inboundIP != nil && strings.ToLower(network) == "tcp" {
ip := inboundIP.(net.IP)
if !ip.IsLoopback() && !ip.IsPrivate() {
d.LocalAddr = &net.TCPAddr{
IP: ip,
}
}
}
return d.DialContext(ctx, network, ipAddr)
}

8
client.go

@ -121,7 +121,6 @@ type HandshakeOptions struct {
TLSConfig *tls.Config
WSOptions *WSOptions
KCPConfig *KCPConfig
QUICConfig *QUICConfig
SSHConfig *SSHConfig
}
@ -191,13 +190,6 @@ func KCPConfigHandshakeOption(config *KCPConfig) HandshakeOption {
}
}
// QUICConfigHandshakeOption specifies the QUIC config used by QUIC handshake
func QUICConfigHandshakeOption(config *QUICConfig) HandshakeOption {
return func(opts *HandshakeOptions) {
opts.QUICConfig = config
}
}
// SSHConfigHandshakeOption specifies the ssh config used by SSH client handshake.
func SSHConfigHandshakeOption(config *SSHConfig) HandshakeOption {
return func(opts *HandshakeOptions) {

40
cmd/gost/route.go

@ -1,7 +1,6 @@
package main
import (
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"encoding/base64"
@ -205,26 +204,6 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) {
} else {
tr = gost.SSHTunnelTransporter()
}
case "quic":
config := &gost.QUICConfig{
TLSConfig: tlsCfg,
KeepAlive: node.GetBool("keepalive"),
Timeout: timeout,
IdleTimeout: node.GetDuration("idle"),
}
if config.KeepAlive {
config.KeepAlivePeriod = node.GetDuration("ttl")
if config.KeepAlivePeriod == 0 {
config.KeepAlivePeriod = 10 * time.Second
}
}
if cipher := node.Get("cipher"); cipher != "" {
sum := sha256.Sum256([]byte(cipher))
config.Key = sum[:]
}
tr = gost.QUICTransporter(config)
case "http2":
tr = gost.HTTP2Transporter(tlsCfg)
case "h2":
@ -457,25 +436,6 @@ func (r *route) GenRouters() ([]router, error) {
} else {
ln, err = gost.SSHTunnelListener(node.Addr, config)
}
case "quic":
config := &gost.QUICConfig{
TLSConfig: tlsCfg,
KeepAlive: node.GetBool("keepalive"),
Timeout: timeout,
IdleTimeout: node.GetDuration("idle"),
}
if config.KeepAlive {
config.KeepAlivePeriod = node.GetDuration("ttl")
if config.KeepAlivePeriod == 0 {
config.KeepAlivePeriod = 10 * time.Second
}
}
if cipher := node.Get("cipher"); cipher != "" {
sum := sha256.Sum256([]byte(cipher))
config.Key = sum[:]
}
ln, err = gost.QUICListener(node.Addr, config)
case "http2":
ln, err = gost.HTTP2Listener(node.Addr, tlsCfg)
case "h2":

18
http.go

@ -142,6 +142,12 @@ func (h *httpHandler) handleRequest(conn net.Conn, req *http.Request) {
return
}
ctx := req.Context()
inboundAddr, ok := conn.LocalAddr().(*net.TCPAddr)
if ok {
ctx = context.WithValue(ctx, "InboundIP", inboundAddr.IP)
}
// try to get the actual host.
if v := req.Header.Get("Gost-Target"); v != "" {
if h, err := decodeServerName(v); err == nil {
@ -208,7 +214,7 @@ func (h *httpHandler) handleRequest(conn net.Conn, req *http.Request) {
return
}
if !h.authenticate(conn, req, resp) {
if !h.authenticateContext(ctx, conn, req, resp) {
return
}
@ -268,7 +274,7 @@ func (h *httpHandler) handleRequest(conn net.Conn, req *http.Request) {
continue
}
cc, err = route.Dial(host,
cc, err = route.DialContext(ctx, "tcp", host,
TimeoutChainOption(h.options.Timeout),
HostsChainOption(h.options.Hosts),
ResolverChainOption(h.options.Resolver),
@ -313,13 +319,13 @@ func (h *httpHandler) handleRequest(conn net.Conn, req *http.Request) {
log.Logf("[http] %s >-< %s", conn.RemoteAddr(), host)
}
func (h *httpHandler) authenticate(conn net.Conn, req *http.Request, resp *http.Response) (ok bool) {
func (h *httpHandler) authenticateContext(ctx context.Context, conn net.Conn, req *http.Request, resp *http.Response) (ok bool) {
u, p, _ := basicProxyAuth(req.Header.Get("Proxy-Authorization"))
if Debug && (u != "" || p != "") {
log.Logf("[http] %s -> %s : Authorization '%s' '%s'",
conn.RemoteAddr(), conn.LocalAddr(), u, p)
}
if h.options.Authenticator == nil || h.options.Authenticator.Authenticate(u, p) {
if h.options.Authenticator == nil || h.options.Authenticator.InflowwAuthenticateContext(ctx, u, p) {
return true
}
@ -340,7 +346,9 @@ func (h *httpHandler) authenticate(conn net.Conn, req *http.Request, resp *http.
resp = r
}
case "host":
cc, err := net.Dial("tcp", ss[1])
d := net.Dialer{
}
cc, err := d.DialContext(ctx, "tcp", ss[1])
if err == nil {
defer cc.Close()

348
quic.go

@ -1,348 +0,0 @@
package gost
import (
"context"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/tls"
"errors"
"io"
"net"
"sync"
"time"
"github.com/go-log/log"
quic "github.com/lucas-clemente/quic-go"
)
type quicSession struct {
session quic.EarlyConnection
}
func (session *quicSession) GetConn() (*quicConn, error) {
stream, err := session.session.OpenStreamSync(context.Background())
if err != nil {
return nil, err
}
return &quicConn{
Stream: stream,
laddr: session.session.LocalAddr(),
raddr: session.session.RemoteAddr(),
}, nil
}
func (session *quicSession) Close() error {
return session.session.CloseWithError(quic.ApplicationErrorCode(0), "closed")
}
type quicTransporter struct {
config *QUICConfig
sessionMutex sync.Mutex
sessions map[string]*quicSession
}
// QUICTransporter creates a Transporter that is used by QUIC proxy client.
func QUICTransporter(config *QUICConfig) Transporter {
if config == nil {
config = &QUICConfig{}
}
return &quicTransporter{
config: config,
sessions: make(map[string]*quicSession),
}
}
func (tr *quicTransporter) Dial(addr string, options ...DialOption) (conn net.Conn, err error) {
opts := &DialOptions{}
for _, option := range options {
option(opts)
}
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
tr.sessionMutex.Lock()
defer tr.sessionMutex.Unlock()
session, ok := tr.sessions[addr]
if !ok {
var pc net.PacketConn
pc, err = net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
return
}
if tr.config != nil && tr.config.Key != nil {
pc = &quicCipherConn{PacketConn: pc, key: tr.config.Key}
}
session, err = tr.initSession(udpAddr, pc)
if err != nil {
pc.Close()
return nil, err
}
tr.sessions[addr] = session
}
conn, err = session.GetConn()
if err != nil {
session.Close()
delete(tr.sessions, addr)
return nil, err
}
return conn, nil
}
func (tr *quicTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) {
return conn, nil
}
func (tr *quicTransporter) initSession(addr net.Addr, conn net.PacketConn) (*quicSession, error) {
config := tr.config
if config == nil {
config = &QUICConfig{}
}
if config.TLSConfig == nil {
config.TLSConfig = &tls.Config{InsecureSkipVerify: true}
}
quicConfig := &quic.Config{
HandshakeIdleTimeout: config.Timeout,
MaxIdleTimeout: config.IdleTimeout,
KeepAlivePeriod: config.KeepAlivePeriod,
Versions: []quic.VersionNumber{
quic.Version1,
quic.VersionDraft29,
},
}
session, err := quic.DialEarly(conn, addr, addr.String(), tlsConfigQUICALPN(config.TLSConfig), quicConfig)
if err != nil {
log.Logf("quic dial %s: %v", addr, err)
return nil, err
}
return &quicSession{session: session}, nil
}
func (tr *quicTransporter) Multiplex() bool {
return true
}
// QUICConfig is the config for QUIC client and server
type QUICConfig struct {
TLSConfig *tls.Config
Timeout time.Duration
KeepAlive bool
KeepAlivePeriod time.Duration
IdleTimeout time.Duration
Key []byte
}
type quicListener struct {
ln quic.EarlyListener
connChan chan net.Conn
errChan chan error
}
// QUICListener creates a Listener for QUIC proxy server.
func QUICListener(addr string, config *QUICConfig) (Listener, error) {
if config == nil {
config = &QUICConfig{}
}
quicConfig := &quic.Config{
HandshakeIdleTimeout: config.Timeout,
KeepAlivePeriod: config.KeepAlivePeriod,
MaxIdleTimeout: config.IdleTimeout,
Versions: []quic.VersionNumber{
quic.Version1,
quic.VersionDraft29,
},
}
tlsConfig := config.TLSConfig
if tlsConfig == nil {
tlsConfig = DefaultTLSConfig
}
var conn net.PacketConn
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
conn, err = net.ListenUDP("udp", udpAddr)
if err != nil {
return nil, err
}
if config.Key != nil {
conn = &quicCipherConn{PacketConn: conn, key: config.Key}
}
ln, err := quic.ListenEarly(conn, tlsConfigQUICALPN(tlsConfig), quicConfig)
if err != nil {
return nil, err
}
l := &quicListener{
ln: ln,
connChan: make(chan net.Conn, 1024),
errChan: make(chan error, 1),
}
go l.listenLoop()
return l, nil
}
func (l *quicListener) listenLoop() {
for {
session, err := l.ln.Accept(context.Background())
if err != nil {
log.Log("[quic] accept:", err)
l.errChan <- err
close(l.errChan)
return
}
go l.sessionLoop(session)
}
}
func (l *quicListener) sessionLoop(session quic.Connection) {
log.Logf("[quic] %s <-> %s", session.RemoteAddr(), session.LocalAddr())
defer log.Logf("[quic] %s >-< %s", session.RemoteAddr(), session.LocalAddr())
for {
stream, err := session.AcceptStream(context.Background())
if err != nil {
log.Log("[quic] accept stream:", err)
session.CloseWithError(quic.ApplicationErrorCode(0), "closed")
return
}
cc := &quicConn{Stream: stream, laddr: session.LocalAddr(), raddr: session.RemoteAddr()}
select {
case l.connChan <- cc:
default:
cc.Close()
log.Logf("[quic] %s - %s: connection queue is full", session.RemoteAddr(), session.LocalAddr())
}
}
}
func (l *quicListener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.connChan:
case err, ok = <-l.errChan:
if !ok {
err = errors.New("accpet on closed listener")
}
}
return
}
func (l *quicListener) Addr() net.Addr {
return l.ln.Addr()
}
func (l *quicListener) Close() error {
return l.ln.Close()
}
type quicConn struct {
quic.Stream
laddr net.Addr
raddr net.Addr
}
func (c *quicConn) LocalAddr() net.Addr {
return c.laddr
}
func (c *quicConn) RemoteAddr() net.Addr {
return c.raddr
}
type quicCipherConn struct {
net.PacketConn
key []byte
}
func (conn *quicCipherConn) ReadFrom(data []byte) (n int, addr net.Addr, err error) {
n, addr, err = conn.PacketConn.ReadFrom(data)
if err != nil {
return
}
b, err := conn.decrypt(data[:n])
if err != nil {
return
}
copy(data, b)
return len(b), addr, nil
}
func (conn *quicCipherConn) WriteTo(data []byte, addr net.Addr) (n int, err error) {
b, err := conn.encrypt(data)
if err != nil {
return
}
_, err = conn.PacketConn.WriteTo(b, addr)
if err != nil {
return
}
return len(b), nil
}
func (conn *quicCipherConn) encrypt(data []byte) ([]byte, error) {
c, err := aes.NewCipher(conn.key)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(c)
if err != nil {
return nil, err
}
nonce := make([]byte, gcm.NonceSize())
if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
return nil, err
}
return gcm.Seal(nonce, nonce, data, nil), nil
}
func (conn *quicCipherConn) decrypt(data []byte) ([]byte, error) {
c, err := aes.NewCipher(conn.key)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(c)
if err != nil {
return nil, err
}
nonceSize := gcm.NonceSize()
if len(data) < nonceSize {
return nil, errors.New("ciphertext too short")
}
nonce, ciphertext := data[:nonceSize], data[nonceSize:]
return gcm.Open(nil, nonce, ciphertext, nil)
}
func tlsConfigQUICALPN(tlsConfig *tls.Config) *tls.Config {
if tlsConfig == nil {
panic("quic: tlsconfig is nil")
}
tlsConfigQUIC := &tls.Config{}
*tlsConfigQUIC = *tlsConfig
tlsConfigQUIC.NextProtos = []string{"http/3", "quic/v1"}
return tlsConfigQUIC
}

463
quic_test.go

@ -1,463 +0,0 @@
package gost
import (
"crypto/rand"
"crypto/sha256"
"fmt"
"net/http/httptest"
"net/url"
"testing"
)
func httpOverQUICRoundtrip(targetURL string, data []byte,
clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error {
ln, err := QUICListener("localhost:0", nil)
if err != nil {
return err
}
client := &Client{
Connector: HTTPConnector(clientInfo),
Transporter: QUICTransporter(nil),
}
server := &Server{
Listener: ln,
Handler: HTTPHandler(
UsersHandlerOption(serverInfo...),
),
}
go server.Run()
defer server.Close()
return proxyRoundtrip(client, server, targetURL, data)
}
func TestHTTPOverQUIC(t *testing.T) {
httpSrv := httptest.NewServer(httpTestHandler)
defer httpSrv.Close()
sendData := make([]byte, 128)
rand.Read(sendData)
for i, tc := range httpProxyTests {
err := httpOverQUICRoundtrip(httpSrv.URL, sendData, tc.cliUser, tc.srvUsers)
if err == nil {
if tc.errStr != "" {
t.Errorf("#%d should failed with error %s", i, tc.errStr)
}
} else {
if tc.errStr == "" {
t.Errorf("#%d got error %v", i, err)
}
if err.Error() != tc.errStr {
t.Errorf("#%d got error %v, want %v", i, err, tc.errStr)
}
}
}
}
func BenchmarkHTTPOverQUIC(b *testing.B) {
httpSrv := httptest.NewServer(httpTestHandler)
defer httpSrv.Close()
sendData := make([]byte, 128)
rand.Read(sendData)
ln, err := QUICListener("localhost:0", nil)
if err != nil {
b.Error(err)
}
client := &Client{
Connector: HTTPConnector(url.UserPassword("admin", "123456")),
Transporter: QUICTransporter(&QUICConfig{KeepAlive: true}),
}
server := &Server{
Listener: ln,
Handler: HTTPHandler(
UsersHandlerOption(url.UserPassword("admin", "123456")),
),
}
go server.Run()
defer server.Close()
for i := 0; i < b.N; i++ {
if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil {
b.Error(err)
}
}
}
func BenchmarkHTTPOverQUICParallel(b *testing.B) {
httpSrv := httptest.NewServer(httpTestHandler)
defer httpSrv.Close()
sendData := make([]byte, 128)
rand.Read(sendData)
ln, err := QUICListener("localhost:0", nil)
if err != nil {
b.Error(err)
}
client := &Client{
Connector: HTTPConnector(url.UserPassword("admin", "123456")),
Transporter: QUICTransporter(nil),
}
server := &Server{
Listener: ln,
Handler: HTTPHandler(
UsersHandlerOption(url.UserPassword("admin", "123456")),
),
}
go server.Run()
defer server.Close()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil {
b.Error(err)
}
}
})
}
func socks5OverQUICRoundtrip(targetURL string, data []byte,
clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error {
ln, err := QUICListener("localhost:0", nil)
if err != nil {
return err
}
client := &Client{
Connector: SOCKS5Connector(clientInfo),
Transporter: QUICTransporter(nil),
}
server := &Server{
Listener: ln,
Handler: SOCKS5Handler(
UsersHandlerOption(serverInfo...),
),
}
go server.Run()
defer server.Close()
return proxyRoundtrip(client, server, targetURL, data)
}
func TestSOCKS5OverQUIC(t *testing.T) {
httpSrv := httptest.NewServer(httpTestHandler)
defer httpSrv.Close()
sendData := make([]byte, 128)
rand.Read(sendData)
for i, tc := range socks5ProxyTests {
err := socks5OverQUICRoundtrip(httpSrv.URL, sendData,
tc.cliUser,
tc.srvUsers,
)
if err == nil {
if !tc.pass {
t.Errorf("#%d should failed", i)
}
} else {
// t.Logf("#%d %v", i, err)
if tc.pass {
t.Errorf("#%d got error: %v", i, err)
}
}
}
}
func socks4OverQUICRoundtrip(targetURL string, data []byte) error {
ln, err := QUICListener("localhost:0", nil)
if err != nil {
return err
}
client := &Client{
Connector: SOCKS4Connector(),
Transporter: QUICTransporter(nil),
}
server := &Server{
Listener: ln,
Handler: SOCKS4Handler(),
}
go server.Run()
defer server.Close()
return proxyRoundtrip(client, server, targetURL, data)
}
func TestSOCKS4OverQUIC(t *testing.T) {
httpSrv := httptest.NewServer(httpTestHandler)
defer httpSrv.Close()
sendData := make([]byte, 128)
rand.Read(sendData)
err := socks4OverQUICRoundtrip(httpSrv.URL, sendData)
// t.Logf("#%d %v", i, err)
if err != nil {
t.Errorf("got error: %v", err)
}
}
func socks4aOverQUICRoundtrip(targetURL string, data []byte) error {
ln, err := QUICListener("localhost:0", nil)
if err != nil {
return err
}
client := &Client{
Connector: SOCKS4AConnector(),
Transporter: QUICTransporter(nil),
}
server := &Server{
Listener: ln,
Handler: SOCKS4Handler(),
}
go server.Run()
defer server.Close()
return proxyRoundtrip(client, server, targetURL, data)
}
func TestSOCKS4AOverQUIC(t *testing.T) {
httpSrv := httptest.NewServer(httpTestHandler)
defer httpSrv.Close()
sendData := make([]byte, 128)
rand.Read(sendData)
err := socks4aOverQUICRoundtrip(httpSrv.URL, sendData)
// t.Logf("#%d %v", i, err)
if err != nil {
t.Errorf("got error: %v", err)
}
}
func ssOverQUICRoundtrip(targetURL string, data []byte,
clientInfo, serverInfo *url.Userinfo) error {
ln, err := QUICListener("localhost:0", nil)
if err != nil {
return err
}
client := &Client{
Connector: ShadowConnector(clientInfo),
Transporter: QUICTransporter(nil),
}
server := &Server{
Listener: ln,
Handler: ShadowHandler(
UsersHandlerOption(serverInfo),
),
}
go server.Run()
defer server.Close()
return proxyRoundtrip(client, server, targetURL, data)
}
func TestSSOverQUIC(t *testing.T) {
httpSrv := httptest.NewServer(httpTestHandler)
defer httpSrv.Close()
sendData := make([]byte, 128)
rand.Read(sendData)
for i, tc := range ssProxyTests {
err := ssOverQUICRoundtrip(httpSrv.URL, sendData,
tc.clientCipher,
tc.serverCipher,
)
if err == nil {
if !tc.pass {
t.Errorf("#%d should failed", i)
}
} else {
// t.Logf("#%d %v", i, err)
if tc.pass {
t.Errorf("#%d got error: %v", i, err)
}
}
}
}
func sniOverQUICRoundtrip(targetURL string, data []byte, host string) error {
ln, err := QUICListener("localhost:0", nil)
if err != nil {
return err
}
u, err := url.Parse(targetURL)
if err != nil {
return err
}
client := &Client{
Connector: SNIConnector(host),
Transporter: QUICTransporter(nil),
}
server := &Server{
Listener: ln,
Handler: SNIHandler(HostHandlerOption(u.Host)),
}
go server.Run()
defer server.Close()
return sniRoundtrip(client, server, targetURL, data)
}
func TestSNIOverQUIC(t *testing.T) {
httpSrv := httptest.NewServer(httpTestHandler)
defer httpSrv.Close()
httpsSrv := httptest.NewTLSServer(httpTestHandler)
defer httpsSrv.Close()
sendData := make([]byte, 128)
rand.Read(sendData)
var sniProxyTests = []struct {
targetURL string
host string
pass bool
}{
{httpSrv.URL, "", true},
{httpSrv.URL, "example.com", true},
{httpsSrv.URL, "", true},
{httpsSrv.URL, "example.com", true},
}
for i, tc := range sniProxyTests {
tc := tc
t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) {
err := sniOverQUICRoundtrip(tc.targetURL, sendData, tc.host)
if err == nil {
if !tc.pass {
t.Errorf("#%d should failed", i)
}
} else {
// t.Logf("#%d %v", i, err)
if tc.pass {
t.Errorf("#%d got error: %v", i, err)
}
}
})
}
}
func quicForwardTunnelRoundtrip(targetURL string, data []byte) error {
ln, err := QUICListener("localhost:0", nil)
if err != nil {
return err
}
u, err := url.Parse(targetURL)
if err != nil {
return err
}
client := &Client{
Connector: ForwardConnector(),
Transporter: QUICTransporter(nil),
}
server := &Server{
Listener: ln,
Handler: TCPDirectForwardHandler(u.Host),
}
server.Handler.Init()
go server.Run()
defer server.Close()
return proxyRoundtrip(client, server, targetURL, data)
}
func TestQUICForwardTunnel(t *testing.T) {
httpSrv := httptest.NewServer(httpTestHandler)
defer httpSrv.Close()
sendData := make([]byte, 128)
rand.Read(sendData)
err := quicForwardTunnelRoundtrip(httpSrv.URL, sendData)
if err != nil {
t.Error(err)
}
}
func httpOverCipherQUICRoundtrip(targetURL string, data []byte,
clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error {
sum := sha256.Sum256([]byte("12345678"))
cfg := &QUICConfig{
Key: sum[:],
}
ln, err := QUICListener("localhost:0", cfg)
if err != nil {
return err
}
client := &Client{
Connector: HTTPConnector(clientInfo),
Transporter: QUICTransporter(cfg),
}
server := &Server{
Listener: ln,
Handler: HTTPHandler(
UsersHandlerOption(serverInfo...),
),
}
go server.Run()
defer server.Close()
return proxyRoundtrip(client, server, targetURL, data)
}
func TestHTTPOverCipherQUIC(t *testing.T) {
httpSrv := httptest.NewServer(httpTestHandler)
defer httpSrv.Close()
sendData := make([]byte, 128)
rand.Read(sendData)
for i, tc := range httpProxyTests {
err := httpOverCipherQUICRoundtrip(httpSrv.URL, sendData, tc.cliUser, tc.srvUsers)
if err == nil {
if tc.errStr != "" {
t.Errorf("#%d should failed with error %s", i, tc.errStr)
}
} else {
if tc.errStr == "" {
t.Errorf("#%d got error %v", i, err)
}
if err.Error() != tc.errStr {
t.Errorf("#%d got error %v, want %v", i, err, tc.errStr)
}
}
}
}

14
socks.go

@ -168,7 +168,12 @@ func (selector *serverSelector) OnSelected(method uint8, conn net.Conn) (net.Con
log.Logf("[socks5] %s - %s: %s", conn.RemoteAddr(), conn.LocalAddr(), req.String())
}
if selector.Authenticator != nil && !selector.Authenticator.Authenticate(req.Username, req.Password) {
ctx := context.Background()
inboundAddr, ok := conn.LocalAddr().(*net.TCPAddr)
if ok {
ctx = context.WithValue(ctx, "InboundIP", inboundAddr.IP)
}
if selector.Authenticator != nil && !selector.Authenticator.InflowwAuthenticateContext(ctx, req.Username, req.Password) {
resp := gosocks5.NewUserPassResponse(gosocks5.UserPassVer, gosocks5.Failure)
if err := resp.Write(conn); err != nil {
log.Logf("[socks5] %s - %s: %s", conn.RemoteAddr(), conn.LocalAddr(), err)
@ -939,7 +944,12 @@ func (h *socks5Handler) handleConnect(conn net.Conn, req *gosocks5.Request) {
fmt.Fprintf(&buf, "%s", host)
log.Log("[route]", buf.String())
cc, err = route.Dial(host,
ctx := context.Background()
inboundAddr, ok := conn.LocalAddr().(*net.TCPAddr)
if ok {
ctx = context.WithValue(ctx, "InboundIP", inboundAddr.IP)
}
cc, err = route.DialContext(ctx, "tcp", host,
TimeoutChainOption(h.options.Timeout),
HostsChainOption(h.options.Hosts),
ResolverChainOption(h.options.Resolver),

Loading…
Cancel
Save