mirror of https://github.com/ginuerzh/gost
committed by
GitHub
8 changed files with 525 additions and 28 deletions
@ -0,0 +1,43 @@ |
|||||
|
package gost |
||||
|
|
||||
|
import ( |
||||
|
"testing" |
||||
|
|
||||
|
"github.com/stretchr/testify/assert" |
||||
|
) |
||||
|
|
||||
|
func TestNodeDefaultWhitelist(t *testing.T) { |
||||
|
assert := assert.New(t) |
||||
|
|
||||
|
node, _ := ParseProxyNode("http2://localhost:8000") |
||||
|
|
||||
|
assert.True(node.Can("connect", "google.pl:80")) |
||||
|
assert.True(node.Can("connect", "google.pl:443")) |
||||
|
assert.True(node.Can("connect", "google.pl:22")) |
||||
|
assert.True(node.Can("bind", "google.pl:80")) |
||||
|
assert.True(node.Can("bind", "google.com:80")) |
||||
|
} |
||||
|
|
||||
|
func TestNodeWhitelist(t *testing.T) { |
||||
|
assert := assert.New(t) |
||||
|
|
||||
|
node, _ := ParseProxyNode("http2://localhost:8000?whitelist=connect:google.pl:80,443") |
||||
|
|
||||
|
assert.True(node.Can("connect", "google.pl:80")) |
||||
|
assert.True(node.Can("connect", "google.pl:443")) |
||||
|
assert.False(node.Can("connect", "google.pl:22")) |
||||
|
assert.False(node.Can("bind", "google.pl:80")) |
||||
|
assert.False(node.Can("bind", "google.com:80")) |
||||
|
} |
||||
|
|
||||
|
func TestNodeBlacklist(t *testing.T) { |
||||
|
assert := assert.New(t) |
||||
|
|
||||
|
node, _ := ParseProxyNode("http2://localhost:8000?blacklist=connect:google.pl:80,443") |
||||
|
|
||||
|
assert.False(node.Can("connect", "google.pl:80")) |
||||
|
assert.False(node.Can("connect", "google.pl:443")) |
||||
|
assert.True(node.Can("connect", "google.pl:22")) |
||||
|
assert.True(node.Can("bind", "google.pl:80")) |
||||
|
assert.True(node.Can("bind", "google.com:80")) |
||||
|
} |
||||
@ -0,0 +1,185 @@ |
|||||
|
package gost |
||||
|
|
||||
|
import ( |
||||
|
"errors" |
||||
|
"fmt" |
||||
|
"strconv" |
||||
|
"strings" |
||||
|
|
||||
|
glob "github.com/ryanuber/go-glob" |
||||
|
) |
||||
|
|
||||
|
type PortRange struct { |
||||
|
Min, Max int |
||||
|
} |
||||
|
|
||||
|
type PortSet []PortRange |
||||
|
|
||||
|
type StringSet []string |
||||
|
|
||||
|
type Permission struct { |
||||
|
Actions StringSet |
||||
|
Hosts StringSet |
||||
|
Ports PortSet |
||||
|
} |
||||
|
|
||||
|
type Permissions []Permission |
||||
|
|
||||
|
func minint(x, y int) int { |
||||
|
if x < y { |
||||
|
return x |
||||
|
} |
||||
|
return y |
||||
|
} |
||||
|
|
||||
|
func maxint(x, y int) int { |
||||
|
if x > y { |
||||
|
return x |
||||
|
} |
||||
|
return y |
||||
|
} |
||||
|
|
||||
|
func (ir *PortRange) Contains(value int) bool { |
||||
|
return value >= ir.Min && value <= ir.Max |
||||
|
} |
||||
|
|
||||
|
func ParsePortRange(s string) (*PortRange, error) { |
||||
|
if s == "*" { |
||||
|
return &PortRange{Min: 0, Max: 65535}, nil |
||||
|
} |
||||
|
|
||||
|
minmax := strings.Split(s, "-") |
||||
|
switch len(minmax) { |
||||
|
case 1: |
||||
|
port, err := strconv.Atoi(s) |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
if port < 0 || port > 65535 { |
||||
|
return nil, fmt.Errorf("invalid port: %s", s) |
||||
|
} |
||||
|
return &PortRange{Min: port, Max: port}, nil |
||||
|
case 2: |
||||
|
min, err := strconv.Atoi(minmax[0]) |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
max, err := strconv.Atoi(minmax[1]) |
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
|
||||
|
realmin := maxint(0, minint(min, max)) |
||||
|
realmax := minint(65535, maxint(min, max)) |
||||
|
|
||||
|
return &PortRange{Min: realmin, Max: realmax}, nil |
||||
|
default: |
||||
|
return nil, fmt.Errorf("invalid range: %s", s) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func (ps *PortSet) Contains(value int) bool { |
||||
|
for _, portRange := range *ps { |
||||
|
if portRange.Contains(value) { |
||||
|
return true |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
return false |
||||
|
} |
||||
|
|
||||
|
func ParsePortSet(s string) (*PortSet, error) { |
||||
|
ps := &PortSet{} |
||||
|
|
||||
|
if s == "" { |
||||
|
return nil, errors.New("must specify at least one port") |
||||
|
} |
||||
|
|
||||
|
ranges := strings.Split(s, ",") |
||||
|
|
||||
|
for _, r := range ranges { |
||||
|
portRange, err := ParsePortRange(r) |
||||
|
|
||||
|
if err != nil { |
||||
|
return nil, err |
||||
|
} |
||||
|
|
||||
|
*ps = append(*ps, *portRange) |
||||
|
} |
||||
|
|
||||
|
return ps, nil |
||||
|
} |
||||
|
|
||||
|
func (ss *StringSet) Contains(subj string) bool { |
||||
|
for _, s := range *ss { |
||||
|
if glob.Glob(s, subj) { |
||||
|
return true |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
return false |
||||
|
} |
||||
|
|
||||
|
func ParseStringSet(s string) (*StringSet, error) { |
||||
|
ss := &StringSet{} |
||||
|
if s == "" { |
||||
|
return nil, errors.New("cannot be empty") |
||||
|
} |
||||
|
|
||||
|
*ss = strings.Split(s, ",") |
||||
|
|
||||
|
return ss, nil |
||||
|
} |
||||
|
|
||||
|
func (ps *Permissions) Can(action string, host string, port int) bool { |
||||
|
for _, p := range *ps { |
||||
|
if p.Actions.Contains(action) && p.Hosts.Contains(host) && p.Ports.Contains(port) { |
||||
|
return true |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
return false |
||||
|
} |
||||
|
|
||||
|
func ParsePermissions(s string) (*Permissions, error) { |
||||
|
ps := &Permissions{} |
||||
|
|
||||
|
if s == "" { |
||||
|
return &Permissions{}, nil |
||||
|
} |
||||
|
|
||||
|
perms := strings.Split(s, " ") |
||||
|
|
||||
|
for _, perm := range perms { |
||||
|
parts := strings.Split(perm, ":") |
||||
|
|
||||
|
switch len(parts) { |
||||
|
case 3: |
||||
|
actions, err := ParseStringSet(parts[0]) |
||||
|
|
||||
|
if err != nil { |
||||
|
return nil, fmt.Errorf("action list must look like connect,bind given: %s", parts[0]) |
||||
|
} |
||||
|
|
||||
|
hosts, err := ParseStringSet(parts[1]) |
||||
|
|
||||
|
if err != nil { |
||||
|
return nil, fmt.Errorf("hosts list must look like google.pl,*.google.com given: %s", parts[1]) |
||||
|
} |
||||
|
|
||||
|
ports, err := ParsePortSet(parts[2]) |
||||
|
|
||||
|
if err != nil { |
||||
|
return nil, fmt.Errorf("ports list must look like 80,8000-9000, given: %s", parts[2]) |
||||
|
} |
||||
|
|
||||
|
permission := Permission{Actions: *actions, Hosts: *hosts, Ports: *ports} |
||||
|
|
||||
|
*ps = append(*ps, permission) |
||||
|
default: |
||||
|
return nil, fmt.Errorf("permission must have format [actions]:[hosts]:[ports] given: %s", perm) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
return ps, nil |
||||
|
} |
||||
@ -0,0 +1,152 @@ |
|||||
|
package gost |
||||
|
|
||||
|
import ( |
||||
|
"fmt" |
||||
|
"testing" |
||||
|
) |
||||
|
|
||||
|
var portRangeTests = []struct { |
||||
|
in string |
||||
|
out *PortRange |
||||
|
}{ |
||||
|
{"1", &PortRange{Min: 1, Max: 1}}, |
||||
|
{"1-3", &PortRange{Min: 1, Max: 3}}, |
||||
|
{"3-1", &PortRange{Min: 1, Max: 3}}, |
||||
|
{"0-100000", &PortRange{Min: 0, Max: 65535}}, |
||||
|
{"*", &PortRange{Min: 0, Max: 65535}}, |
||||
|
} |
||||
|
|
||||
|
var stringSetTests = []struct { |
||||
|
in string |
||||
|
out *StringSet |
||||
|
}{ |
||||
|
{"*", &StringSet{"*"}}, |
||||
|
{"google.pl,google.com", &StringSet{"google.pl", "google.com"}}, |
||||
|
} |
||||
|
|
||||
|
var portSetTests = []struct { |
||||
|
in string |
||||
|
out *PortSet |
||||
|
}{ |
||||
|
{"1,3", &PortSet{PortRange{Min: 1, Max: 1}, PortRange{Min: 3, Max: 3}}}, |
||||
|
{"1-3,7-5", &PortSet{PortRange{Min: 1, Max: 3}, PortRange{Min: 5, Max: 7}}}, |
||||
|
{"0-100000", &PortSet{PortRange{Min: 0, Max: 65535}}}, |
||||
|
{"*", &PortSet{PortRange{Min: 0, Max: 65535}}}, |
||||
|
} |
||||
|
|
||||
|
var permissionsTests = []struct { |
||||
|
in string |
||||
|
out *Permissions |
||||
|
}{ |
||||
|
{"", &Permissions{}}, |
||||
|
{"*:*:*", &Permissions{ |
||||
|
Permission{ |
||||
|
Actions: StringSet{"*"}, |
||||
|
Hosts: StringSet{"*"}, |
||||
|
Ports: PortSet{PortRange{Min: 0, Max: 65535}}, |
||||
|
}, |
||||
|
}}, |
||||
|
{"bind:127.0.0.1,localhost:80,443,8000-8100 connect:*.google.pl:80,443", &Permissions{ |
||||
|
Permission{ |
||||
|
Actions: StringSet{"bind"}, |
||||
|
Hosts: StringSet{"127.0.0.1", "localhost"}, |
||||
|
Ports: PortSet{ |
||||
|
PortRange{Min: 80, Max: 80}, |
||||
|
PortRange{Min: 443, Max: 443}, |
||||
|
PortRange{Min: 8000, Max: 8100}, |
||||
|
}, |
||||
|
}, |
||||
|
Permission{ |
||||
|
Actions: StringSet{"connect"}, |
||||
|
Hosts: StringSet{"*.google.pl"}, |
||||
|
Ports: PortSet{ |
||||
|
PortRange{Min: 80, Max: 80}, |
||||
|
PortRange{Min: 443, Max: 443}, |
||||
|
}, |
||||
|
}, |
||||
|
}}, |
||||
|
} |
||||
|
|
||||
|
func TestPortRangeParse(t *testing.T) { |
||||
|
for _, test := range portRangeTests { |
||||
|
actual, err := ParsePortRange(test.in) |
||||
|
if err != nil { |
||||
|
t.Errorf("ParsePortRange(%q) returned error: %v", test.in, err) |
||||
|
} else if *actual != *test.out { |
||||
|
t.Errorf("ParsePortRange(%q): got %v, want %v", test.in, actual, test.out) |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestPortRangeContains(t *testing.T) { |
||||
|
actual, _ := ParsePortRange("5-10") |
||||
|
|
||||
|
if !actual.Contains(5) || !actual.Contains(7) || !actual.Contains(10) { |
||||
|
t.Errorf("5-10 should contain 5, 7 and 10") |
||||
|
} |
||||
|
|
||||
|
if actual.Contains(4) || actual.Contains(11) { |
||||
|
t.Errorf("5-10 should not contain 4, 11") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestStringSetParse(t *testing.T) { |
||||
|
for _, test := range stringSetTests { |
||||
|
actual, err := ParseStringSet(test.in) |
||||
|
if err != nil { |
||||
|
t.Errorf("ParseStringSet(%q) returned error: %v", test.in, err) |
||||
|
} else if fmt.Sprintln(actual) != fmt.Sprintln(test.out) { |
||||
|
t.Errorf("ParseStringSet(%q): got %v, want %v", test.in, actual, test.out) |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestStringSetContains(t *testing.T) { |
||||
|
ss, _ := ParseStringSet("google.pl,*.google.com") |
||||
|
|
||||
|
if !ss.Contains("google.pl") || !ss.Contains("www.google.com") { |
||||
|
t.Errorf("google.pl,*.google.com should contain google.pl and www.google.com") |
||||
|
} |
||||
|
|
||||
|
if ss.Contains("www.google.pl") || ss.Contains("foobar.com") { |
||||
|
t.Errorf("google.pl,*.google.com shound not contain www.google.pl and foobar.com") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestPortSetParse(t *testing.T) { |
||||
|
for _, test := range portSetTests { |
||||
|
actual, err := ParsePortSet(test.in) |
||||
|
if err != nil { |
||||
|
t.Errorf("ParsePortRange(%q) returned error: %v", test.in, err) |
||||
|
} else if fmt.Sprintln(actual) != fmt.Sprintln(test.out) { |
||||
|
t.Errorf("ParsePortRange(%q): got %v, want %v", test.in, actual, test.out) |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestPortSetContains(t *testing.T) { |
||||
|
actual, _ := ParsePortSet("5-10,20-30") |
||||
|
|
||||
|
if !actual.Contains(5) || !actual.Contains(7) || !actual.Contains(10) { |
||||
|
t.Errorf("5-10,20-30 should contain 5, 7 and 10") |
||||
|
} |
||||
|
|
||||
|
if !actual.Contains(20) || !actual.Contains(27) || !actual.Contains(30) { |
||||
|
t.Errorf("5-10,20-30 should contain 20, 27 and 30") |
||||
|
} |
||||
|
|
||||
|
if actual.Contains(4) || actual.Contains(11) || actual.Contains(31) { |
||||
|
t.Errorf("5-10,20-30 should not contain 4, 11, 31") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func TestPermissionsParse(t *testing.T) { |
||||
|
for _, test := range permissionsTests { |
||||
|
actual, err := ParsePermissions(test.in) |
||||
|
if err != nil { |
||||
|
t.Errorf("ParsePermissions(%q) returned error: %v", test.in, err) |
||||
|
} else if fmt.Sprintln(actual) != fmt.Sprintln(test.out) { |
||||
|
t.Errorf("ParsePermissions(%q): got %v, want %v", test.in, actual, test.out) |
||||
|
} |
||||
|
} |
||||
|
} |
||||
Loading…
Reference in new issue