Skip to content

Commit

Permalink
Make start of TND fallible
Browse files Browse the repository at this point in the history
Signed-off-by: hwipl <[email protected]>
  • Loading branch information
hwipl committed Dec 13, 2023
1 parent c95f854 commit 0033777
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 33 deletions.
36 changes: 24 additions & 12 deletions pkg/tnd/detector.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ type Detector struct {
servers []*https.Server
dialer *net.Dialer

// route and file watch
rw routes.Watcher
fw files.Watcher

// timer
timer *time.Timer

Expand Down Expand Up @@ -106,16 +110,8 @@ func (d *Detector) resetTimer() {
func (d *Detector) start() {
// signal stop to user via results
defer close(d.results)

// start route watching
rw := routes.NewWatch(d.probes)
rw.Start()
defer rw.Stop()

// start file watching
fw := files.NewWatch(d.probes)
fw.Start()
defer fw.Stop()
defer d.rw.Stop()
defer d.fw.Stop()

// set timer for periodic checks
d.timer = time.NewTimer(d.config.UntrustedTimer)
Expand Down Expand Up @@ -177,8 +173,21 @@ func (d *Detector) start() {
}

// Start starts the trusted network detection.
func (d *Detector) Start() {
func (d *Detector) Start() error {
// start route watching
if err := d.rw.Start(); err != nil {
return err
}

// start file watching
if err := d.fw.Start(); err != nil {
d.rw.Stop()
return err
}

// start detector
go d.start()
return nil
}

// Stop stops the running TND.
Expand All @@ -205,12 +214,15 @@ func (d *Detector) Results() chan bool {

// NewDetector returns a new Detector.
func NewDetector(config *Config) *Detector {
probes := make(chan struct{})
return &Detector{
config: config,
probes: make(chan struct{}),
probes: probes,
results: make(chan bool),
done: make(chan struct{}),
dialer: &net.Dialer{},
rw: routes.NewWatch(probes),
fw: files.NewWatch(probes),

probeResults: make(chan bool),
}
Expand Down
82 changes: 66 additions & 16 deletions pkg/tnd/detector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,20 @@ package tnd
import (
"crypto/sha256"
"encoding/hex"
"errors"
"net"
"reflect"
"testing"
)

// testWatcher is a watcher that implements the routes.Watcher and
// files.Watcher interfaces.
type testWatcher struct{ err error }

func (t *testWatcher) Start() error { return t.err }
func (t *testWatcher) Stop() {}
func (t *testWatcher) Probes() chan struct{} { return nil }

// TestDetectorSetGetServers tests SetServers and GetServers of Detector.
func TestDetectorSetGetServers(t *testing.T) {
tnd := NewDetector(NewConfig())
Expand Down Expand Up @@ -40,16 +49,50 @@ func TestDetectorSetGetDialer(t *testing.T) {
}

// TestTNDStartStop tests Start and Stop of TND.
func TestTNDStartStop(_ *testing.T) {
tnd := NewDetector(NewConfig())
tnd.Start()
tnd.Stop()
func TestTNDStartStop(t *testing.T) {
// test rw error
t.Run("routes watch error", func(t *testing.T) {
tnd := NewDetector(NewConfig())
tnd.rw = &testWatcher{err: errors.New("test error")}
tnd.fw = &testWatcher{}
if err := tnd.Start(); err == nil {
t.Error("start should fail")
return
}
})

// test fw error
t.Run("files watch error", func(t *testing.T) {
tnd := NewDetector(NewConfig())
tnd.rw = &testWatcher{}
tnd.fw = &testWatcher{err: errors.New("test error")}
if err := tnd.Start(); err == nil {
t.Error("start should fail")
return
}
})

// test without errors
t.Run("no errors", func(t *testing.T) {
tnd := NewDetector(NewConfig())
tnd.rw = &testWatcher{}
tnd.fw = &testWatcher{}
if err := tnd.Start(); err != nil {
t.Errorf("start should not fail: %v", err)
return
}
tnd.Stop()
})
}

// TestTNDProbe tests Probe of TND.
func TestTNDProbe(t *testing.T) {
tnd := NewDetector(NewConfig())
tnd.Start()
tnd.rw = &testWatcher{}
tnd.fw = &testWatcher{}
if err := tnd.Start(); err != nil {
t.Fatal(err)
}
tnd.Probe()
want := false
got := <-tnd.Results()
Expand All @@ -71,17 +114,24 @@ func TestTNDResults(t *testing.T) {

// TestNewTND tests NewTND.
func TestNewTND(t *testing.T) {
tnd := NewDetector(NewConfig())
if tnd.probes == nil {
t.Errorf("got nil, want != nil")
}
if tnd.results == nil {
t.Errorf("got nil, want != nil")
}
if tnd.done == nil {
t.Errorf("got nil, want != nil")
c := NewConfig()
tnd := NewDetector(c)

if tnd.config != c {
t.Errorf("got %v, want %v", tnd.config, c)
}
if tnd.dialer == nil {
t.Errorf("got nil, want != nil")

for i, x := range []any{
tnd.probes,
tnd.results,
tnd.done,
tnd.dialer,
tnd.rw,
tnd.fw,
tnd.probeResults,
} {
if x == nil {
t.Errorf("got nil, want != nil: %d", i)
}
}
}
2 changes: 1 addition & 1 deletion pkg/tnd/tnd.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ type TND interface {
GetServers() map[string]string
SetDialer(dialer *net.Dialer)
GetDialer() *net.Dialer
Start()
Start() error
Stop()
Probe()
Results() chan bool
Expand Down
7 changes: 4 additions & 3 deletions pkg/tnd/tndtest/detector.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ type Funcs struct {
GetServers func() map[string]string
SetDialer func(dialer *net.Dialer)
GetDialer func() *net.Dialer
Start func()
Start func() error
Stop func()
Probe func()
Results func() chan bool
Expand Down Expand Up @@ -54,10 +54,11 @@ func (d *Detector) GetDialer() *net.Dialer {
}

// Start starts the trusted network detection.
func (d *Detector) Start() {
func (d *Detector) Start() error {
if d.Funcs.Start != nil {
d.Funcs.Start()
return d.Funcs.Start()
}
return nil
}

// Stop stops the running TND.
Expand Down
3 changes: 2 additions & 1 deletion pkg/tnd/tndtest/detector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,9 @@ func TestDetectorStart(t *testing.T) {
// test func set
want := true
got := false
d.Funcs.Start = func() {
d.Funcs.Start = func() error {
got = true
return nil
}
d.Start()
if got != want {
Expand Down

0 comments on commit 0033777

Please sign in to comment.