mirror of https://github.com/ginuerzh/gost
31 changed files with 492 additions and 187 deletions
@ -1,3 +1,6 @@ |
|||||
|
# period for live reloading |
||||
|
reload 3s |
||||
|
|
||||
# username password |
# username password |
||||
|
|
||||
$test.admin$ $123456$ |
$test.admin$ $123456$ |
||||
@ -0,0 +1,155 @@ |
|||||
|
package gost |
||||
|
|
||||
|
import ( |
||||
|
"bufio" |
||||
|
"io" |
||||
|
"strings" |
||||
|
"sync" |
||||
|
"time" |
||||
|
) |
||||
|
|
||||
|
// Authenticator is an interface for user authentication.
|
||||
|
type Authenticator interface { |
||||
|
Authenticate(user, password string) bool |
||||
|
} |
||||
|
|
||||
|
// LocalAuthenticator is an Authenticator that authenticates client by local key-value pairs.
|
||||
|
type LocalAuthenticator struct { |
||||
|
kvs map[string]string |
||||
|
period time.Duration |
||||
|
stopped chan struct{} |
||||
|
mux sync.RWMutex |
||||
|
} |
||||
|
|
||||
|
// NewLocalAuthenticator creates an Authenticator that authenticates client by local infos.
|
||||
|
func NewLocalAuthenticator(kvs map[string]string) *LocalAuthenticator { |
||||
|
return &LocalAuthenticator{ |
||||
|
kvs: kvs, |
||||
|
stopped: make(chan struct{}), |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// Authenticate checks the validity of the provided user-password pair.
|
||||
|
func (au *LocalAuthenticator) Authenticate(user, password string) bool { |
||||
|
if au == nil { |
||||
|
return true |
||||
|
} |
||||
|
|
||||
|
au.mux.RLock() |
||||
|
defer au.mux.RUnlock() |
||||
|
|
||||
|
if len(au.kvs) == 0 { |
||||
|
return true |
||||
|
} |
||||
|
|
||||
|
v, ok := au.kvs[user] |
||||
|
return ok && (v == "" || password == v) |
||||
|
} |
||||
|
|
||||
|
// Add adds a key-value pair to the Authenticator.
|
||||
|
func (au *LocalAuthenticator) Add(k, v string) { |
||||
|
au.mux.Lock() |
||||
|
defer au.mux.Unlock() |
||||
|
if au.kvs == nil { |
||||
|
au.kvs = make(map[string]string) |
||||
|
} |
||||
|
au.kvs[k] = v |
||||
|
} |
||||
|
|
||||
|
// Reload parses config from r, then live reloads the bypass.
|
||||
|
func (au *LocalAuthenticator) Reload(r io.Reader) error { |
||||
|
var period time.Duration |
||||
|
kvs := make(map[string]string) |
||||
|
|
||||
|
if r == nil || au.Stopped() { |
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
// splitLine splits a line text by white space.
|
||||
|
// A line started with '#' will be ignored, otherwise it is valid.
|
||||
|
split := func(line string) []string { |
||||
|
if line == "" { |
||||
|
return nil |
||||
|
} |
||||
|
line = strings.Replace(line, "\t", " ", -1) |
||||
|
line = strings.TrimSpace(line) |
||||
|
|
||||
|
if strings.IndexByte(line, '#') == 0 { |
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
var ss []string |
||||
|
for _, s := range strings.Split(line, " ") { |
||||
|
if s = strings.TrimSpace(s); s != "" { |
||||
|
ss = append(ss, s) |
||||
|
} |
||||
|
} |
||||
|
return ss |
||||
|
} |
||||
|
|
||||
|
scanner := bufio.NewScanner(r) |
||||
|
for scanner.Scan() { |
||||
|
line := scanner.Text() |
||||
|
ss := split(line) |
||||
|
if len(ss) == 0 { |
||||
|
continue |
||||
|
} |
||||
|
|
||||
|
switch ss[0] { |
||||
|
case "reload": // reload option
|
||||
|
if len(ss) > 1 { |
||||
|
period, _ = time.ParseDuration(ss[1]) |
||||
|
} |
||||
|
default: |
||||
|
var k, v string |
||||
|
k = ss[0] |
||||
|
if len(ss) > 1 { |
||||
|
v = ss[1] |
||||
|
} |
||||
|
kvs[k] = v |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
if err := scanner.Err(); err != nil { |
||||
|
return err |
||||
|
} |
||||
|
|
||||
|
au.mux.Lock() |
||||
|
defer au.mux.Unlock() |
||||
|
|
||||
|
au.period = period |
||||
|
au.kvs = kvs |
||||
|
|
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
// Period returns the reload period.
|
||||
|
func (au *LocalAuthenticator) Period() time.Duration { |
||||
|
if au.Stopped() { |
||||
|
return -1 |
||||
|
} |
||||
|
|
||||
|
au.mux.RLock() |
||||
|
defer au.mux.RUnlock() |
||||
|
|
||||
|
return au.period |
||||
|
} |
||||
|
|
||||
|
// Stop stops reloading.
|
||||
|
func (au *LocalAuthenticator) Stop() { |
||||
|
select { |
||||
|
case <-au.stopped: |
||||
|
default: |
||||
|
close(au.stopped) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// Stopped checks whether the reloader is stopped.
|
||||
|
func (au *LocalAuthenticator) Stopped() bool { |
||||
|
select { |
||||
|
case <-au.stopped: |
||||
|
return true |
||||
|
default: |
||||
|
return false |
||||
|
} |
||||
|
} |
||||
@ -0,0 +1,191 @@ |
|||||
|
package gost |
||||
|
|
||||
|
import ( |
||||
|
"bytes" |
||||
|
"fmt" |
||||
|
"io" |
||||
|
"net/url" |
||||
|
"testing" |
||||
|
"time" |
||||
|
) |
||||
|
|
||||
|
var localAuthenticatorTests = []struct { |
||||
|
clientUser *url.Userinfo |
||||
|
serverUsers []*url.Userinfo |
||||
|
valid bool |
||||
|
}{ |
||||
|
{nil, nil, true}, |
||||
|
{nil, []*url.Userinfo{url.User("admin")}, false}, |
||||
|
{nil, []*url.Userinfo{url.UserPassword("", "123456")}, false}, |
||||
|
{nil, []*url.Userinfo{url.UserPassword("admin", "123456")}, false}, |
||||
|
|
||||
|
{url.User("admin"), nil, true}, |
||||
|
{url.User("admin"), []*url.Userinfo{url.User("admin")}, true}, |
||||
|
{url.User("admin"), []*url.Userinfo{url.User("test")}, false}, |
||||
|
{url.User("admin"), []*url.Userinfo{url.UserPassword("test", "123456")}, false}, |
||||
|
{url.User("admin"), []*url.Userinfo{url.UserPassword("admin", "123456")}, false}, |
||||
|
{url.User("admin"), []*url.Userinfo{url.UserPassword("admin", "")}, true}, |
||||
|
{url.User("admin"), []*url.Userinfo{url.UserPassword("", "123456")}, false}, |
||||
|
|
||||
|
{url.UserPassword("", ""), nil, true}, |
||||
|
{url.UserPassword("", "123456"), nil, true}, |
||||
|
{url.UserPassword("", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, true}, |
||||
|
{url.UserPassword("", "123456"), []*url.Userinfo{url.UserPassword("admin", "")}, false}, |
||||
|
{url.UserPassword("", "123456"), []*url.Userinfo{url.UserPassword("admin", "123456")}, false}, |
||||
|
|
||||
|
{url.UserPassword("admin", "123456"), nil, true}, |
||||
|
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.User("admin")}, true}, |
||||
|
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.User("test")}, false}, |
||||
|
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("admin", "")}, true}, |
||||
|
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, false}, |
||||
|
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("admin", "123")}, false}, |
||||
|
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("test", "123456")}, false}, |
||||
|
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("admin", "123456")}, true}, |
||||
|
|
||||
|
{url.UserPassword("admin", "123456"), []*url.Userinfo{ |
||||
|
url.UserPassword("test", "123"), |
||||
|
url.UserPassword("admin", "123456"), |
||||
|
}, true}, |
||||
|
} |
||||
|
|
||||
|
func TestLocalAuthenticator(t *testing.T) { |
||||
|
for i, tc := range localAuthenticatorTests { |
||||
|
tc := tc |
||||
|
t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { |
||||
|
au := NewLocalAuthenticator(nil) |
||||
|
for _, u := range tc.serverUsers { |
||||
|
if u != nil { |
||||
|
p, _ := u.Password() |
||||
|
au.Add(u.Username(), p) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
var u, p string |
||||
|
if tc.clientUser != nil { |
||||
|
u = tc.clientUser.Username() |
||||
|
p, _ = tc.clientUser.Password() |
||||
|
} |
||||
|
if au.Authenticate(u, p) != tc.valid { |
||||
|
t.Error("authenticate result should be", tc.valid) |
||||
|
} |
||||
|
}) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
var localAuthenticatorReloadTests = []struct { |
||||
|
r io.Reader |
||||
|
period time.Duration |
||||
|
kvs map[string]string |
||||
|
stopped bool |
||||
|
}{ |
||||
|
{ |
||||
|
r: nil, |
||||
|
period: 0, |
||||
|
kvs: nil, |
||||
|
}, |
||||
|
{ |
||||
|
r: bytes.NewBufferString(""), |
||||
|
period: 0, |
||||
|
}, |
||||
|
{ |
||||
|
r: bytes.NewBufferString("reload 10s"), |
||||
|
period: 10 * time.Second, |
||||
|
}, |
||||
|
{ |
||||
|
r: bytes.NewBufferString("# reload 10s\n"), |
||||
|
}, |
||||
|
{ |
||||
|
r: bytes.NewBufferString("reload 10s\n#admin"), |
||||
|
period: 10 * time.Second, |
||||
|
}, |
||||
|
{ |
||||
|
r: bytes.NewBufferString("reload 10s\nadmin"), |
||||
|
period: 10 * time.Second, |
||||
|
kvs: map[string]string{ |
||||
|
"admin": "", |
||||
|
}, |
||||
|
}, |
||||
|
{ |
||||
|
r: bytes.NewBufferString("# reload 10s\nadmin"), |
||||
|
kvs: map[string]string{ |
||||
|
"admin": "", |
||||
|
}, |
||||
|
}, |
||||
|
{ |
||||
|
r: bytes.NewBufferString("# reload 10s\nadmin #123456"), |
||||
|
kvs: map[string]string{ |
||||
|
"admin": "#123456", |
||||
|
}, |
||||
|
stopped: true, |
||||
|
}, |
||||
|
{ |
||||
|
r: bytes.NewBufferString("admin \t #123456\n\n\ntest \t 123456"), |
||||
|
kvs: map[string]string{ |
||||
|
"admin": "#123456", |
||||
|
"test": "123456", |
||||
|
}, |
||||
|
stopped: true, |
||||
|
}, |
||||
|
{ |
||||
|
r: bytes.NewBufferString(` |
||||
|
$test.admin$ $123456$ |
||||
|
@test.admin@ @123456@ |
||||
|
test.admin# #123456# |
||||
|
test.admin\admin 123456 |
||||
|
`), |
||||
|
kvs: map[string]string{ |
||||
|
"$test.admin$": "$123456$", |
||||
|
"@test.admin@": "@123456@", |
||||
|
"test.admin#": "#123456#", |
||||
|
"test.admin\\admin": "123456", |
||||
|
}, |
||||
|
stopped: true, |
||||
|
}, |
||||
|
} |
||||
|
|
||||
|
func TestLocalAuthenticatorReload(t *testing.T) { |
||||
|
isEquals := func(a, b map[string]string) bool { |
||||
|
if len(a) == 0 && len(b) == 0 { |
||||
|
return true |
||||
|
} |
||||
|
if len(a) != len(b) { |
||||
|
return false |
||||
|
} |
||||
|
|
||||
|
for k, v := range a { |
||||
|
if b[k] != v { |
||||
|
return false |
||||
|
} |
||||
|
} |
||||
|
return true |
||||
|
} |
||||
|
for i, tc := range localAuthenticatorReloadTests { |
||||
|
tc := tc |
||||
|
t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { |
||||
|
au := NewLocalAuthenticator(nil) |
||||
|
|
||||
|
if err := au.Reload(tc.r); err != nil { |
||||
|
t.Error(err) |
||||
|
} |
||||
|
if au.Period() != tc.period { |
||||
|
t.Errorf("#%d test failed: period value should be %v, got %v", |
||||
|
i, tc.period, au.Period()) |
||||
|
} |
||||
|
if !isEquals(au.kvs, tc.kvs) { |
||||
|
t.Errorf("#%d test failed: %v, %s", i, au.kvs, tc.kvs) |
||||
|
} |
||||
|
|
||||
|
if tc.stopped { |
||||
|
au.Stop() |
||||
|
if au.Period() >= 0 { |
||||
|
t.Errorf("period of the stopped reloader should be minus value") |
||||
|
} |
||||
|
au.Stop() |
||||
|
} |
||||
|
if au.Stopped() != tc.stopped { |
||||
|
t.Errorf("#%d test failed: stopped value should be %v, got %v", |
||||
|
i, tc.stopped, au.Stopped()) |
||||
|
} |
||||
|
}) |
||||
|
} |
||||
|
} |
||||
@ -1 +0,0 @@ |
|||||
Hello World! |
|
||||
Loading…
Reference in new issue