mirror of https://github.com/ginuerzh/gost
7 changed files with 511 additions and 25 deletions
@ -0,0 +1,43 @@ |
|||
package gost |
|||
|
|||
import ( |
|||
"testing" |
|||
|
|||
"github.com/stretchr/testify/assert" |
|||
) |
|||
|
|||
func TestNodeDefaultWhitelist(t *testing.T) { |
|||
assert := assert.New(t) |
|||
|
|||
node, _ := ParseProxyNode("http2://localhost:8000") |
|||
|
|||
assert.True(node.Can("connect", "google.pl:80")) |
|||
assert.True(node.Can("connect", "google.pl:443")) |
|||
assert.True(node.Can("connect", "google.pl:22")) |
|||
assert.True(node.Can("bind", "google.pl:80")) |
|||
assert.True(node.Can("bind", "google.com:80")) |
|||
} |
|||
|
|||
func TestNodeWhitelist(t *testing.T) { |
|||
assert := assert.New(t) |
|||
|
|||
node, _ := ParseProxyNode("http2://localhost:8000?whitelist=connect:google.pl:80,443") |
|||
|
|||
assert.True(node.Can("connect", "google.pl:80")) |
|||
assert.True(node.Can("connect", "google.pl:443")) |
|||
assert.False(node.Can("connect", "google.pl:22")) |
|||
assert.False(node.Can("bind", "google.pl:80")) |
|||
assert.False(node.Can("bind", "google.com:80")) |
|||
} |
|||
|
|||
func TestNodeBlacklist(t *testing.T) { |
|||
assert := assert.New(t) |
|||
|
|||
node, _ := ParseProxyNode("http2://localhost:8000?blacklist=connect:google.pl:80,443") |
|||
|
|||
assert.False(node.Can("connect", "google.pl:80")) |
|||
assert.False(node.Can("connect", "google.pl:443")) |
|||
assert.True(node.Can("connect", "google.pl:22")) |
|||
assert.True(node.Can("bind", "google.pl:80")) |
|||
assert.True(node.Can("bind", "google.com:80")) |
|||
} |
|||
@ -0,0 +1,185 @@ |
|||
package gost |
|||
|
|||
import ( |
|||
"errors" |
|||
"fmt" |
|||
"strconv" |
|||
"strings" |
|||
|
|||
glob "github.com/ryanuber/go-glob" |
|||
) |
|||
|
|||
type PortRange struct { |
|||
Min, Max int |
|||
} |
|||
|
|||
type PortSet []PortRange |
|||
|
|||
type StringSet []string |
|||
|
|||
type Permission struct { |
|||
Actions StringSet |
|||
Hosts StringSet |
|||
Ports PortSet |
|||
} |
|||
|
|||
type Permissions []Permission |
|||
|
|||
func minint(x, y int) int { |
|||
if x < y { |
|||
return x |
|||
} |
|||
return y |
|||
} |
|||
|
|||
func maxint(x, y int) int { |
|||
if x > y { |
|||
return x |
|||
} |
|||
return y |
|||
} |
|||
|
|||
func (ir *PortRange) Contains(value int) bool { |
|||
return value >= ir.Min && value <= ir.Max |
|||
} |
|||
|
|||
func ParsePortRange(s string) (*PortRange, error) { |
|||
if s == "*" { |
|||
return &PortRange{Min: 0, Max: 65535}, nil |
|||
} |
|||
|
|||
minmax := strings.Split(s, "-") |
|||
switch len(minmax) { |
|||
case 1: |
|||
port, err := strconv.Atoi(s) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
if port < 0 || port > 65535 { |
|||
return nil, fmt.Errorf("invalid port: %s", s) |
|||
} |
|||
return &PortRange{Min: port, Max: port}, nil |
|||
case 2: |
|||
min, err := strconv.Atoi(minmax[0]) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
max, err := strconv.Atoi(minmax[1]) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
realmin := maxint(0, minint(min, max)) |
|||
realmax := minint(65535, maxint(min, max)) |
|||
|
|||
return &PortRange{Min: realmin, Max: realmax}, nil |
|||
default: |
|||
return nil, fmt.Errorf("invalid range: %s", s) |
|||
} |
|||
} |
|||
|
|||
func (ps *PortSet) Contains(value int) bool { |
|||
for _, portRange := range *ps { |
|||
if portRange.Contains(value) { |
|||
return true |
|||
} |
|||
} |
|||
|
|||
return false |
|||
} |
|||
|
|||
func ParsePortSet(s string) (*PortSet, error) { |
|||
ps := &PortSet{} |
|||
|
|||
if s == "" { |
|||
return nil, errors.New("must specify at least one port") |
|||
} |
|||
|
|||
ranges := strings.Split(s, ",") |
|||
|
|||
for _, r := range ranges { |
|||
portRange, err := ParsePortRange(r) |
|||
|
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
*ps = append(*ps, *portRange) |
|||
} |
|||
|
|||
return ps, nil |
|||
} |
|||
|
|||
func (ss *StringSet) Contains(subj string) bool { |
|||
for _, s := range *ss { |
|||
if glob.Glob(s, subj) { |
|||
return true |
|||
} |
|||
} |
|||
|
|||
return false |
|||
} |
|||
|
|||
func ParseStringSet(s string) (*StringSet, error) { |
|||
ss := &StringSet{} |
|||
if s == "" { |
|||
return nil, errors.New("cannot be empty") |
|||
} |
|||
|
|||
*ss = strings.Split(s, ",") |
|||
|
|||
return ss, nil |
|||
} |
|||
|
|||
func (ps *Permissions) Can(action string, host string, port int) bool { |
|||
for _, p := range *ps { |
|||
if p.Actions.Contains(action) && p.Hosts.Contains(host) && p.Ports.Contains(port) { |
|||
return true |
|||
} |
|||
} |
|||
|
|||
return false |
|||
} |
|||
|
|||
func ParsePermissions(s string) (*Permissions, error) { |
|||
ps := &Permissions{} |
|||
|
|||
if s == "" { |
|||
return &Permissions{}, nil |
|||
} |
|||
|
|||
perms := strings.Split(s, "+") |
|||
|
|||
for _, perm := range perms { |
|||
parts := strings.Split(perm, ":") |
|||
|
|||
switch len(parts) { |
|||
case 3: |
|||
actions, err := ParseStringSet(parts[0]) |
|||
|
|||
if err != nil { |
|||
return nil, fmt.Errorf("action list must look like connect,bind given: %s", parts[0]) |
|||
} |
|||
|
|||
hosts, err := ParseStringSet(parts[1]) |
|||
|
|||
if err != nil { |
|||
return nil, fmt.Errorf("hosts list must look like google.pl,*.google.com given: %s", parts[1]) |
|||
} |
|||
|
|||
ports, err := ParsePortSet(parts[2]) |
|||
|
|||
if err != nil { |
|||
return nil, fmt.Errorf("ports list must look like 80,8000-9000, given: %s", parts[2]) |
|||
} |
|||
|
|||
permission := Permission{Actions: *actions, Hosts: *hosts, Ports: *ports} |
|||
|
|||
*ps = append(*ps, permission) |
|||
default: |
|||
return nil, fmt.Errorf("permission must have format [actions]:[hosts]:[ports] given: %s", perm) |
|||
} |
|||
} |
|||
|
|||
return ps, nil |
|||
} |
|||
@ -0,0 +1,152 @@ |
|||
package gost |
|||
|
|||
import ( |
|||
"fmt" |
|||
"testing" |
|||
) |
|||
|
|||
var portRangeTests = []struct { |
|||
in string |
|||
out *PortRange |
|||
}{ |
|||
{"1", &PortRange{Min: 1, Max: 1}}, |
|||
{"1-3", &PortRange{Min: 1, Max: 3}}, |
|||
{"3-1", &PortRange{Min: 1, Max: 3}}, |
|||
{"0-100000", &PortRange{Min: 0, Max: 65535}}, |
|||
{"*", &PortRange{Min: 0, Max: 65535}}, |
|||
} |
|||
|
|||
var stringSetTests = []struct { |
|||
in string |
|||
out *StringSet |
|||
}{ |
|||
{"*", &StringSet{"*"}}, |
|||
{"google.pl,google.com", &StringSet{"google.pl", "google.com"}}, |
|||
} |
|||
|
|||
var portSetTests = []struct { |
|||
in string |
|||
out *PortSet |
|||
}{ |
|||
{"1,3", &PortSet{PortRange{Min: 1, Max: 1}, PortRange{Min: 3, Max: 3}}}, |
|||
{"1-3,7-5", &PortSet{PortRange{Min: 1, Max: 3}, PortRange{Min: 5, Max: 7}}}, |
|||
{"0-100000", &PortSet{PortRange{Min: 0, Max: 65535}}}, |
|||
{"*", &PortSet{PortRange{Min: 0, Max: 65535}}}, |
|||
} |
|||
|
|||
var permissionsTests = []struct { |
|||
in string |
|||
out *Permissions |
|||
}{ |
|||
{"", &Permissions{}}, |
|||
{"*:*:*", &Permissions{ |
|||
Permission{ |
|||
Actions: StringSet{"*"}, |
|||
Hosts: StringSet{"*"}, |
|||
Ports: PortSet{PortRange{Min: 0, Max: 65535}}, |
|||
}, |
|||
}}, |
|||
{"bind:127.0.0.1,localhost:80,443,8000-8100+connect:*.google.pl:80,443", &Permissions{ |
|||
Permission{ |
|||
Actions: StringSet{"bind"}, |
|||
Hosts: StringSet{"127.0.0.1", "localhost"}, |
|||
Ports: PortSet{ |
|||
PortRange{Min: 80, Max: 80}, |
|||
PortRange{Min: 443, Max: 443}, |
|||
PortRange{Min: 8000, Max: 8100}, |
|||
}, |
|||
}, |
|||
Permission{ |
|||
Actions: StringSet{"connect"}, |
|||
Hosts: StringSet{"*.google.pl"}, |
|||
Ports: PortSet{ |
|||
PortRange{Min: 80, Max: 80}, |
|||
PortRange{Min: 443, Max: 443}, |
|||
}, |
|||
}, |
|||
}}, |
|||
} |
|||
|
|||
func TestPortRangeParse(t *testing.T) { |
|||
for _, test := range portRangeTests { |
|||
actual, err := ParsePortRange(test.in) |
|||
if err != nil { |
|||
t.Errorf("ParsePortRange(%q) returned error: %v", test.in, err) |
|||
} else if *actual != *test.out { |
|||
t.Errorf("ParsePortRange(%q): got %v, want %v", test.in, actual, test.out) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func TestPortRangeContains(t *testing.T) { |
|||
actual, _ := ParsePortRange("5-10") |
|||
|
|||
if !actual.Contains(5) || !actual.Contains(7) || !actual.Contains(10) { |
|||
t.Errorf("5-10 should contain 5, 7 and 10") |
|||
} |
|||
|
|||
if actual.Contains(4) || actual.Contains(11) { |
|||
t.Errorf("5-10 should not contain 4, 11") |
|||
} |
|||
} |
|||
|
|||
func TestStringSetParse(t *testing.T) { |
|||
for _, test := range stringSetTests { |
|||
actual, err := ParseStringSet(test.in) |
|||
if err != nil { |
|||
t.Errorf("ParseStringSet(%q) returned error: %v", test.in, err) |
|||
} else if fmt.Sprintln(actual) != fmt.Sprintln(test.out) { |
|||
t.Errorf("ParseStringSet(%q): got %v, want %v", test.in, actual, test.out) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func TestStringSetContains(t *testing.T) { |
|||
ss, _ := ParseStringSet("google.pl,*.google.com") |
|||
|
|||
if !ss.Contains("google.pl") || !ss.Contains("www.google.com") { |
|||
t.Errorf("google.pl,*.google.com should contain google.pl and www.google.com") |
|||
} |
|||
|
|||
if ss.Contains("www.google.pl") || ss.Contains("foobar.com") { |
|||
t.Errorf("google.pl,*.google.com shound not contain www.google.pl and foobar.com") |
|||
} |
|||
} |
|||
|
|||
func TestPortSetParse(t *testing.T) { |
|||
for _, test := range portSetTests { |
|||
actual, err := ParsePortSet(test.in) |
|||
if err != nil { |
|||
t.Errorf("ParsePortRange(%q) returned error: %v", test.in, err) |
|||
} else if fmt.Sprintln(actual) != fmt.Sprintln(test.out) { |
|||
t.Errorf("ParsePortRange(%q): got %v, want %v", test.in, actual, test.out) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func TestPortSetContains(t *testing.T) { |
|||
actual, _ := ParsePortSet("5-10,20-30") |
|||
|
|||
if !actual.Contains(5) || !actual.Contains(7) || !actual.Contains(10) { |
|||
t.Errorf("5-10,20-30 should contain 5, 7 and 10") |
|||
} |
|||
|
|||
if !actual.Contains(20) || !actual.Contains(27) || !actual.Contains(30) { |
|||
t.Errorf("5-10,20-30 should contain 20, 27 and 30") |
|||
} |
|||
|
|||
if actual.Contains(4) || actual.Contains(11) || actual.Contains(31) { |
|||
t.Errorf("5-10,20-30 should not contain 4, 11, 31") |
|||
} |
|||
} |
|||
|
|||
func TestPermissionsParse(t *testing.T) { |
|||
for _, test := range permissionsTests { |
|||
actual, err := ParsePermissions(test.in) |
|||
if err != nil { |
|||
t.Errorf("ParsePermissions(%q) returned error: %v", test.in, err) |
|||
} else if fmt.Sprintln(actual) != fmt.Sprintln(test.out) { |
|||
t.Errorf("ParsePermissions(%q): got %v, want %v", test.in, actual, test.out) |
|||
} |
|||
} |
|||
} |
|||
Loading…
Reference in new issue