mirror of https://github.com/ginuerzh/gost
10 changed files with 420 additions and 2 deletions
@ -0,0 +1,259 @@ |
|||||
|
package gost |
||||
|
|
||||
|
import ( |
||||
|
"bufio" |
||||
|
"errors" |
||||
|
"io" |
||||
|
"strconv" |
||||
|
"strings" |
||||
|
"sync" |
||||
|
"sync/atomic" |
||||
|
"time" |
||||
|
) |
||||
|
|
||||
|
type Limiter interface { |
||||
|
CheckRate(key string, checkConcurrent bool) (func(), bool) |
||||
|
} |
||||
|
|
||||
|
func NewLocalLimiter(user string, cfg string) (*LocalLimiter, error) { |
||||
|
limiter := LocalLimiter{ |
||||
|
buckets: map[string]*limiterBucket{}, |
||||
|
concurrent: map[string]chan bool{}, |
||||
|
stopped: make(chan struct{}), |
||||
|
} |
||||
|
if cfg == "" || user == "" { |
||||
|
return &limiter, nil |
||||
|
} |
||||
|
if err := limiter.AddRule(user, cfg); err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
return &limiter, nil |
||||
|
} |
||||
|
|
||||
|
// Token Bucket
|
||||
|
type limiterBucket struct { |
||||
|
max int64 |
||||
|
cur int64 |
||||
|
duration int64 |
||||
|
batch int64 |
||||
|
} |
||||
|
|
||||
|
type LocalLimiter struct { |
||||
|
buckets map[string]*limiterBucket |
||||
|
concurrent map[string]chan bool |
||||
|
mux sync.RWMutex |
||||
|
stopped chan struct{} |
||||
|
period time.Duration |
||||
|
} |
||||
|
|
||||
|
func (l *LocalLimiter) CheckRate(key string, checkConcurrent bool) (func(), bool) { |
||||
|
if checkConcurrent { |
||||
|
done, ok := l.checkConcurrent(key) |
||||
|
if !ok { |
||||
|
return nil, false |
||||
|
} |
||||
|
if t := l.getToken(key); !t { |
||||
|
done() |
||||
|
return nil, false |
||||
|
} |
||||
|
return done, true |
||||
|
} else { |
||||
|
if t := l.getToken(key); !t { |
||||
|
return nil, false |
||||
|
} |
||||
|
return nil, true |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func (l *LocalLimiter) AddRule(user string, cfg string) error { |
||||
|
if user == "" { |
||||
|
return nil |
||||
|
} |
||||
|
if cfg == "" { |
||||
|
//reload need check old limit exists
|
||||
|
if _, ok := l.buckets[user]; ok { |
||||
|
delete(l.buckets, user) |
||||
|
} |
||||
|
if _, ok := l.concurrent[user]; ok { |
||||
|
delete(l.concurrent, user) |
||||
|
} |
||||
|
return nil |
||||
|
} |
||||
|
args := strings.Split(cfg, ",") |
||||
|
if len(args) < 2 || len(args) > 3 { |
||||
|
return errors.New("parse limiter fail:" + cfg) |
||||
|
} |
||||
|
if len(args) == 2 { |
||||
|
args = append(args, "0") |
||||
|
} |
||||
|
|
||||
|
duration, e1 := strconv.ParseInt(strings.TrimSpace(args[0]), 10, 64) |
||||
|
count, e2 := strconv.ParseInt(strings.TrimSpace(args[1]), 10, 64) |
||||
|
cur, e3 := strconv.ParseInt(strings.TrimSpace(args[2]), 10, 64) |
||||
|
if e1 != nil || e2 != nil || e3 != nil { |
||||
|
return errors.New("parse limiter fail:" + cfg) |
||||
|
} |
||||
|
// 0 means not limit
|
||||
|
if duration > 0 && count > 0 { |
||||
|
bu := &limiterBucket{ |
||||
|
cur: count * 10, |
||||
|
max: count * 10, |
||||
|
duration: duration * 100, |
||||
|
batch: count, |
||||
|
} |
||||
|
go func() { |
||||
|
for { |
||||
|
time.Sleep(time.Millisecond * time.Duration(bu.duration)) |
||||
|
if bu.cur+bu.batch > bu.max { |
||||
|
bu.cur = bu.max |
||||
|
} else { |
||||
|
atomic.AddInt64(&bu.cur, bu.batch) |
||||
|
} |
||||
|
} |
||||
|
}() |
||||
|
l.buckets[user] = bu |
||||
|
} else { |
||||
|
if _, ok := l.buckets[user]; ok { |
||||
|
delete(l.buckets, user) |
||||
|
} |
||||
|
} |
||||
|
// zero means not limit
|
||||
|
if cur > 0 { |
||||
|
l.concurrent[user] = make(chan bool, cur) |
||||
|
} else { |
||||
|
if _, ok := l.concurrent[user]; ok { |
||||
|
delete(l.concurrent, user) |
||||
|
} |
||||
|
} |
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
// Reload parses config from r, then live reloads the LocalLimiter.
|
||||
|
func (l *LocalLimiter) Reload(r io.Reader) error { |
||||
|
var period time.Duration |
||||
|
kvs := make(map[string]string) |
||||
|
|
||||
|
if r == nil || l.Stopped() { |
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
// splitLine splits a line text by white space.
|
||||
|
// A line started with '#' will be ignored, otherwise it is valid.
|
||||
|
split := func(line string) []string { |
||||
|
if line == "" { |
||||
|
return nil |
||||
|
} |
||||
|
line = strings.Replace(line, "\t", " ", -1) |
||||
|
line = strings.TrimSpace(line) |
||||
|
|
||||
|
if strings.IndexByte(line, '#') == 0 { |
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
var ss []string |
||||
|
for _, s := range strings.Split(line, " ") { |
||||
|
if s = strings.TrimSpace(s); s != "" { |
||||
|
ss = append(ss, s) |
||||
|
} |
||||
|
} |
||||
|
return ss |
||||
|
} |
||||
|
|
||||
|
scanner := bufio.NewScanner(r) |
||||
|
for scanner.Scan() { |
||||
|
line := scanner.Text() |
||||
|
ss := split(line) |
||||
|
if len(ss) == 0 { |
||||
|
continue |
||||
|
} |
||||
|
|
||||
|
switch ss[0] { |
||||
|
case "reload": // reload option
|
||||
|
if len(ss) > 1 { |
||||
|
period, _ = time.ParseDuration(ss[1]) |
||||
|
} |
||||
|
default: |
||||
|
var k, v string |
||||
|
k = ss[0] |
||||
|
if len(ss) > 2 { |
||||
|
v = ss[2] |
||||
|
} |
||||
|
kvs[k] = v |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
if err := scanner.Err(); err != nil { |
||||
|
return err |
||||
|
} |
||||
|
|
||||
|
l.mux.Lock() |
||||
|
defer l.mux.Unlock() |
||||
|
|
||||
|
l.period = period |
||||
|
for user, args := range kvs { |
||||
|
err := l.AddRule(user, args) |
||||
|
if err != nil { |
||||
|
return err |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
return nil |
||||
|
} |
||||
|
|
||||
|
// Period returns the reload period.
|
||||
|
func (l *LocalLimiter) Period() time.Duration { |
||||
|
if l.Stopped() { |
||||
|
return -1 |
||||
|
} |
||||
|
|
||||
|
l.mux.RLock() |
||||
|
defer l.mux.RUnlock() |
||||
|
|
||||
|
return l.period |
||||
|
} |
||||
|
|
||||
|
// Stop stops reloading.
|
||||
|
func (l *LocalLimiter) Stop() { |
||||
|
select { |
||||
|
case <-l.stopped: |
||||
|
default: |
||||
|
close(l.stopped) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// Stopped checks whether the reloader is stopped.
|
||||
|
func (l *LocalLimiter) Stopped() bool { |
||||
|
select { |
||||
|
case <-l.stopped: |
||||
|
return true |
||||
|
default: |
||||
|
return false |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func (l *LocalLimiter) getToken(key string) bool { |
||||
|
b, ok := l.buckets[key] |
||||
|
if !ok || b == nil { |
||||
|
return true |
||||
|
} |
||||
|
if b.cur <= 0 { |
||||
|
return false |
||||
|
} |
||||
|
atomic.AddInt64(&b.cur, -10) |
||||
|
return true |
||||
|
} |
||||
|
|
||||
|
func (l *LocalLimiter) checkConcurrent(key string) (func(), bool) { |
||||
|
c, ok := l.concurrent[key] |
||||
|
if !ok || c == nil { |
||||
|
return func() {}, true |
||||
|
} |
||||
|
select { |
||||
|
case c <- true: |
||||
|
return func() { |
||||
|
<-c |
||||
|
}, true |
||||
|
default: |
||||
|
return nil, false |
||||
|
} |
||||
|
} |
||||
@ -0,0 +1,69 @@ |
|||||
|
package gost |
||||
|
|
||||
|
import ( |
||||
|
"fmt" |
||||
|
"testing" |
||||
|
) |
||||
|
|
||||
|
func TestNewLocalLimiter(t *testing.T) { |
||||
|
items := []struct { |
||||
|
user string |
||||
|
args string |
||||
|
success bool |
||||
|
}{ |
||||
|
{"admin", "10,1", true}, |
||||
|
{"admin", "", true}, |
||||
|
{"admin", "10,1,1", true}, |
||||
|
{"admin", "10", false}, |
||||
|
{"admin", "0,1", true}, |
||||
|
{"admin", "0,1,1", true}, |
||||
|
{"admin", "a,b", false}, |
||||
|
{"", "", true}, |
||||
|
{"", "1,2", true}, |
||||
|
} |
||||
|
for i, item := range items { |
||||
|
item := item |
||||
|
t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { |
||||
|
_, err := NewLocalLimiter(item.user, item.args) |
||||
|
if (err == nil) != item.success { |
||||
|
t.Error("test NewLocalLimiter fail", item.user, item.args) |
||||
|
} |
||||
|
}) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestCheckRate(t *testing.T) { |
||||
|
items := []struct { |
||||
|
user string |
||||
|
args string |
||||
|
testUser string |
||||
|
checkCount int |
||||
|
shouldSuccessCount int |
||||
|
}{ |
||||
|
{"admin", "10,3", "admin", 10, 3}, |
||||
|
{"admin", "10,3,0", "admin", 10, 3}, |
||||
|
{"admin", "10,3,2", "admin", 10, 2}, |
||||
|
{"admin", "0,0", "admin", 10, 10}, |
||||
|
{"admin", "10,3,5", "admin", 10, 3}, |
||||
|
{"admin", "10,3,5", "admin22", 10, 10}, |
||||
|
{"admin", "0,0,5", "admin", 10, 5}, |
||||
|
} |
||||
|
for i, item := range items { |
||||
|
item := item |
||||
|
t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { |
||||
|
l, err := NewLocalLimiter(item.user, item.args) |
||||
|
if err != nil { |
||||
|
t.Error("test NewLocalLimiter fail", item.user, item.args) |
||||
|
} |
||||
|
successCount := 0 |
||||
|
for j := 0; j < item.checkCount; j++ { |
||||
|
if _, ok := l.CheckRate(item.testUser, true); ok { |
||||
|
successCount++ |
||||
|
} |
||||
|
} |
||||
|
if successCount != item.shouldSuccessCount { |
||||
|
t.Error("test localLimiter fail", item) |
||||
|
} |
||||
|
}) |
||||
|
} |
||||
|
} |
||||
Loading…
Reference in new issue