Skip to content

Commit

Permalink
add reloader for authenticator
Browse files Browse the repository at this point in the history
  • Loading branch information
ginuerzh committed Jan 9, 2019
1 parent 1930da5 commit 6266356
Show file tree
Hide file tree
Showing 31 changed files with 492 additions and 187 deletions.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
3 changes: 3 additions & 0 deletions cmd/gost/.config/secrets.txt → .config/secrets.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# period for live reloading
reload 3s

# username password

$test.admin$ $123456$
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ gost - GO Simple Tunnel
* 多端口监听
* 可设置转发代理,支持多级转发(代理链)
* 支持标准HTTP/HTTPS/HTTP2/SOCKS4(A)/SOCKS5代理协议
* Web代理支持[探测防御](https://docs.ginuerzh.xyz/gost/probe_resist/)
* [支持多种隧道类型](https://docs.ginuerzh.xyz/gost/configuration/)
* [SOCKS5代理支持TLS协商加密](https://docs.ginuerzh.xyz/gost/socks/)
* [Tunnel UDP over TCP](https://docs.ginuerzh.xyz/gost/socks/)
Expand Down
1 change: 1 addition & 0 deletions README_en.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Features
* Listening on multiple ports
* Multi-level forward proxy - proxy chain
* Standard HTTP/HTTPS/HTTP2/SOCKS4(A)/SOCKS5 proxy protocols support
* [Probing resistance](https://docs.ginuerzh.xyz/gost/en/probe_resist/) support for web proxy
* [Support multiple tunnel types](https://docs.ginuerzh.xyz/gost/en/configuration/)
* [TLS encryption via negotiation support for SOCKS5 proxy](https://docs.ginuerzh.xyz/gost/en/socks/)
* [Tunnel UDP over TCP](https://docs.ginuerzh.xyz/gost/en/socks/)
Expand Down
155 changes: 155 additions & 0 deletions auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
package gost

import (
"bufio"
"io"
"strings"
"sync"
"time"
)

// Authenticator is an interface for user authentication.
type Authenticator interface {
Authenticate(user, password string) bool
}

// LocalAuthenticator is an Authenticator that authenticates client by local key-value pairs.
type LocalAuthenticator struct {
kvs map[string]string
period time.Duration
stopped chan struct{}
mux sync.RWMutex
}

// NewLocalAuthenticator creates an Authenticator that authenticates client by local infos.
func NewLocalAuthenticator(kvs map[string]string) *LocalAuthenticator {
return &LocalAuthenticator{
kvs: kvs,
stopped: make(chan struct{}),
}
}

// Authenticate checks the validity of the provided user-password pair.
func (au *LocalAuthenticator) Authenticate(user, password string) bool {
if au == nil {
return true
}

au.mux.RLock()
defer au.mux.RUnlock()

if len(au.kvs) == 0 {
return true
}

v, ok := au.kvs[user]
return ok && (v == "" || password == v)
}

// Add adds a key-value pair to the Authenticator.
func (au *LocalAuthenticator) Add(k, v string) {
au.mux.Lock()
defer au.mux.Unlock()
if au.kvs == nil {
au.kvs = make(map[string]string)
}
au.kvs[k] = v
}

// Reload parses config from r, then live reloads the bypass.
func (au *LocalAuthenticator) Reload(r io.Reader) error {
var period time.Duration
kvs := make(map[string]string)

if r == nil || au.Stopped() {
return nil
}

// splitLine splits a line text by white space.
// A line started with '#' will be ignored, otherwise it is valid.
split := func(line string) []string {
if line == "" {
return nil
}
line = strings.Replace(line, "\t", " ", -1)
line = strings.TrimSpace(line)

if strings.IndexByte(line, '#') == 0 {
return nil
}

var ss []string
for _, s := range strings.Split(line, " ") {
if s = strings.TrimSpace(s); s != "" {
ss = append(ss, s)
}
}
return ss
}

scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := scanner.Text()
ss := split(line)
if len(ss) == 0 {
continue
}

switch ss[0] {
case "reload": // reload option
if len(ss) > 1 {
period, _ = time.ParseDuration(ss[1])
}
default:
var k, v string
k = ss[0]
if len(ss) > 1 {
v = ss[1]
}
kvs[k] = v
}
}

if err := scanner.Err(); err != nil {
return err
}

au.mux.Lock()
defer au.mux.Unlock()

au.period = period
au.kvs = kvs

return nil
}

// Period returns the reload period.
func (au *LocalAuthenticator) Period() time.Duration {
if au.Stopped() {
return -1
}

au.mux.RLock()
defer au.mux.RUnlock()

return au.period
}

// Stop stops reloading.
func (au *LocalAuthenticator) Stop() {
select {
case <-au.stopped:
default:
close(au.stopped)
}
}

// Stopped checks whether the reloader is stopped.
func (au *LocalAuthenticator) Stopped() bool {
select {
case <-au.stopped:
return true
default:
return false
}
}
191 changes: 191 additions & 0 deletions auth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
package gost

import (
"bytes"
"fmt"
"io"
"net/url"
"testing"
"time"
)

var localAuthenticatorTests = []struct {
clientUser *url.Userinfo
serverUsers []*url.Userinfo
valid bool
}{
{nil, nil, true},
{nil, []*url.Userinfo{url.User("admin")}, false},
{nil, []*url.Userinfo{url.UserPassword("", "123456")}, false},
{nil, []*url.Userinfo{url.UserPassword("admin", "123456")}, false},

{url.User("admin"), nil, true},
{url.User("admin"), []*url.Userinfo{url.User("admin")}, true},
{url.User("admin"), []*url.Userinfo{url.User("test")}, false},
{url.User("admin"), []*url.Userinfo{url.UserPassword("test", "123456")}, false},
{url.User("admin"), []*url.Userinfo{url.UserPassword("admin", "123456")}, false},
{url.User("admin"), []*url.Userinfo{url.UserPassword("admin", "")}, true},
{url.User("admin"), []*url.Userinfo{url.UserPassword("", "123456")}, false},

{url.UserPassword("", ""), nil, true},
{url.UserPassword("", "123456"), nil, true},
{url.UserPassword("", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, true},
{url.UserPassword("", "123456"), []*url.Userinfo{url.UserPassword("admin", "")}, false},
{url.UserPassword("", "123456"), []*url.Userinfo{url.UserPassword("admin", "123456")}, false},

{url.UserPassword("admin", "123456"), nil, true},
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.User("admin")}, true},
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.User("test")}, false},
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("admin", "")}, true},
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, false},
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("admin", "123")}, false},
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("test", "123456")}, false},
{url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("admin", "123456")}, true},

{url.UserPassword("admin", "123456"), []*url.Userinfo{
url.UserPassword("test", "123"),
url.UserPassword("admin", "123456"),
}, true},
}

func TestLocalAuthenticator(t *testing.T) {
for i, tc := range localAuthenticatorTests {
tc := tc
t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) {
au := NewLocalAuthenticator(nil)
for _, u := range tc.serverUsers {
if u != nil {
p, _ := u.Password()
au.Add(u.Username(), p)
}
}

var u, p string
if tc.clientUser != nil {
u = tc.clientUser.Username()
p, _ = tc.clientUser.Password()
}
if au.Authenticate(u, p) != tc.valid {
t.Error("authenticate result should be", tc.valid)
}
})
}
}

var localAuthenticatorReloadTests = []struct {
r io.Reader
period time.Duration
kvs map[string]string
stopped bool
}{
{
r: nil,
period: 0,
kvs: nil,
},
{
r: bytes.NewBufferString(""),
period: 0,
},
{
r: bytes.NewBufferString("reload 10s"),
period: 10 * time.Second,
},
{
r: bytes.NewBufferString("# reload 10s\n"),
},
{
r: bytes.NewBufferString("reload 10s\n#admin"),
period: 10 * time.Second,
},
{
r: bytes.NewBufferString("reload 10s\nadmin"),
period: 10 * time.Second,
kvs: map[string]string{
"admin": "",
},
},
{
r: bytes.NewBufferString("# reload 10s\nadmin"),
kvs: map[string]string{
"admin": "",
},
},
{
r: bytes.NewBufferString("# reload 10s\nadmin #123456"),
kvs: map[string]string{
"admin": "#123456",
},
stopped: true,
},
{
r: bytes.NewBufferString("admin \t #123456\n\n\ntest \t 123456"),
kvs: map[string]string{
"admin": "#123456",
"test": "123456",
},
stopped: true,
},
{
r: bytes.NewBufferString(`
$test.admin$ $123456$
@test.admin@ @123456@
test.admin# #123456#
test.admin\admin 123456
`),
kvs: map[string]string{
"$test.admin$": "$123456$",
"@test.admin@": "@123456@",
"test.admin#": "#123456#",
"test.admin\\admin": "123456",
},
stopped: true,
},
}

func TestLocalAuthenticatorReload(t *testing.T) {
isEquals := func(a, b map[string]string) bool {
if len(a) == 0 && len(b) == 0 {
return true
}
if len(a) != len(b) {
return false
}

for k, v := range a {
if b[k] != v {
return false
}
}
return true
}
for i, tc := range localAuthenticatorReloadTests {
tc := tc
t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) {
au := NewLocalAuthenticator(nil)

if err := au.Reload(tc.r); err != nil {
t.Error(err)
}
if au.Period() != tc.period {
t.Errorf("#%d test failed: period value should be %v, got %v",
i, tc.period, au.Period())
}
if !isEquals(au.kvs, tc.kvs) {
t.Errorf("#%d test failed: %v, %s", i, au.kvs, tc.kvs)
}

if tc.stopped {
au.Stop()
if au.Period() >= 0 {
t.Errorf("period of the stopped reloader should be minus value")
}
au.Stop()
}
if au.Stopped() != tc.stopped {
t.Errorf("#%d test failed: stopped value should be %v, got %v",
i, tc.stopped, au.Stopped())
}
})
}
}
Loading

0 comments on commit 6266356

Please sign in to comment.