From 0033777ffd3be779592757a11c8c9358803fc0b4 Mon Sep 17 00:00:00 2001 From: hwipl <33433250+hwipl@users.noreply.github.com> Date: Mon, 4 Dec 2023 15:32:58 +0100 Subject: [PATCH] Make start of TND fallible Signed-off-by: hwipl <33433250+hwipl@users.noreply.github.com> --- pkg/tnd/detector.go | 36 +++++++++----- pkg/tnd/detector_test.go | 82 +++++++++++++++++++++++++------- pkg/tnd/tnd.go | 2 +- pkg/tnd/tndtest/detector.go | 7 +-- pkg/tnd/tndtest/detector_test.go | 3 +- 5 files changed, 97 insertions(+), 33 deletions(-) diff --git a/pkg/tnd/detector.go b/pkg/tnd/detector.go index 6b61f4c..40d2a76 100644 --- a/pkg/tnd/detector.go +++ b/pkg/tnd/detector.go @@ -19,6 +19,10 @@ type Detector struct { servers []*https.Server dialer *net.Dialer + // route and file watch + rw routes.Watcher + fw files.Watcher + // timer timer *time.Timer @@ -106,16 +110,8 @@ func (d *Detector) resetTimer() { func (d *Detector) start() { // signal stop to user via results defer close(d.results) - - // start route watching - rw := routes.NewWatch(d.probes) - rw.Start() - defer rw.Stop() - - // start file watching - fw := files.NewWatch(d.probes) - fw.Start() - defer fw.Stop() + defer d.rw.Stop() + defer d.fw.Stop() // set timer for periodic checks d.timer = time.NewTimer(d.config.UntrustedTimer) @@ -177,8 +173,21 @@ func (d *Detector) start() { } // Start starts the trusted network detection. -func (d *Detector) Start() { +func (d *Detector) Start() error { + // start route watching + if err := d.rw.Start(); err != nil { + return err + } + + // start file watching + if err := d.fw.Start(); err != nil { + d.rw.Stop() + return err + } + + // start detector go d.start() + return nil } // Stop stops the running TND. @@ -205,12 +214,15 @@ func (d *Detector) Results() chan bool { // NewDetector returns a new Detector. func NewDetector(config *Config) *Detector { + probes := make(chan struct{}) return &Detector{ config: config, - probes: make(chan struct{}), + probes: probes, results: make(chan bool), done: make(chan struct{}), dialer: &net.Dialer{}, + rw: routes.NewWatch(probes), + fw: files.NewWatch(probes), probeResults: make(chan bool), } diff --git a/pkg/tnd/detector_test.go b/pkg/tnd/detector_test.go index 5cb9296..b992b92 100644 --- a/pkg/tnd/detector_test.go +++ b/pkg/tnd/detector_test.go @@ -3,11 +3,20 @@ package tnd import ( "crypto/sha256" "encoding/hex" + "errors" "net" "reflect" "testing" ) +// testWatcher is a watcher that implements the routes.Watcher and +// files.Watcher interfaces. +type testWatcher struct{ err error } + +func (t *testWatcher) Start() error { return t.err } +func (t *testWatcher) Stop() {} +func (t *testWatcher) Probes() chan struct{} { return nil } + // TestDetectorSetGetServers tests SetServers and GetServers of Detector. func TestDetectorSetGetServers(t *testing.T) { tnd := NewDetector(NewConfig()) @@ -40,16 +49,50 @@ func TestDetectorSetGetDialer(t *testing.T) { } // TestTNDStartStop tests Start and Stop of TND. -func TestTNDStartStop(_ *testing.T) { - tnd := NewDetector(NewConfig()) - tnd.Start() - tnd.Stop() +func TestTNDStartStop(t *testing.T) { + // test rw error + t.Run("routes watch error", func(t *testing.T) { + tnd := NewDetector(NewConfig()) + tnd.rw = &testWatcher{err: errors.New("test error")} + tnd.fw = &testWatcher{} + if err := tnd.Start(); err == nil { + t.Error("start should fail") + return + } + }) + + // test fw error + t.Run("files watch error", func(t *testing.T) { + tnd := NewDetector(NewConfig()) + tnd.rw = &testWatcher{} + tnd.fw = &testWatcher{err: errors.New("test error")} + if err := tnd.Start(); err == nil { + t.Error("start should fail") + return + } + }) + + // test without errors + t.Run("no errors", func(t *testing.T) { + tnd := NewDetector(NewConfig()) + tnd.rw = &testWatcher{} + tnd.fw = &testWatcher{} + if err := tnd.Start(); err != nil { + t.Errorf("start should not fail: %v", err) + return + } + tnd.Stop() + }) } // TestTNDProbe tests Probe of TND. func TestTNDProbe(t *testing.T) { tnd := NewDetector(NewConfig()) - tnd.Start() + tnd.rw = &testWatcher{} + tnd.fw = &testWatcher{} + if err := tnd.Start(); err != nil { + t.Fatal(err) + } tnd.Probe() want := false got := <-tnd.Results() @@ -71,17 +114,24 @@ func TestTNDResults(t *testing.T) { // TestNewTND tests NewTND. func TestNewTND(t *testing.T) { - tnd := NewDetector(NewConfig()) - if tnd.probes == nil { - t.Errorf("got nil, want != nil") - } - if tnd.results == nil { - t.Errorf("got nil, want != nil") - } - if tnd.done == nil { - t.Errorf("got nil, want != nil") + c := NewConfig() + tnd := NewDetector(c) + + if tnd.config != c { + t.Errorf("got %v, want %v", tnd.config, c) } - if tnd.dialer == nil { - t.Errorf("got nil, want != nil") + + for i, x := range []any{ + tnd.probes, + tnd.results, + tnd.done, + tnd.dialer, + tnd.rw, + tnd.fw, + tnd.probeResults, + } { + if x == nil { + t.Errorf("got nil, want != nil: %d", i) + } } } diff --git a/pkg/tnd/tnd.go b/pkg/tnd/tnd.go index c556818..44f5665 100644 --- a/pkg/tnd/tnd.go +++ b/pkg/tnd/tnd.go @@ -11,7 +11,7 @@ type TND interface { GetServers() map[string]string SetDialer(dialer *net.Dialer) GetDialer() *net.Dialer - Start() + Start() error Stop() Probe() Results() chan bool diff --git a/pkg/tnd/tndtest/detector.go b/pkg/tnd/tndtest/detector.go index a40bef2..37ca46d 100644 --- a/pkg/tnd/tndtest/detector.go +++ b/pkg/tnd/tndtest/detector.go @@ -11,7 +11,7 @@ type Funcs struct { GetServers func() map[string]string SetDialer func(dialer *net.Dialer) GetDialer func() *net.Dialer - Start func() + Start func() error Stop func() Probe func() Results func() chan bool @@ -54,10 +54,11 @@ func (d *Detector) GetDialer() *net.Dialer { } // Start starts the trusted network detection. -func (d *Detector) Start() { +func (d *Detector) Start() error { if d.Funcs.Start != nil { - d.Funcs.Start() + return d.Funcs.Start() } + return nil } // Stop stops the running TND. diff --git a/pkg/tnd/tndtest/detector_test.go b/pkg/tnd/tndtest/detector_test.go index acefcd5..049cbf5 100644 --- a/pkg/tnd/tndtest/detector_test.go +++ b/pkg/tnd/tndtest/detector_test.go @@ -71,8 +71,9 @@ func TestDetectorStart(t *testing.T) { // test func set want := true got := false - d.Funcs.Start = func() { + d.Funcs.Start = func() error { got = true + return nil } d.Start() if got != want {