diff --git a/internal/routes/watch.go b/internal/routes/watch.go index 6cdf73c..543492f 100644 --- a/internal/routes/watch.go +++ b/internal/routes/watch.go @@ -7,9 +7,17 @@ import ( "golang.org/x/sys/unix" ) +// Watcher is the watcher interface. +type Watcher interface { + Start() error + Stop() + Probes() chan struct{} +} + // Watch waits for routing update events and then probes the // trusted https servers. type Watch struct { + events chan netlink.RouteUpdate probes chan struct{} done chan struct{} } @@ -24,17 +32,11 @@ func (w *Watch) sendProbe() { // start starts the Watch. func (w *Watch) start() { - // register for route update events - events := make(chan netlink.RouteUpdate) - if err := netlink.RouteSubscribe(events, w.done); err != nil { - log.WithError(err).Fatal("TND route subscribe error") - } - // run initial probe w.sendProbe() // handle route update events - for e := range events { + for e := range w.events { switch e.Type { case unix.RTM_NEWROUTE: log.WithField("dst", e.Dst).Debug("TND got route NEW event") @@ -45,9 +47,20 @@ func (w *Watch) start() { } } +// netlinkRouteSubscribe is netlink.RouteSubscribe for testing. +var netlinkRouteSubscribe = netlink.RouteSubscribe + // Start starts the Watch. -func (w *Watch) Start() { +func (w *Watch) Start() error { + // register for route update events + if err := netlinkRouteSubscribe(w.events, w.done); err != nil { + log.WithError(err).Error("TND route subscribe error") + return err + } + + // start watcher go w.start() + return nil } // Stop stops the Watch. @@ -67,6 +80,7 @@ func (w *Watch) Probes() chan struct{} { // NewWatch returns a new Watch. func NewWatch(probes chan struct{}) *Watch { return &Watch{ + events: make(chan netlink.RouteUpdate), probes: probes, done: make(chan struct{}), } diff --git a/internal/routes/watch_test.go b/internal/routes/watch_test.go index 9da3c5d..7125d59 100644 --- a/internal/routes/watch_test.go +++ b/internal/routes/watch_test.go @@ -1,13 +1,35 @@ package routes -import "testing" +import ( + "errors" + "testing" + + "github.com/vishvananda/netlink" +) // TestWatchStartStop tests Start and Stop of Watch. -func TestWatchStartStop(_ *testing.T) { +func TestWatchStartStop(t *testing.T) { probes := make(chan struct{}) - rw := NewWatch(probes) - rw.Start() - rw.Stop() + + t.Run("subscribe error", func(t *testing.T) { + defer func() { netlinkRouteSubscribe = netlink.RouteSubscribe }() + netlinkRouteSubscribe = func(chan<- netlink.RouteUpdate, <-chan struct{}) error { + return errors.New("test error") + } + + rw := NewWatch(probes) + if err := rw.Start(); err == nil { + t.Error("start should fail") + } + }) + + t.Run("no errors", func(t *testing.T) { + rw := NewWatch(probes) + if err := rw.Start(); err != nil { + t.Errorf("start should not fail: %v", err) + } + rw.Stop() + }) } // TestWatchProbes tests Probes of Watch. @@ -23,6 +45,9 @@ func TestWatchProbes(t *testing.T) { func TestNewWatch(t *testing.T) { probes := make(chan struct{}) rw := NewWatch(probes) + if rw.events == nil { + t.Errorf("got nil, want != nil") + } if rw.probes != probes { t.Errorf("got %p, want %p", rw.probes, probes) }