mirror of https://github.com/ginuerzh/gost
31 changed files with 492 additions and 187 deletions
@ -1,3 +1,6 @@ |
|||
# period for live reloading |
|||
reload 3s |
|||
|
|||
# username password |
|||
|
|||
$test.admin$ $123456$ |
|||
@ -0,0 +1,155 @@ |
|||
package gost |
|||
|
|||
import ( |
|||
"bufio" |
|||
"io" |
|||
"strings" |
|||
"sync" |
|||
"time" |
|||
) |
|||
|
|||
// Authenticator is an interface for user authentication.
|
|||
type Authenticator interface { |
|||
Authenticate(user, password string) bool |
|||
} |
|||
|
|||
// LocalAuthenticator is an Authenticator that authenticates client by local key-value pairs.
|
|||
type LocalAuthenticator struct { |
|||
kvs map[string]string |
|||
period time.Duration |
|||
stopped chan struct{} |
|||
mux sync.RWMutex |
|||
} |
|||
|
|||
// NewLocalAuthenticator creates an Authenticator that authenticates client by local infos.
|
|||
func NewLocalAuthenticator(kvs map[string]string) *LocalAuthenticator { |
|||
return &LocalAuthenticator{ |
|||
kvs: kvs, |
|||
stopped: make(chan struct{}), |
|||
} |
|||
} |
|||
|
|||
// Authenticate checks the validity of the provided user-password pair.
|
|||
func (au *LocalAuthenticator) Authenticate(user, password string) bool { |
|||
if au == nil { |
|||
return true |
|||
} |
|||
|
|||
au.mux.RLock() |
|||
defer au.mux.RUnlock() |
|||
|
|||
if len(au.kvs) == 0 { |
|||
return true |
|||
} |
|||
|
|||
v, ok := au.kvs[user] |
|||
return ok && (v == "" || password == v) |
|||
} |
|||
|
|||
// Add adds a key-value pair to the Authenticator.
|
|||
func (au *LocalAuthenticator) Add(k, v string) { |
|||
au.mux.Lock() |
|||
defer au.mux.Unlock() |
|||
if au.kvs == nil { |
|||
au.kvs = make(map[string]string) |
|||
} |
|||
au.kvs[k] = v |
|||
} |
|||
|
|||
// Reload parses config from r, then live reloads the bypass.
|
|||
func (au *LocalAuthenticator) Reload(r io.Reader) error { |
|||
var period time.Duration |
|||
kvs := make(map[string]string) |
|||
|
|||
if r == nil || au.Stopped() { |
|||
return nil |
|||
} |
|||
|
|||
// splitLine splits a line text by white space.
|
|||
// A line started with '#' will be ignored, otherwise it is valid.
|
|||
split := func(line string) []string { |
|||
if line == "" { |
|||
return nil |
|||
} |
|||
line = strings.Replace(line, "\t", " ", -1) |
|||
line = strings.TrimSpace(line) |
|||
|
|||
if strings.IndexByte(line, '#') == 0 { |
|||
return nil |
|||
} |
|||
|
|||
var ss []string |
|||
for _, s := range strings.Split(line, " ") { |
|||
if s = strings.TrimSpace(s); s != "" { |
|||
ss = append(ss, s) |
|||
} |
|||
} |
|||
return ss |
|||
} |
|||
|
|||
scanner := bufio.NewScanner(r) |
|||
for scanner.Scan() { |
|||
line := scanner.Text() |
|||
ss := split(line) |
|||
if len(ss) == 0 { |
|||
continue |
|||
} |
|||
|
|||
switch ss[0] { |
|||
case "reload": // reload option
|
|||
if len(ss) > 1 { |
|||
period, _ = time.ParseDuration(ss[1]) |
|||
} |
|||
default: |
|||
var k, v string |
|||
k = ss[0] |
|||
if len(ss) > 1 { |
|||
v = ss[1] |
|||
} |
|||
kvs[k] = v |
|||
} |
|||
} |
|||
|
|||
if err := scanner.Err(); err != nil { |
|||
return err |
|||
} |
|||
|
|||
au.mux.Lock() |
|||
defer au.mux.Unlock() |
|||
|
|||
au.period = period |
|||
au.kvs = kvs |
|||
|
|||
return nil |
|||
} |
|||
|
|||
// Period returns the reload period.
|
|||
func (au *LocalAuthenticator) Period() time.Duration { |
|||
if au.Stopped() { |
|||
return -1 |
|||
} |
|||
|
|||
au.mux.RLock() |
|||
defer au.mux.RUnlock() |
|||
|
|||
return au.period |
|||
} |
|||
|
|||
// Stop stops reloading.
|
|||
func (au *LocalAuthenticator) Stop() { |
|||
select { |
|||
case <-au.stopped: |
|||
default: |
|||
close(au.stopped) |
|||
} |
|||
} |
|||
|
|||
// Stopped checks whether the reloader is stopped.
|
|||
func (au *LocalAuthenticator) Stopped() bool { |
|||
select { |
|||
case <-au.stopped: |
|||
return true |
|||
default: |
|||
return false |
|||
} |
|||
} |
|||
@ -0,0 +1,191 @@ |
|||
package gost |
|||
|
|||
import ( |
|||
"bytes" |
|||
"fmt" |
|||
"io" |
|||
"net/url" |
|||
"testing" |
|||
"time" |
|||
) |
|||
|
|||
var localAuthenticatorTests = []struct { |
|||
clientUser *url.Userinfo |
|||
serverUsers []*url.Userinfo |
|||
valid bool |
|||
}{ |
|||
{nil, nil, true}, |
|||
{nil, []*url.Userinfo{url.User("admin")}, false}, |
|||
{nil, []*url.Userinfo{url.UserPassword("", "123456")}, false}, |
|||
{nil, []*url.Userinfo{url.UserPassword("admin", "123456")}, false}, |
|||
|
|||
{url.User("admin"), nil, true}, |
|||
{url.User("admin"), []*url.Userinfo{url.User("admin")}, true}, |
|||
{url.User("admin"), []*url.Userinfo{url.User("test")}, false}, |
|||
{url.User("admin"), []*url.Userinfo{url.UserPassword("test", "123456")}, false}, |
|||
{url.User("admin"), []*url.Userinfo{url.UserPassword("admin", "123456")}, false}, |
|||
{url.User("admin"), []*url.Userinfo{url.UserPassword("admin", "")}, true}, |
|||
{url.User("admin"), []*url.Userinfo{url.UserPassword("", "123456")}, false}, |
|||
|
|||
{url.UserPassword("", ""), nil, true}, |
|||
{url.UserPassword("", "123456"), nil, true}, |
|||
{url.UserPassword("", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, true}, |
|||
{url.UserPassword("", "123456"), []*url.Userinfo{url.UserPassword("admin", "")}, false}, |
|||
{url.UserPassword("", "123456"), []*url.Userinfo{url.UserPassword("admin", "123456")}, false}, |
|||
|
|||
{url.UserPassword("admin", "123456"), nil, true}, |
|||
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.User("admin")}, true}, |
|||
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.User("test")}, false}, |
|||
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("admin", "")}, true}, |
|||
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, false}, |
|||
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("admin", "123")}, false}, |
|||
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("test", "123456")}, false}, |
|||
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("admin", "123456")}, true}, |
|||
|
|||
{url.UserPassword("admin", "123456"), []*url.Userinfo{ |
|||
url.UserPassword("test", "123"), |
|||
url.UserPassword("admin", "123456"), |
|||
}, true}, |
|||
} |
|||
|
|||
func TestLocalAuthenticator(t *testing.T) { |
|||
for i, tc := range localAuthenticatorTests { |
|||
tc := tc |
|||
t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { |
|||
au := NewLocalAuthenticator(nil) |
|||
for _, u := range tc.serverUsers { |
|||
if u != nil { |
|||
p, _ := u.Password() |
|||
au.Add(u.Username(), p) |
|||
} |
|||
} |
|||
|
|||
var u, p string |
|||
if tc.clientUser != nil { |
|||
u = tc.clientUser.Username() |
|||
p, _ = tc.clientUser.Password() |
|||
} |
|||
if au.Authenticate(u, p) != tc.valid { |
|||
t.Error("authenticate result should be", tc.valid) |
|||
} |
|||
}) |
|||
} |
|||
} |
|||
|
|||
var localAuthenticatorReloadTests = []struct { |
|||
r io.Reader |
|||
period time.Duration |
|||
kvs map[string]string |
|||
stopped bool |
|||
}{ |
|||
{ |
|||
r: nil, |
|||
period: 0, |
|||
kvs: nil, |
|||
}, |
|||
{ |
|||
r: bytes.NewBufferString(""), |
|||
period: 0, |
|||
}, |
|||
{ |
|||
r: bytes.NewBufferString("reload 10s"), |
|||
period: 10 * time.Second, |
|||
}, |
|||
{ |
|||
r: bytes.NewBufferString("# reload 10s\n"), |
|||
}, |
|||
{ |
|||
r: bytes.NewBufferString("reload 10s\n#admin"), |
|||
period: 10 * time.Second, |
|||
}, |
|||
{ |
|||
r: bytes.NewBufferString("reload 10s\nadmin"), |
|||
period: 10 * time.Second, |
|||
kvs: map[string]string{ |
|||
"admin": "", |
|||
}, |
|||
}, |
|||
{ |
|||
r: bytes.NewBufferString("# reload 10s\nadmin"), |
|||
kvs: map[string]string{ |
|||
"admin": "", |
|||
}, |
|||
}, |
|||
{ |
|||
r: bytes.NewBufferString("# reload 10s\nadmin #123456"), |
|||
kvs: map[string]string{ |
|||
"admin": "#123456", |
|||
}, |
|||
stopped: true, |
|||
}, |
|||
{ |
|||
r: bytes.NewBufferString("admin \t #123456\n\n\ntest \t 123456"), |
|||
kvs: map[string]string{ |
|||
"admin": "#123456", |
|||
"test": "123456", |
|||
}, |
|||
stopped: true, |
|||
}, |
|||
{ |
|||
r: bytes.NewBufferString(` |
|||
$test.admin$ $123456$ |
|||
@test.admin@ @123456@ |
|||
test.admin# #123456# |
|||
test.admin\admin 123456 |
|||
`), |
|||
kvs: map[string]string{ |
|||
"$test.admin$": "$123456$", |
|||
"@test.admin@": "@123456@", |
|||
"test.admin#": "#123456#", |
|||
"test.admin\\admin": "123456", |
|||
}, |
|||
stopped: true, |
|||
}, |
|||
} |
|||
|
|||
func TestLocalAuthenticatorReload(t *testing.T) { |
|||
isEquals := func(a, b map[string]string) bool { |
|||
if len(a) == 0 && len(b) == 0 { |
|||
return true |
|||
} |
|||
if len(a) != len(b) { |
|||
return false |
|||
} |
|||
|
|||
for k, v := range a { |
|||
if b[k] != v { |
|||
return false |
|||
} |
|||
} |
|||
return true |
|||
} |
|||
for i, tc := range localAuthenticatorReloadTests { |
|||
tc := tc |
|||
t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { |
|||
au := NewLocalAuthenticator(nil) |
|||
|
|||
if err := au.Reload(tc.r); err != nil { |
|||
t.Error(err) |
|||
} |
|||
if au.Period() != tc.period { |
|||
t.Errorf("#%d test failed: period value should be %v, got %v", |
|||
i, tc.period, au.Period()) |
|||
} |
|||
if !isEquals(au.kvs, tc.kvs) { |
|||
t.Errorf("#%d test failed: %v, %s", i, au.kvs, tc.kvs) |
|||
} |
|||
|
|||
if tc.stopped { |
|||
au.Stop() |
|||
if au.Period() >= 0 { |
|||
t.Errorf("period of the stopped reloader should be minus value") |
|||
} |
|||
au.Stop() |
|||
} |
|||
if au.Stopped() != tc.stopped { |
|||
t.Errorf("#%d test failed: stopped value should be %v, got %v", |
|||
i, tc.stopped, au.Stopped()) |
|||
} |
|||
}) |
|||
} |
|||
} |
|||
@ -1 +0,0 @@ |
|||
Hello World! |
|||
Loading…
Reference in new issue