diff --git a/auth.go b/auth.go index 4ded573..c2870a9 100644 --- a/auth.go +++ b/auth.go @@ -2,8 +2,6 @@ package gost import ( "bufio" - "crypto/sha256" - "encoding/hex" "io" "net" "strings" @@ -13,6 +11,9 @@ import ( "github.com/go-log/log" ) +// 防止 OTP 重放 +var usedOTP sync.Map + // Authenticator is an interface for user authentication. type Authenticator interface { Authenticate(user, password string) bool @@ -20,28 +21,41 @@ type Authenticator interface { } func (au *LocalAuthenticator) IFAuthenticate(ip net.IP, user, password string) bool { - if ip != nil { - if !ip.IsLoopback() && !ip.IsPrivate() { - expected := GeneratePass(ip.String(), user) - if expected == password { - return true - } else { - log.Logf("user pass %s/%s, expect pass %s", user, password, expected) - } + if ip == nil { + return false + } + + if isWhiteIP(ip) { + + expected := GeneratePassword(ip.String(), user) + if expected == password { + return true + } else { + log.Logf("user pass %s/%s, expect pass %s", user, password, expected) + } + } else { + // if !ip.IsLoopback() && !ip.IsPrivate() { // 存的时候已经判断. + secret := generateSecret(ip.String(), user) + ok, counter := verifyOTP(secret, password) + + if !ok { + log.Logf("otp verify fail user=%s ip=%s pass=%s", user, ip, password) + return false } + + // 防止 OTP 重放 + key := user + ":" + ip.String() + ":" + password + if _, exists := usedOTP.Load(key); exists { + log.Logf("otp replay attack user=%s ip=%s", user, ip) + return false + } + usedOTP.Store(key, counter) + + return true } return false } -func GeneratePass(ip, user string) string { - src := ip + user + "&&4sg123g[]/~" - hash := sha256.New() - hash.Write([]byte(src)) - hashedSrc := hash.Sum(nil) - hashedSrcHex := hex.EncodeToString(hashedSrc) - return hashedSrcHex -} - // LocalAuthenticator is an Authenticator that authenticates client by local key-value pairs. type LocalAuthenticator struct { kvs map[string]string diff --git a/auth_emai.go b/auth_emai.go new file mode 100644 index 0000000..77a1f9a --- /dev/null +++ b/auth_emai.go @@ -0,0 +1,122 @@ +package gost + +import ( + "fmt" + "net" + "regexp" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/go-log/log" +) + +// 25 端口 465 和 587 +var mailPorts = []string{"25", "465", "587"} + +type EmailACL struct { + emails map[string]struct{} + domains map[string]struct{} + regex []*regexp.Regexp +} + +var emailACL atomic.Value + +func LoadEmailACL(list []string, regexList []string) error { + acl := &EmailACL{ + emails: map[string]struct{}{}, + domains: map[string]struct{}{}, + } + for _, v := range list { + v = strings.ToLower(strings.TrimSpace(v)) + if strings.HasPrefix(v, "@") { + domain := strings.TrimPrefix(v, "@") + acl.domains[domain] = struct{}{} + } else { + acl.emails[v] = struct{}{} + } + } + for _, r := range regexList { + re, err := regexp.Compile(r) + if err != nil { + return err + } + acl.regex = append(acl.regex, re) + } + emailACL.Store(acl) + return nil +} + +func IsEmailAllowed(email string) bool { + acl := emailACL.Load().(*EmailACL) + email = strings.ToLower(strings.TrimSpace(email)) + // 白名单为空 → 拒绝所有 + if len(acl.emails) == 0 && len(acl.domains) == 0 && len(acl.regex) == 0 { + return false + } + // 精确匹配 + if _, ok := acl.emails[email]; ok { + return true + } + // domain 匹配 + parts := strings.Split(email, "@") + if len(parts) == 2 { + domain := parts[1] + if _, ok := acl.domains[domain]; ok { + return true + } + } + // regex + for _, r := range acl.regex { + if r.MatchString(email) { + return true + } + } + return false +} + +func CheckMailFrom(email string) error { + if !IsEmailAllowed(email) { + log.Logf("smtp blocked email: %s", email) + return fmt.Errorf("550 sender not allowed") + } + return nil +} + +type RateLimit struct { + count int + lastTime time.Time +} + +var rateLimitMap sync.Map // key: ip/user, value: *RateLimit + +func CheckRateLimit(ip net.IP, user string, maxPerMinute int) bool { + + key := ip.String() + ":" + user + now := time.Now() + + v, _ := rateLimitMap.LoadOrStore(key, &RateLimit{ + count: 0, + lastTime: now, + }) + + rl := v.(*RateLimit) + + // 超过一分钟窗口 → 重置计数 + if now.Sub(rl.lastTime) > time.Minute { + rl.count = 0 + rl.lastTime = now + } + + if rl.count >= maxPerMinute { + return false + } + + rl.count++ + return true +} + +// if !CheckRateLimit(clientIP, username, 50) { +// return fmt.Errorf("451 Too many messages, rate limit exceeded") +// } diff --git a/auth_ip.go b/auth_ip.go new file mode 100644 index 0000000..f1e3e54 --- /dev/null +++ b/auth_ip.go @@ -0,0 +1,65 @@ +package gost + +import ( + "fmt" + "net" + "strings" + "sync/atomic" +) + +var whiteList atomic.Value + +type IPWhiteList struct { + networks []*net.IPNet +} + +func NewIPWhiteList(list []string) (*IPWhiteList, error) { + wl := &IPWhiteList{} + for _, item := range list { + if strings.Contains(item, "/") { + _, netw, err := net.ParseCIDR(item) + if err != nil { + return nil, err + } + wl.networks = append(wl.networks, netw) + } else { + ip := net.ParseIP(item) + if ip == nil { + return nil, fmt.Errorf("invalid ip: %s", item) + } + mask := net.CIDRMask(32, 32) + wl.networks = append(wl.networks, &net.IPNet{ + IP: ip, + Mask: mask, + }) + } + } + + return wl, nil +} + +func (w *IPWhiteList) Contains(ip net.IP) bool { + if ip == nil { + return false + } + for _, netw := range w.networks { + if netw.Contains(ip) { + return true + } + } + return false +} + +func LoadIPWhiteList(list []string) error { + wl, err := NewIPWhiteList(list) + if err != nil { + return err + } + whiteList.Store(wl) + return nil +} + +func isWhiteIP(ip net.IP) bool { + wl := whiteList.Load().(*IPWhiteList) + return wl.Contains(ip) +} diff --git a/chain.go b/chain.go index 8ad7a89..d426373 100644 --- a/chain.go +++ b/chain.go @@ -199,7 +199,7 @@ func (c *Chain) dialWithOptions(ctx context.Context, network, address string, op default: } - localAddr := getLocalAddr(ctx) + localAddr := getLocalAddr(ctx, options) d := &net.Dialer{ Timeout: timeout, Control: controlFunction, diff --git a/cmd/gost/gost.toml b/cmd/gost/gost.toml new file mode 100644 index 0000000..50960ac --- /dev/null +++ b/cmd/gost/gost.toml @@ -0,0 +1,23 @@ +[auth] +#动态口令周期 (s) 600=10分钟 +dynamic_period = 600 +#时间漂移 +dynamic_skew = 1 + +ip_whitelist = [ + "198.144.184.47", + "108.174.48.108", + "23.94.205.145", + "108.174.48.102", + "198.144.184.108", + "198.144.184.57", + "198.144.184.125", +] + +#动态密钥 +secret = "&&4sg123g[]/~" + +# SMTP 发信白名单 +email_whitelist = ["admin@example.com", "@example.com"] + +email_regex = ["^test[0-9]+@example.com$"] diff --git a/cmd/gost/main.go b/cmd/gost/main.go index 29740da..d0085ac 100644 --- a/cmd/gost/main.go +++ b/cmd/gost/main.go @@ -8,9 +8,6 @@ import ( "net/http" "os" "runtime" - "strings" - - _ "net/http/pprof" "github.com/ginuerzh/gost" "github.com/go-log/log" @@ -25,11 +22,6 @@ var ( func init() { - if len(os.Args) == 4 && strings.ToLower(os.Args[1]) == "genpass" { - fmt.Println(gost.GeneratePass(os.Args[2], os.Args[3])) - os.Exit(0) - } - gost.SetLogger(&gost.LogLogger{}) var ( @@ -93,6 +85,8 @@ func main() { gost.DefaultTLSConfig = tlsConfig + gost.LoadAuthConfig() + if err := start(); err != nil { log.Log(err) os.Exit(1) diff --git a/config.go b/config.go new file mode 100644 index 0000000..14fbf71 --- /dev/null +++ b/config.go @@ -0,0 +1,30 @@ +package gost + +import ( + "github.com/BurntSushi/toml" + "github.com/go-log/log" +) + +type Config struct { + Auth struct { + DynamicPeriod int64 `toml:"dynamic_period"` + IPWhiteList []string `toml:"ip_whitelist"` + EmailWhiteList []string `toml:"email_whitelist"` + EmailRegWhiteList []string `toml:"email_regex"` + DynamicSkew int `toml:"dynamic_skew"` + Secret string `toml:"secret"` + } `toml:"auth"` +} + +var config Config + +func LoadAuthConfig() { + + _, err := toml.DecodeFile("auth.toml", &config) + if err != nil { + log.Log("not found auth.toml", err) + } + + LoadIPWhiteList(config.Auth.IPWhiteList) + LoadEmailACL(config.Auth.EmailWhiteList, config.Auth.EmailRegWhiteList) +} diff --git a/go.mod b/go.mod index 2c94c1e..9454d9f 100644 --- a/go.mod +++ b/go.mod @@ -33,7 +33,9 @@ require ( require ( filippo.io/edwards25519 v1.2.0 // indirect + github.com/BurntSushi/toml v1.6.0 // indirect github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da // indirect + github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect github.com/coreos/go-iptables v0.8.0 // indirect github.com/dchest/siphash v1.2.3 // indirect github.com/google/go-cmp v0.7.0 // indirect @@ -42,6 +44,7 @@ require ( github.com/klauspost/reedsolomon v1.13.2 // indirect github.com/mdlayher/socket v0.5.1 // indirect github.com/pkg/errors v0.9.1 // indirect + github.com/pquerna/otp v1.5.0 // indirect github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect github.com/tjfoc/gmsm v1.4.1 // indirect github.com/xtaci/lossyconn v1.0.0 // indirect diff --git a/go.sum b/go.sum index 0a76a26..420b1ce 100644 --- a/go.sum +++ b/go.sum @@ -3,17 +3,22 @@ filippo.io/edwards25519 v1.0.0/go.mod h1:N1IkdkCkiLB6tki+MYJoSx2JTY9NUlxZE7eHn5E filippo.io/edwards25519 v1.2.0 h1:crnVqOiS4jqYleHd9vaKZ+HKtHfllngJIiOpNpoJsjo= filippo.io/edwards25519 v1.2.0/go.mod h1:xzAOLCNug/yB62zG1bQ8uziwrIqIuxhctzJT18Q77mc= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/BurntSushi/toml v1.6.0 h1:dRaEfpa2VI55EwlIW72hMRHdWouJeRF7TPYhI+AUQjk= +github.com/BurntSushi/toml v1.6.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= github.com/LiamHaworth/go-tproxy v0.0.0-20190726054950-ef7efd7f24ed h1:eqa6queieK8SvoszxCu0WwH7lSVeL4/N/f1JwOMw1G4= github.com/LiamHaworth/go-tproxy v0.0.0-20190726054950-ef7efd7f24ed/go.mod h1:rA52xkgZwql9LRZXWb2arHEFP6qSR48KY2xOfWzEciQ= github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY= github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da/go.mod h1:eHEWzANqSiWQsof+nXEI9bUVUyV6F53Fp89EuCh2EAA= github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3dyBCFEj5IhUbnKptjxatkF07cF2ak3yi77so= github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= +github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI= +github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/coreos/go-iptables v0.8.0 h1:MPc2P89IhuVpLI7ETL/2tx3XZ61VeICZjYqDEgNsPRc= github.com/coreos/go-iptables v0.8.0/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dchest/siphash v1.2.3 h1:QXwFc8cFOR2dSa/gE6o/HokBMWtLUaNDVd+22aKHeEA= @@ -70,6 +75,8 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pquerna/otp v1.5.0 h1:NMMR+WrmaqXU4EzdGJEE1aUUI0AMRzsp96fFFWNPwxs= +github.com/pquerna/otp v1.5.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/quic-go/quic-go v0.59.0 h1:OLJkp1Mlm/aS7dpKgTc6cnpynnD2Xg7C1pwL6vy/SAw= github.com/quic-go/quic-go v0.59.0/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU= @@ -83,6 +90,8 @@ github.com/shadowsocks/shadowsocks-go v0.0.0-20200409064450-3e585ff90601 h1:XU9h github.com/shadowsocks/shadowsocks-go v0.0.0-20200409064450-3e585ff90601/go.mod h1:mttDPaeLm87u74HMrP+n2tugXvIKWcwff/cqSX0lehY= github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D9tmUCz4VNwm9MfrtPr0SU2qSX8= github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/tjfoc/gmsm v1.4.1 h1:aMe1GlZb+0bLjn+cKTPEvvn9oUEBlJitaZiiBwsbgho= diff --git a/http2.go b/http2.go index e79a73d..f260408 100644 --- a/http2.go +++ b/http2.go @@ -142,7 +142,7 @@ func (tr *http2Transporter) Dial(addr string, options ...DialOption) (net.Conn, if !ok { // NOTE: There is no real connection to the HTTP2 server at this moment. // So we try to connect to the server to check the server health. - conn, err := opts.Chain.Dial(nil, addr) + conn, err := opts.Chain.Dial(addr) if err != nil { log.Log("http2 dial:", addr, err) return nil, err @@ -156,7 +156,7 @@ func (tr *http2Transporter) Dial(addr string, options ...DialOption) (net.Conn, transport := http2.Transport{ TLSClientConfig: tr.tlsConfig, DialTLS: func(network, adr string, cfg *tls.Config) (net.Conn, error) { - conn, err := opts.Chain.Dial(nil, adr) + conn, err := opts.Chain.Dial(adr) if err != nil { return nil, err } @@ -234,7 +234,7 @@ func (tr *h2Transporter) Dial(addr string, options ...DialOption) (net.Conn, err transport := http2.Transport{ TLSClientConfig: tr.tlsConfig, DialTLS: func(network, adr string, cfg *tls.Config) (net.Conn, error) { - conn, err := opts.Chain.Dial(nil, addr) + conn, err := opts.Chain.Dial(addr) if err != nil { return nil, err } diff --git a/localAddr.go b/localAddr.go index 3602725..952fcdd 100644 --- a/localAddr.go +++ b/localAddr.go @@ -17,13 +17,6 @@ func getIP(conn net.Conn) (ip net.IP) { return nil } -func getContext1(ip net.IP) (ctx context.Context) { - ctx = context.Background() - if ip != nil { - return context.WithValue(ctx, localAddrKey, ip) - } - return -} func getContext(conn net.Conn, parentCtx context.Context) (ctx context.Context) { IP := getIP(conn) if IP != nil { @@ -48,8 +41,11 @@ func GetSshIP(conn ssh.ConnMetadata) (ip net.IP) { return nil } -func getLocalAddr(ctx context.Context) (addr net.Addr) { +func getLocalAddr(ctx context.Context, options *ChainOptions) (addr net.Addr) { ip := GetIP(ctx) + if ip == nil && options != nil { + ip = options.IP + } if ip != nil { addr = &net.TCPAddr{ IP: ip, diff --git a/resolver.go b/resolver.go index 4cecc40..8a2259a 100644 --- a/resolver.go +++ b/resolver.go @@ -282,7 +282,7 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) { ctx := context.Background() for _, ns := range r.copyServers() { - ips, err = r.resolve(ip, ctx, ns.exchanger, host) + ips, err = r.resolve(ctx, ns.exchanger, host) if err != nil { log.Logf("[resolver] %s via %s : %s", host, ns.String(), err) continue @@ -299,7 +299,7 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) { return } -func (r *resolver) resolve(ip net.IP, ctx context.Context, ex Exchanger, host string) (ips []net.IP, err error) { +func (r *resolver) resolve(ctx context.Context, ex Exchanger, host string) (ips []net.IP, err error) { if ex == nil { return } @@ -309,36 +309,36 @@ func (r *resolver) resolve(ip net.IP, ctx context.Context, ex Exchanger, host st r.mux.RUnlock() if prefer == "ipv6" { // prefer ipv6 - if ips, err = r.resolve6(ip, ctx, ex, host); len(ips) > 0 { + if ips, err = r.resolve6(ctx, ex, host); len(ips) > 0 { return } - return r.resolve4(ip, ctx, ex, host) + return r.resolve4(ctx, ex, host) } - if ips, err = r.resolve4(ip, ctx, ex, host); len(ips) > 0 { + if ips, err = r.resolve4(ctx, ex, host); len(ips) > 0 { return } - return r.resolve6(ip, ctx, ex, host) + return r.resolve6(ctx, ex, host) } -func (r *resolver) resolve4(ip net.IP, ctx context.Context, ex Exchanger, host string) (ips []net.IP, err error) { +func (r *resolver) resolve4(ctx context.Context, ex Exchanger, host string) (ips []net.IP, err error) { mq := dns.Msg{} mq.SetQuestion(dns.Fqdn(host), dns.TypeA) - return r.resolveIPs(ip, ctx, ex, &mq) + return r.resolveIPs(ctx, ex, &mq) } -func (r *resolver) resolve6(ip net.IP, ctx context.Context, ex Exchanger, host string) (ips []net.IP, err error) { +func (r *resolver) resolve6(ctx context.Context, ex Exchanger, host string) (ips []net.IP, err error) { mq := dns.Msg{} mq.SetQuestion(dns.Fqdn(host), dns.TypeAAAA) - return r.resolveIPs(ip, ctx, ex, &mq) + return r.resolveIPs(ctx, ex, &mq) } -func (r *resolver) resolveIPs(ip net.IP, ctx context.Context, ex Exchanger, mq *dns.Msg) (ips []net.IP, err error) { +func (r *resolver) resolveIPs(ctx context.Context, ex Exchanger, mq *dns.Msg) (ips []net.IP, err error) { key := newResolverCacheKey(&mq.Question[0]) mr := r.cache.loadCache(key) if mr == nil { r.addSubnetOpt(mq) - mr, err = r.exchangeMsg(ip, ctx, ex, mq) + mr, err = r.exchangeMsg(ctx, ex, mq) if err != nil { return } @@ -753,9 +753,9 @@ func NewDNSTCPExchanger(addr string, opts ...ExchangerOption) Exchanger { } } -func (ex *dnsTCPExchanger) Exchange(ip net.IP, ctx context.Context, query []byte) ([]byte, error) { +func (ex *dnsTCPExchanger) Exchange(ctx context.Context, query []byte) ([]byte, error) { t := time.Now() - c, err := ex.options.chain.DialContext(ip, ctx, + c, err := ex.options.chain.DialContext(ctx, "tcp", ex.addr, TimeoutChainOption(ex.options.timeout), ) @@ -810,8 +810,8 @@ func NewDoTExchanger(addr string, tlsConfig *tls.Config, opts ...ExchangerOption } } -func (ex *dotExchanger) dial(ip net.IP, ctx context.Context, network, address string) (conn net.Conn, err error) { - conn, err = ex.options.chain.DialContext(ip, ctx, +func (ex *dotExchanger) dial(ctx context.Context, network, address string) (conn net.Conn, err error) { + conn, err = ex.options.chain.DialContext(ctx, network, address, TimeoutChainOption(ex.options.timeout), ) @@ -823,9 +823,9 @@ func (ex *dotExchanger) dial(ip net.IP, ctx context.Context, network, address st return } -func (ex *dotExchanger) Exchange(ip net.IP, ctx context.Context, query []byte) ([]byte, error) { +func (ex *dotExchanger) Exchange(ctx context.Context, query []byte) ([]byte, error) { t := time.Now() - c, err := ex.dial(ip, ctx, "tcp", ex.addr) + c, err := ex.dial(ctx, "tcp", ex.addr) if err != nil { return nil, err } @@ -881,9 +881,8 @@ func NewDoHExchanger(urlStr *url.URL, tlsConfig *tls.Config, opts ...ExchangerOp return ex } -func (ex *dohExchanger) dialContext(ip net.IP, ctx context.Context, network, address string) (net.Conn, error) { - // todo:: - return ex.options.chain.DialContext(ip, ctx, +func (ex *dohExchanger) dialContext(ctx context.Context, network, address string) (net.Conn, error) { + return ex.options.chain.DialContext(ctx, network, address, TimeoutChainOption(ex.options.timeout), ) diff --git a/socks.go b/socks.go index 08db5b6..650537a 100644 --- a/socks.go +++ b/socks.go @@ -431,7 +431,7 @@ func (tr *socks5MuxBindTransporter) Dial(addr string, options ...DialOption) (co if opts.Chain == nil { conn, err = net.DialTimeout("tcp", addr, timeout) } else { - conn, err = opts.Chain.Dial(nil, addr) + conn, err = opts.Chain.Dial(addr) } if err != nil { return diff --git a/ssh.go b/ssh.go index 007e71d..efcf327 100644 --- a/ssh.go +++ b/ssh.go @@ -194,7 +194,7 @@ func (tr *sshForwardTransporter) Dial(addr string, options ...DialOption) (conn if opts.Chain == nil { conn, err = net.DialTimeout("tcp", addr, timeout) } else { - conn, err = opts.Chain.Dial(nil, addr) + conn, err = opts.Chain.Dial(addr) } if err != nil { return @@ -308,7 +308,7 @@ func (tr *sshTunnelTransporter) Dial(addr string, options ...DialOption) (conn n if opts.Chain == nil { conn, err = net.DialTimeout("tcp", addr, timeout) } else { - conn, err = opts.Chain.Dial(nil, addr) + conn, err = opts.Chain.Dial(addr) } if err != nil { return diff --git a/tcp.go b/tcp.go index 1e1ecd9..42a1532 100644 --- a/tcp.go +++ b/tcp.go @@ -25,7 +25,7 @@ func (tr *tcpTransporter) Dial(addr string, options ...DialOption) (net.Conn, er if opts.Chain == nil { return net.DialTimeout("tcp", addr, timeout) } - return opts.Chain.Dial(nil, addr) + return opts.Chain.Dial(addr) } func (tr *tcpTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { diff --git a/tls.go b/tls.go index a0931f6..8526c6f 100644 --- a/tls.go +++ b/tls.go @@ -74,7 +74,7 @@ func (tr *mtlsTransporter) Dial(addr string, options ...DialOption) (conn net.Co if opts.Chain == nil { conn, err = net.DialTimeout("tcp", addr, timeout) } else { - conn, err = opts.Chain.Dial(nil, addr) + conn, err = opts.Chain.Dial(addr) } if err != nil { return diff --git a/totp.go b/totp.go new file mode 100644 index 0000000..fa1dcba --- /dev/null +++ b/totp.go @@ -0,0 +1,72 @@ +package gost + +import ( + "crypto/hmac" + "crypto/sha1" + "crypto/sha256" + "encoding/binary" + "encoding/hex" + "fmt" + "time" +) + +func VerifyOTP(secret, pass string) bool { + now := time.Now().UTC().Unix() + skew := config.Auth.DynamicSkew + period := config.Auth.DynamicPeriod + + for i := -skew; i <= skew; i++ { + + t := (now / period) + int64(i) + + code := generateTOTP(secret, t) + + if code == pass { + return true + } + } + return false +} + +func generateTOTP(secret string, counter int64) string { + key := []byte(secret) + var buf [8]byte + binary.BigEndian.PutUint64(buf[:], uint64(counter)) + h := hmac.New(sha1.New, key) + h.Write(buf[:]) + hash := h.Sum(nil) + offset := hash[len(hash)-1] & 0x0f + truncated := + (binary.BigEndian.Uint32(hash[offset:offset+4]) & 0x7fffffff) % 1000000 + return fmt.Sprintf("%06d", truncated) +} + +func generateSecret(ip, user string) string { + src := ip + ":" + user + ":" + config.Auth.Secret + hash := sha256.Sum256([]byte(src)) + return hex.EncodeToString(hash[:]) +} + +func verifyOTP(secret, pass string) (bool, int64) { + now := time.Now().UTC().Unix() + skew := config.Auth.DynamicSkew + period := config.Auth.DynamicPeriod + + for i := -skew; i <= skew; i++ { + counter := (now / period) + int64(i) + code := generateTOTP(secret, counter) + if code == pass { + return true, counter + } + } + return false, 0 +} + +func GeneratePassword(ip, user string) string { + src := ip + user + config.Auth.Secret + hash := sha256.New() + hash.Write([]byte(src)) + hashedSrc := hash.Sum(nil) + hashedSrcHex := hex.EncodeToString(hashedSrc) + return hashedSrcHex +} diff --git a/vsock.go b/vsock.go index 3d98cb8..fabdbcf 100644 --- a/vsock.go +++ b/vsock.go @@ -27,7 +27,7 @@ func (tr *vsockTransporter) Dial(addr string, options ...DialOption) (net.Conn, } return vsock.Dial(vAddr.ContextID, vAddr.Port, nil) } - return opts.Chain.Dial(nil, addr) + return opts.Chain.Dial(addr) } func parseUint32(s string) (uint32, error) { diff --git a/ws.go b/ws.go index 51d6dda..9dc8f0d 100644 --- a/ws.go +++ b/ws.go @@ -104,7 +104,7 @@ func (tr *mwsTransporter) Dial(addr string, options ...DialOption) (conn net.Con if opts.Chain == nil { conn, err = net.DialTimeout("tcp", addr, timeout) } else { - conn, err = opts.Chain.Dial(nil, addr) + conn, err = opts.Chain.Dial(addr) } if err != nil { return @@ -261,7 +261,7 @@ func (tr *mwssTransporter) Dial(addr string, options ...DialOption) (conn net.Co if opts.Chain == nil { conn, err = net.DialTimeout("tcp", addr, timeout) } else { - conn, err = opts.Chain.Dial(nil, addr) + conn, err = opts.Chain.Dial(addr) } if err != nil { return