diff --git a/internal/cmdtmpl/command.go b/internal/cmdtmpl/command.go index 6039870..c3d59fa 100644 --- a/internal/cmdtmpl/command.go +++ b/internal/cmdtmpl/command.go @@ -47,39 +47,6 @@ func (cl *CommandList) executeTemplate(tmpl string, data any) (string, error) { return s, nil } -// getCommandListSplitRouting returns the command list identified by name for SplitRouting. -func getCommandListSplitRouting(name string) *CommandList { - var cl *CommandList - switch name { - case "SplitRoutingSetExcludes": - // Set Excludes - cl = &CommandList{ - Name: name, - Commands: []*Command{ - // flush existing entries - // add entries - {Line: "{{.Executables.Nft}} -f -", - Stdin: `flush set inet oc-daemon-routing excludes4 -flush set inet oc-daemon-routing excludes6 -{{range .Addresses -}} -{{if .Addr.Is6 -}} -add element inet oc-daemon-routing excludes6 { {{.}} } -{{else -}} -add element inet oc-daemon-routing excludes4 { {{.}} } -{{end -}} -{{end}}`}, - }, - defaultTemplate: VPNSetupDefaultTemplate, - } - default: - return nil - - } - - cl.template = template.Must(template.New("Template").Parse(cl.defaultTemplate)) - return cl -} - // TrafPolDefaultTemplate is the default template for Traffic Policing. const TrafPolDefaultTemplate = ` {{- define "TrafPolRules"}} @@ -502,6 +469,26 @@ func getCommandListVPNSetup(name string) *CommandList { }, defaultTemplate: VPNSetupDefaultTemplate, } + case "VPNSetupSetExcludes": + // Set Excludes + cl = &CommandList{ + Name: name, + Commands: []*Command{ + // flush existing entries + // add entries + {Line: "{{.Executables.Nft}} -f -", + Stdin: `flush set inet oc-daemon-routing excludes4 +flush set inet oc-daemon-routing excludes6 +{{range .Addresses -}} +{{if .Addr.Is6 -}} +add element inet oc-daemon-routing excludes6 { {{.}} } +{{else -}} +add element inet oc-daemon-routing excludes4 { {{.}} } +{{end -}} +{{end}}`}, + }, + defaultTemplate: VPNSetupDefaultTemplate, + } case "VPNSetupSetupDNSServer": // Setup DNS server cl = &CommandList{ @@ -568,9 +555,6 @@ func getCommandListVPNSetup(name string) *CommandList { // getCommandList returns the command list identified by name. func getCommandList(name string) *CommandList { - if strings.HasPrefix(name, "SplitRouting") { - return getCommandListSplitRouting(name) - } if strings.HasPrefix(name, "TrafPol") { return getCommandListTrafPol(name) } diff --git a/internal/cmdtmpl/command_test.go b/internal/cmdtmpl/command_test.go index 1e11758..85571a7 100644 --- a/internal/cmdtmpl/command_test.go +++ b/internal/cmdtmpl/command_test.go @@ -22,7 +22,6 @@ func TestExecuteTemplateParseError(t *testing.T) { func TestGetCommandList(t *testing.T) { // not existing for _, name := range []string{ - "SplitRoutingDoesNotExist", "TrafPolDoesNotExist", "VPNSetupDoesNotExist", "DoesNotExist", @@ -35,9 +34,6 @@ func TestGetCommandList(t *testing.T) { // existing for _, name := range []string{ - // Split Routing - "SplitRoutingSetExcludes", - // Traffic Policing "TrafPolSetFilterRules", "TrafPolUnsetFilterRules", @@ -52,6 +48,7 @@ func TestGetCommandList(t *testing.T) { // VPN Setup "VPNSetupSetup", "VPNSetupTeardown", + "VPNSetupSetExcludes", "VPNSetupSetupDNSServer", "VPNSetupSetupDNSDomains", "VPNSetupSetupDNSDefaultRoute", @@ -89,9 +86,6 @@ func TestGetCmds(t *testing.T) { // existing, that only need daemon config as input data for _, name := range []string{ - // Split Routing - // "SplitRoutingSetExcludes", // skip, requires excludes - // Traffic Policing "TrafPolSetFilterRules", "TrafPolUnsetFilterRules", @@ -106,6 +100,7 @@ func TestGetCmds(t *testing.T) { // VPN Setup "VPNSetupSetup", "VPNSetupTeardown", + // "VPNSetupSetExcludes", // skip, requires excludes "VPNSetupSetupDNSServer", "VPNSetupSetupDNSDomains", "VPNSetupSetupDNSDefaultRoute", @@ -121,13 +116,13 @@ func TestGetCmds(t *testing.T) { // existing, with insufficient input data for _, name := range []string{ - // Split Routing - "SplitRoutingSetExcludes", - // Traffic Policing "TrafPolAddAllowedDevice", "TrafPolRemoveAllowedDevice", "TrafPolAddAllowedHost", + + // VPN Setup + "VPNSetupSetExcludes", } { if _, err := GetCmds(name, daemoncfg.NewConfig()); err == nil { t.Errorf("insufficient data should return error for list %s", name) diff --git a/internal/splitrt/excludes.go b/internal/splitrt/excludes.go index 7f96f57..c4cfb2b 100644 --- a/internal/splitrt/excludes.go +++ b/internal/splitrt/excludes.go @@ -5,7 +5,6 @@ import ( "sync" log "github.com/sirupsen/logrus" - "github.com/telekom-mms/oc-daemon/internal/daemoncfg" ) const ( @@ -22,7 +21,6 @@ type dynExclude struct { // Excludes contains split Excludes. type Excludes struct { sync.Mutex - conf *daemoncfg.Config s map[string]netip.Prefix d map[netip.Addr]*dynExclude done chan struct{} @@ -175,9 +173,8 @@ func (e *Excludes) List() (static, dynamic []string) { } // NewExcludes returns new split excludes. -func NewExcludes(conf *daemoncfg.Config) *Excludes { +func NewExcludes() *Excludes { return &Excludes{ - conf: conf, s: make(map[string]netip.Prefix), d: make(map[netip.Addr]*dynExclude), done: make(chan struct{}), diff --git a/internal/splitrt/excludes_test.go b/internal/splitrt/excludes_test.go index 26e311d..537f9dd 100644 --- a/internal/splitrt/excludes_test.go +++ b/internal/splitrt/excludes_test.go @@ -3,8 +3,6 @@ package splitrt import ( "net/netip" "testing" - - "github.com/telekom-mms/oc-daemon/internal/daemoncfg" ) // getTestExcludes returns excludes for testing. @@ -59,7 +57,7 @@ func getTestDynamicExcludes(t *testing.T) []netip.Prefix { // TestExcludesAddStatic tests AddStatic of Excludes. func TestExcludesAddStatic(t *testing.T) { - e := NewExcludes(daemoncfg.NewConfig()) + e := NewExcludes() excludes := getTestStaticExcludes(t) // test adding excludes @@ -77,7 +75,7 @@ func TestExcludesAddStatic(t *testing.T) { } // test adding overlapping excludes - e = NewExcludes(daemoncfg.NewConfig()) + e = NewExcludes() for _, exclude := range getTestStaticExcludesOverlap(t) { e.AddStatic(exclude) } @@ -90,7 +88,7 @@ func TestExcludesAddStatic(t *testing.T) { // TestExcludesAddDynamic tests AddDynamic of Excludes. func TestExcludesAddDynamic(t *testing.T) { - e := NewExcludes(daemoncfg.NewConfig()) + e := NewExcludes() excludes := getTestDynamicExcludes(t) // test adding excludes @@ -110,7 +108,7 @@ func TestExcludesAddDynamic(t *testing.T) { // test adding excludes with existing static excludes, // should only add new excludes statics := getTestStaticExcludes(t) - e = NewExcludes(daemoncfg.NewConfig()) + e = NewExcludes() for _, exclude := range statics { if !e.AddStatic(exclude) { t.Errorf("should add exclude %s", exclude) @@ -132,7 +130,7 @@ func TestExcludesAddDynamic(t *testing.T) { } // test adding invalid excludes (static as dynamic) - e = NewExcludes(daemoncfg.NewConfig()) + e = NewExcludes() for _, exclude := range getTestStaticExcludes(t) { if e.AddDynamic(exclude, 300) { t.Errorf("should not add exclude %s", exclude) @@ -142,7 +140,7 @@ func TestExcludesAddDynamic(t *testing.T) { // TestExcludesRemoveStatic tests RemoveStatic of Excludes. func TestExcludesRemove(t *testing.T) { - e := NewExcludes(daemoncfg.NewConfig()) + e := NewExcludes() excludes := getTestStaticExcludes(t) // test removing not existing excludes @@ -182,7 +180,7 @@ func TestExcludesRemove(t *testing.T) { // TestExcludesCleanup tests cleanup of Excludes. func TestExcludesCleanup(t *testing.T) { - e := NewExcludes(daemoncfg.NewConfig()) + e := NewExcludes() // test without excludes if e.cleanup() { @@ -219,10 +217,8 @@ func TestExcludesCleanup(t *testing.T) { // TestNewExcludes tests NewExcludes. func TestNewExcludes(t *testing.T) { - conf := daemoncfg.NewConfig() - e := NewExcludes(conf) + e := NewExcludes() if e == nil || - e.conf != conf || e.s == nil || e.d == nil || e.done == nil || diff --git a/internal/splitrt/filter.go b/internal/splitrt/filter.go deleted file mode 100644 index 663af24..0000000 --- a/internal/splitrt/filter.go +++ /dev/null @@ -1,38 +0,0 @@ -package splitrt - -import ( - "context" - "net/netip" - - log "github.com/sirupsen/logrus" - "github.com/telekom-mms/oc-daemon/internal/cmdtmpl" - "github.com/telekom-mms/oc-daemon/internal/daemoncfg" -) - -// setExcludes resets the excludes to addresses in netfilter. -func setExcludes(ctx context.Context, conf *daemoncfg.Config, addresses []netip.Prefix) { - data := &struct { - daemoncfg.Config - Addresses []netip.Prefix - }{ - Config: *conf, - Addresses: addresses, - } - cmds, err := cmdtmpl.GetCmds("SplitRoutingSetExcludes", data) - if err != nil { - log.WithError(err).Error("SplitRouting could not get set excludes commands") - } - for _, c := range cmds { - if stdout, stderr, err := c.Run(ctx); err != nil { - log.WithFields(log.Fields{ - "addresses": addresses, - "command": c.Cmd, - "args": c.Args, - "stdin": c.Stdin, - "stdout": string(stdout), - "stderr": string(stderr), - "error": err, - }).Error("SplitRouting could not run set excludes command") - } - } -} diff --git a/internal/splitrt/splitrt.go b/internal/splitrt/splitrt.go index 02dcf1f..11466f1 100644 --- a/internal/splitrt/splitrt.go +++ b/internal/splitrt/splitrt.go @@ -2,7 +2,6 @@ package splitrt import ( - "context" "fmt" "net/netip" "sync" @@ -55,7 +54,8 @@ type SplitRouting struct { addrs *Addresses locals locals excludes *Excludes - dnsreps chan *dnsproxy.Report + prefixes chan []netip.Prefix + dnsreps <-chan *dnsproxy.Report done chan struct{} closed chan struct{} } @@ -73,8 +73,16 @@ func (s *SplitRouting) excludeLocalNetworks() (exclude bool, virtual bool) { return } +// sendPrefixes sends the current prefixes over the prefixes channel. +func (s *SplitRouting) sendPrefixes(p []netip.Prefix) { + select { + case s.prefixes <- p: + case <-s.done: + } +} + // updateLocalNetworkExcludes updates the local network split excludes. -func (s *SplitRouting) updateLocalNetworkExcludes(ctx context.Context) { +func (s *SplitRouting) updateLocalNetworkExcludes() { exclude, virtual := s.excludeLocalNetworks() // stop if exclude local networks is disabled @@ -106,30 +114,32 @@ func (s *SplitRouting) updateLocalNetworkExcludes(ctx context.Context) { } // add new excludes + updated := false for _, e := range excludes { if !isIn(e, s.locals.get()) { - if s.excludes.AddStatic(e) { - setExcludes(ctx, s.config, s.excludes.GetPrefixes()) - } + updated = s.excludes.AddStatic(e) || updated } } // remove old excludes for _, l := range s.locals.get() { if !isIn(l, excludes) { - if s.excludes.RemoveStatic(l) { - setExcludes(ctx, s.config, s.excludes.GetPrefixes()) - } + updated = s.excludes.RemoveStatic(l) || updated } } + if updated { + // signal update + s.sendPrefixes(s.excludes.GetPrefixes()) + } + // save local excludes s.locals.set(excludes) log.WithField("locals", s.locals.get()).Debug("SplitRouting updated exclude local networks") } // handleDeviceUpdate handles a device update from the device monitor. -func (s *SplitRouting) handleDeviceUpdate(ctx context.Context, u *devmon.Update) { +func (s *SplitRouting) handleDeviceUpdate(u *devmon.Update) { log.WithField("update", u).Debug("SplitRouting got device update") if u.Add { @@ -146,11 +156,11 @@ func (s *SplitRouting) handleDeviceUpdate(ctx context.Context, u *devmon.Update) } else { s.devices.Remove(u) } - s.updateLocalNetworkExcludes(ctx) + s.updateLocalNetworkExcludes() } // handleAddressUpdate handles an address update from the address monitor. -func (s *SplitRouting) handleAddressUpdate(ctx context.Context, u *addrmon.Update) { +func (s *SplitRouting) handleAddressUpdate(u *addrmon.Update) { log.WithField("update", u).Debug("SplitRouting got address update") if u.Add { @@ -158,36 +168,43 @@ func (s *SplitRouting) handleAddressUpdate(ctx context.Context, u *addrmon.Updat } else { s.addrs.Remove(u) } - s.updateLocalNetworkExcludes(ctx) + s.updateLocalNetworkExcludes() } // handleDNSReport handles a DNS report. -func (s *SplitRouting) handleDNSReport(ctx context.Context, r *dnsproxy.Report) { +func (s *SplitRouting) handleDNSReport(r *dnsproxy.Report) { defer r.Close() log.WithField("report", r).Debug("SplitRouting handling DNS report") exclude := netip.PrefixFrom(r.IP, r.IP.BitLen()) if s.excludes.AddDynamic(exclude, r.TTL) { - setExcludes(ctx, s.config, s.excludes.GetPrefixes()) + // signal update + s.sendPrefixes(s.excludes.GetPrefixes()) } } // start starts split routing. -func (s *SplitRouting) start(ctx context.Context) { +func (s *SplitRouting) start() { defer close(s.closed) + defer close(s.prefixes) defer s.devmon.Stop() defer s.addrmon.Stop() + // send initial prefixes + if prefixes := s.excludes.GetPrefixes(); len(prefixes) > 0 { + s.sendPrefixes(prefixes) + } + // main loop timer := time.NewTimer(excludesTimer * time.Second) for { select { case u := <-s.devmon.Updates(): - s.handleDeviceUpdate(ctx, u) + s.handleDeviceUpdate(u) case u := <-s.addrmon.Updates(): - s.handleAddressUpdate(ctx, u) + s.handleAddressUpdate(u) case r := <-s.dnsreps: - s.handleDNSReport(ctx, r) + s.handleDNSReport(r) case <-timer.C: s.excludes.cleanup() timer.Reset(excludesTimer * time.Second) @@ -204,9 +221,6 @@ func (s *SplitRouting) start(ctx context.Context) { func (s *SplitRouting) Start() error { log.Debug("SplitRouting starting") - // create context - ctx := context.Background() - // start device monitor if err := s.devmon.Start(); err != nil { return fmt.Errorf("SplitRouting could not start DevMon: %w", err) @@ -222,9 +236,7 @@ func (s *SplitRouting) Start() error { if s.config.VPNConfig.Gateway.IsValid() { gateway := netip.PrefixFrom(s.config.VPNConfig.Gateway, s.config.VPNConfig.Gateway.BitLen()) - if s.excludes.AddStatic(gateway) { - setExcludes(ctx, s.config, s.excludes.GetPrefixes()) - } + s.excludes.AddStatic(gateway) } // add static IPv4 excludes @@ -232,9 +244,7 @@ func (s *SplitRouting) Start() error { if e.String() == "0.0.0.0/32" { continue } - if s.excludes.AddStatic(e) { - setExcludes(ctx, s.config, s.excludes.GetPrefixes()) - } + s.excludes.AddStatic(e) } // add static IPv6 excludes @@ -243,12 +253,10 @@ func (s *SplitRouting) Start() error { if e.String() == "::/128" { continue } - if s.excludes.AddStatic(e) { - setExcludes(ctx, s.config, s.excludes.GetPrefixes()) - } + s.excludes.AddStatic(e) } - go s.start(ctx) + go s.start() return nil } @@ -259,9 +267,9 @@ func (s *SplitRouting) Stop() { log.Debug("SplitRouting stopped") } -// DNSReports returns the channel for dns reports. -func (s *SplitRouting) DNSReports() chan *dnsproxy.Report { - return s.dnsreps +// Prefixes returns the channel for the exclude prefixes. +func (s *SplitRouting) Prefixes() <-chan []netip.Prefix { + return s.prefixes } // GetState returns the internal state. @@ -281,15 +289,16 @@ func (s *SplitRouting) GetState() *State { } // NewSplitRouting returns a new SplitRouting. -func NewSplitRouting(config *daemoncfg.Config) *SplitRouting { +func NewSplitRouting(config *daemoncfg.Config, dnsReports <-chan *dnsproxy.Report) *SplitRouting { return &SplitRouting{ config: config, devmon: devmon.NewDevMon(), addrmon: addrmon.NewAddrMon(), devices: NewDevices(), addrs: NewAddresses(), - excludes: NewExcludes(config), - dnsreps: make(chan *dnsproxy.Report), + excludes: NewExcludes(), + prefixes: make(chan []netip.Prefix), + dnsreps: dnsReports, done: make(chan struct{}), closed: make(chan struct{}), } diff --git a/internal/splitrt/splitrt_test.go b/internal/splitrt/splitrt_test.go index f2d33b3..783676d 100644 --- a/internal/splitrt/splitrt_test.go +++ b/internal/splitrt/splitrt_test.go @@ -1,109 +1,103 @@ package splitrt import ( - "context" - "errors" + "cmp" "net/netip" "reflect" + "slices" "testing" "github.com/telekom-mms/oc-daemon/internal/addrmon" "github.com/telekom-mms/oc-daemon/internal/daemoncfg" "github.com/telekom-mms/oc-daemon/internal/devmon" "github.com/telekom-mms/oc-daemon/internal/dnsproxy" - "github.com/telekom-mms/oc-daemon/internal/execs" "github.com/vishvananda/netlink" ) // TestSplitRoutingHandleDeviceUpdate tests handleDeviceUpdate of SplitRouting. func TestSplitRoutingHandleDeviceUpdate(t *testing.T) { - ctx := context.Background() - s := NewSplitRouting(daemoncfg.NewConfig()) - - want := []string{"nothing else"} - got := []string{"nothing else"} - - oldRunCmd := execs.RunCmd - execs.RunCmd = func(_ context.Context, _ string, s string, _ ...string) ([]byte, []byte, error) { - got = append(got, s) - return nil, nil, nil + s := NewSplitRouting(daemoncfg.NewConfig(), make(chan *dnsproxy.Report)) + + // test helper + want := []netip.Prefix{} + got := []netip.Prefix{} + test := func(t *testing.T, update *devmon.Update) { + done := make(chan struct{}) + go func() { + defer close(done) + s.handleDeviceUpdate(update) + }() + select { + case got = <-s.Prefixes(): + case <-done: + } + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } } - defer func() { execs.RunCmd = oldRunCmd }() // test adding update := getTestDevMonUpdate() - s.handleDeviceUpdate(ctx, update) - if !reflect.DeepEqual(got, want) { - t.Errorf("got %v, want %v", got, want) - } + t.Run("adding", func(t *testing.T) { test(t, update) }) // test removing + update = getTestDevMonUpdate() update.Add = false - s.handleDeviceUpdate(ctx, update) - if !reflect.DeepEqual(got, want) { - t.Errorf("got %v, want %v", got, want) - } + t.Run("removing", func(t *testing.T) { test(t, update) }) // test adding loopback device update = getTestDevMonUpdate() update.Type = "loopback" - s.handleDeviceUpdate(ctx, update) - if !reflect.DeepEqual(got, want) { - t.Errorf("got %v, want %v", got, want) - } + t.Run("adding loopback", func(t *testing.T) { test(t, update) }) // test adding vpn device update = getTestDevMonUpdate() update.Device = s.config.VPNConfig.Device.Name - s.handleDeviceUpdate(ctx, update) - if !reflect.DeepEqual(got, want) { - t.Errorf("got %v, want %v", got, want) - } + t.Run("adding vpn device", func(t *testing.T) { test(t, update) }) } // TestSplitRoutingHandleAddressUpdate tests handleAddressUpdate of SplitRouting. func TestSplitRoutingHandleAddressUpdate(t *testing.T) { - ctx := context.Background() // test with exclude conf := daemoncfg.NewConfig() conf.VPNConfig.Split.ExcludeIPv4 = []netip.Prefix{ netip.MustParsePrefix("0.0.0.0/32"), } - s := NewSplitRouting(conf) + s := NewSplitRouting(conf, make(chan *dnsproxy.Report)) s.devices.Add(getTestDevMonUpdate()) - got := []string{} - oldRunCmd := execs.RunCmd - execs.RunCmd = func(_ context.Context, _ string, s string, _ ...string) ([]byte, []byte, error) { - got = append(got, s) - return nil, nil, nil + // test helper + want := []netip.Prefix{} + got := []netip.Prefix{} + test := func(t *testing.T, update *addrmon.Update) { + done := make(chan struct{}) + go func() { + defer close(done) + s.handleAddressUpdate(update) + }() + select { + case got = <-s.Prefixes(): + case <-done: + } + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } } - defer func() { execs.RunCmd = oldRunCmd }() // test adding - want := []string{ - "flush set inet oc-daemon-routing excludes4\n" + - "flush set inet oc-daemon-routing excludes6\n" + - "add element inet oc-daemon-routing excludes4 { 192.168.1.1/32 }\n", - } update := getTestAddrMonUpdate(t, "192.168.1.1/32") - s.handleAddressUpdate(ctx, update) - if !reflect.DeepEqual(got, want) { - t.Errorf("got %v, want %v", got, want) + want = []netip.Prefix{ + netip.MustParsePrefix("192.168.1.1/32"), } + got = []netip.Prefix{} + t.Run("adding with exclude", func(t *testing.T) { test(t, update) }) // test removing - got = []string{} - want = []string{ - "flush set inet oc-daemon-routing excludes4\n" + - "flush set inet oc-daemon-routing excludes6\n", - } update.Add = false - s.handleAddressUpdate(ctx, update) - if !reflect.DeepEqual(got, want) { - t.Errorf("got %v, want %v", got, want) - } + want = []netip.Prefix{} + got = []netip.Prefix{} + t.Run("removing with exclude", func(t *testing.T) { test(t, update) }) // test with exclude and virtual conf = daemoncfg.NewConfig() @@ -111,101 +105,84 @@ func TestSplitRoutingHandleAddressUpdate(t *testing.T) { netip.MustParsePrefix("0.0.0.0/32"), } conf.VPNConfig.Split.ExcludeVirtualSubnetsOnlyIPv4 = true - s = NewSplitRouting(conf) + s = NewSplitRouting(conf, make(chan *dnsproxy.Report)) devUp := getTestDevMonUpdate() devUp.Type = "virtual" s.devices.Add(devUp) - got = []string{} - // test adding - want = []string{ - "flush set inet oc-daemon-routing excludes4\n" + - "flush set inet oc-daemon-routing excludes6\n" + - "add element inet oc-daemon-routing excludes4 { 192.168.1.1/32 }\n", - } update = getTestAddrMonUpdate(t, "192.168.1.1/32") - s.handleAddressUpdate(ctx, update) - if !reflect.DeepEqual(got, want) { - t.Errorf("got %v, want %v", got, want) + want = []netip.Prefix{ + netip.MustParsePrefix("192.168.1.1/32"), } + got = []netip.Prefix{} + t.Run("adding with exclude and virtual", func(t *testing.T) { test(t, update) }) // test double adding - s.handleAddressUpdate(ctx, update) - if !reflect.DeepEqual(got, want) { - t.Errorf("got %v, want %v", got, want) - } + want = []netip.Prefix{} + got = []netip.Prefix{} + t.Run("double adding with exclude and virtual", func(t *testing.T) { test(t, update) }) // test removing - got = []string{} - want = []string{ - "flush set inet oc-daemon-routing excludes4\n" + - "flush set inet oc-daemon-routing excludes6\n", - } update.Add = false - s.handleAddressUpdate(ctx, update) - if !reflect.DeepEqual(got, want) { - t.Errorf("got %v, want %v", got, want) - } + want = []netip.Prefix{} + got = []netip.Prefix{} + t.Run("removing with exclude and virtual", func(t *testing.T) { test(t, update) }) } // TestSplitRoutingHandleDNSReport tests handleDNSReport of SplitRouting. func TestSplitRoutingHandleDNSReport(t *testing.T) { - ctx := context.Background() - s := NewSplitRouting(daemoncfg.NewConfig()) - - got := []string{} - oldRunCmd := execs.RunCmd - execs.RunCmd = func(_ context.Context, _ string, s string, _ ...string) ([]byte, []byte, error) { - got = append(got, s) - return nil, nil, nil + s := NewSplitRouting(daemoncfg.NewConfig(), make(chan *dnsproxy.Report)) + + // test helper + want := []netip.Prefix{} + got := []netip.Prefix{} + test := func(t *testing.T, report *dnsproxy.Report) { + done := make(chan struct{}) + go func() { + defer close(done) + s.handleDNSReport(report) + <-report.Done() + }() + select { + case got = <-s.Prefixes(): + case <-done: + } + cmpPrefixes := func(a, b netip.Prefix) int { + c := a.Addr().Compare(b.Addr()) + if c == 0 { + return cmp.Compare(a.Bits(), b.Bits()) + } + return c + } + slices.SortFunc(want, cmpPrefixes) + slices.SortFunc(got, cmpPrefixes) + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } } - defer func() { execs.RunCmd = oldRunCmd }() // test ipv4 report := dnsproxy.NewReport("example.com", netip.MustParseAddr("192.168.1.1"), 300) - go s.handleDNSReport(ctx, report) - <-report.Done() - - want := []string{ - "flush set inet oc-daemon-routing excludes4\n" + - "flush set inet oc-daemon-routing excludes6\n" + - "add element inet oc-daemon-routing excludes4 { 192.168.1.1/32 }\n", - } - if !reflect.DeepEqual(got, want) { - t.Errorf("got %v, want %v", got, want) + want = []netip.Prefix{ + netip.MustParsePrefix("192.168.1.1/32"), } + got = []netip.Prefix{} + t.Run("ipv4", func(t *testing.T) { test(t, report) }) // test ipv6 - got = []string{} report = dnsproxy.NewReport("example.com", netip.MustParseAddr("2001::1"), 300) - go s.handleDNSReport(ctx, report) - <-report.Done() - - want = []string{ - "flush set inet oc-daemon-routing excludes4\n" + - "flush set inet oc-daemon-routing excludes6\n" + - "add element inet oc-daemon-routing excludes4 { 192.168.1.1/32 }\n" + - "add element inet oc-daemon-routing excludes6 { 2001::1/128 }\n", - "flush set inet oc-daemon-routing excludes4\n" + - "flush set inet oc-daemon-routing excludes6\n" + - "add element inet oc-daemon-routing excludes6 { 2001::1/128 }\n" + - "add element inet oc-daemon-routing excludes4 { 192.168.1.1/32 }\n", - } - if !reflect.DeepEqual(got[0], want[0]) && !reflect.DeepEqual(got[0], want[1]) { - t.Errorf("got %v, want %v or %v", got[0], want[0], want[1]) + want = []netip.Prefix{ + netip.MustParsePrefix("192.168.1.1/32"), + netip.MustParsePrefix("2001::1/128"), } + got = []netip.Prefix{} + t.Run("ipv6", func(t *testing.T) { test(t, report) }) } // TestSplitRoutingStartStop tests Start and Stop of SplitRouting. func TestSplitRoutingStartStop(t *testing.T) { // set dummy low level functions for testing - oldRunCmd := execs.RunCmd - execs.RunCmd = func(context.Context, string, string, ...string) ([]byte, []byte, error) { - return nil, nil, nil - } - defer func() { execs.RunCmd = oldRunCmd }() - oldRegisterAddrUpdates := addrmon.RegisterAddrUpdates addrmon.RegisterAddrUpdates = func(*addrmon.AddrMon) (chan netlink.AddrUpdate, error) { return nil, nil @@ -219,7 +196,7 @@ func TestSplitRoutingStartStop(t *testing.T) { defer func() { devmon.RegisterLinkUpdates = oldRegisterLinkUpdates }() // test with new config - s := NewSplitRouting(daemoncfg.NewConfig()) + s := NewSplitRouting(daemoncfg.NewConfig(), make(chan *dnsproxy.Report)) if err := s.Start(); err != nil { t.Error(err) } @@ -227,6 +204,8 @@ func TestSplitRoutingStartStop(t *testing.T) { // test with excludes conf := daemoncfg.NewConfig() + conf.VPNConfig.Gateway = netip.MustParseAddr("10.0.0.1") + conf.VPNConfig.IPv4 = netip.MustParsePrefix("192.168.0.1/24") conf.VPNConfig.Split.ExcludeIPv4 = []netip.Prefix{ netip.MustParsePrefix("0.0.0.0/32"), netip.MustParsePrefix("192.168.1.1/32"), @@ -235,49 +214,77 @@ func TestSplitRoutingStartStop(t *testing.T) { netip.MustParsePrefix("::/128"), netip.MustParsePrefix("2000::1/128"), } - s = NewSplitRouting(conf) - if err := s.Start(); err != nil { - t.Error(err) + s = NewSplitRouting(conf, make(chan *dnsproxy.Report)) + + want := []netip.Prefix{ + netip.MustParsePrefix("10.0.0.1/32"), + netip.MustParsePrefix("192.168.1.1/32"), + netip.MustParsePrefix("2000::1/128"), } - s.Stop() + got := []netip.Prefix{} - // test with vpn address - conf = daemoncfg.NewConfig() - conf.VPNConfig.IPv4 = netip.MustParsePrefix("192.168.1.1/24") - s = NewSplitRouting(daemoncfg.NewConfig()) + done := make(chan struct{}) + go func(prefixes <-chan []netip.Prefix) { + defer close(done) + got = <-prefixes + + }(s.Prefixes()) if err := s.Start(); err != nil { t.Error(err) } + <-done s.Stop() + cmpPrefixes := func(a, b netip.Prefix) int { + c := a.Addr().Compare(b.Addr()) + if c == 0 { + return cmp.Compare(a.Bits(), b.Bits()) + } + return c + } + slices.SortFunc(want, cmpPrefixes) + slices.SortFunc(got, cmpPrefixes) + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } + // test with events - s = NewSplitRouting(daemoncfg.NewConfig()) + dnsReports := make(chan *dnsproxy.Report) + s = NewSplitRouting(daemoncfg.NewConfig(), dnsReports) + + want = []netip.Prefix{ + netip.MustParsePrefix("192.168.1.1/32"), + } + got = []netip.Prefix{} + + done = make(chan struct{}) + go func(prefixes <-chan []netip.Prefix) { + defer close(done) + for p := range prefixes { + got = p + } + }(s.Prefixes()) if err := s.Start(); err != nil { t.Error(err) } s.devmon.Updates() <- getTestDevMonUpdate() s.addrmon.Updates() <- getTestAddrMonUpdate(t, "192.168.1.1/32") report := dnsproxy.NewReport("example.com", netip.MustParseAddr("192.168.1.1"), 300) - s.dnsreps <- report + dnsReports <- report <-report.Done() s.Stop() - // test with nft errors - execs.RunCmd = func(context.Context, string, string, ...string) ([]byte, []byte, error) { - return nil, nil, errors.New("test error") - } - s = NewSplitRouting(daemoncfg.NewConfig()) - if err := s.Start(); err != nil { - t.Error(err) + <-done + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) } - s.Stop() } -// TestSplitRoutingDNSReports tests DNSReports of SplitRouting. -func TestSplitRoutingDNSReports(t *testing.T) { - s := NewSplitRouting(daemoncfg.NewConfig()) - want := s.dnsreps - got := s.DNSReports() +// TestSplitRoutingPrefixes tests Prefixes of SplitRouting. +func TestSplitRoutingPrefixes(t *testing.T) { + s := NewSplitRouting(daemoncfg.NewConfig(), make(chan *dnsproxy.Report)) + want := s.prefixes + got := s.Prefixes() if got != want { t.Errorf("got %p, want %p", got, want) } @@ -285,7 +292,7 @@ func TestSplitRoutingDNSReports(t *testing.T) { // TestSplitRoutingGetState tests GetState of SplitRouting. func TestSplitRoutingGetState(t *testing.T) { - s := NewSplitRouting(daemoncfg.NewConfig()) + s := NewSplitRouting(daemoncfg.NewConfig(), make(chan *dnsproxy.Report)) // set devices dev := &devmon.Update{ @@ -333,16 +340,20 @@ func TestSplitRoutingGetState(t *testing.T) { // TestNewSplitRouting tests NewSplitRouting. func TestNewSplitRouting(t *testing.T) { config := daemoncfg.NewConfig() - s := NewSplitRouting(config) + dnsReports := make(chan *dnsproxy.Report) + s := NewSplitRouting(config, dnsReports) if s.config != config { t.Errorf("got %p, want %p", s.config, config) } + if s.dnsreps != dnsReports { + t.Errorf("got %p, want %p", s.dnsreps, dnsReports) + } if s.devmon == nil || s.addrmon == nil || s.devices == nil || s.addrs == nil || s.excludes == nil || - s.dnsreps == nil || + s.prefixes == nil || s.done == nil || s.closed == nil { diff --git a/internal/vpnsetup/vpnsetup.go b/internal/vpnsetup/vpnsetup.go index 10bb5c9..b1d808b 100644 --- a/internal/vpnsetup/vpnsetup.go +++ b/internal/vpnsetup/vpnsetup.go @@ -4,6 +4,7 @@ package vpnsetup import ( "context" "errors" + "net/netip" "slices" "strings" "time" @@ -39,6 +40,7 @@ type command struct { // VPNSetup sets up the configuration of the vpn tunnel that belongs to the // current VPN connection. type VPNSetup struct { + config *daemoncfg.Config splitrt *splitrt.SplitRouting dnsProxy *dnsproxy.Proxy @@ -282,6 +284,9 @@ func (v *VPNSetup) stopEnsure() { // setup sets up the vpn configuration. func (v *VPNSetup) setup(ctx context.Context, conf *daemoncfg.Config) { + // set config + v.config = conf + // configure dns proxy // - set remotes // - set watches @@ -291,6 +296,12 @@ func (v *VPNSetup) setup(ctx context.Context, conf *daemoncfg.Config) { log.WithField("excludes", excludes).Debug("Daemon setting DNS Split Excludes") v.dnsProxy.SetWatches(excludes) + // configure split routing + v.splitrt = splitrt.NewSplitRouting(conf, v.dnsProxy.Reports()) + if err := v.splitrt.Start(); err != nil { + log.WithError(err).Error("VPNSetup error setting split routing") + } + cmds, err := cmdtmpl.GetCmds("VPNSetupSetup", conf) if err != nil { log.WithError(err).Error("VPNSetup could not get setup commands") @@ -308,12 +319,6 @@ func (v *VPNSetup) setup(ctx context.Context, conf *daemoncfg.Config) { } } - // configure split routing - v.splitrt = splitrt.NewSplitRouting(conf) - if err := v.splitrt.Start(); err != nil { - log.WithError(err).Error("VPNSetup error setting split routing") - } - // ensure VPN config v.startEnsure(ctx, conf) } @@ -327,6 +332,7 @@ func (v *VPNSetup) teardown(ctx context.Context, conf *daemoncfg.Config) { v.splitrt.Stop() v.splitrt = nil + // tear down device, routing, dns cmds, err := cmdtmpl.GetCmds("VPNSetupTeardown", conf) if err != nil { log.WithError(err).Error("VPNSetup could not get teardown commands") @@ -351,6 +357,8 @@ func (v *VPNSetup) teardown(ctx context.Context, conf *daemoncfg.Config) { v.dnsProxy.SetRemotes(remotes) v.dnsProxy.SetWatches([]string{}) + // unset config + v.config = nil } // getState gets the internal state. @@ -379,20 +387,31 @@ func (v *VPNSetup) handleCommand(ctx context.Context, c *command) { } } -// handleDNSReport handles a DNS report. -func (v *VPNSetup) handleDNSReport(r *dnsproxy.Report) { - log.WithField("report", r).Debug("Daemon handling DNS report") - - if v.splitrt == nil { - // split routing not active, close report and do not forward - r.Close() - return +// handlePrefixes handles a prefixes update from split routing. +func (v *VPNSetup) handlePrefixes(ctx context.Context, config *daemoncfg.Config, prefixes []netip.Prefix) { + data := &struct { + daemoncfg.Config + Addresses []netip.Prefix + }{ + Config: *config, + Addresses: prefixes, } - - // forward report to split routing - select { - case v.splitrt.DNSReports() <- r: - case <-v.done: + cmds, err := cmdtmpl.GetCmds("VPNSetupSetExcludes", data) + if err != nil { + log.WithError(err).Error("VPNSetup could not get set excludes commands") + } + for _, c := range cmds { + if stdout, stderr, err := c.Run(ctx); err != nil { + log.WithFields(log.Fields{ + "addresses": prefixes, + "command": c.Cmd, + "args": c.Args, + "stdin": c.Stdin, + "stdout": string(stdout), + "stderr": string(stderr), + "error": err, + }).Error("VPNSetup could not run set excludes command") + } } } @@ -408,11 +427,24 @@ func (v *VPNSetup) start() { defer v.dnsProxy.Stop() for { + dnsReports := v.dnsProxy.Reports() + var prefixes <-chan []netip.Prefix + if v.splitrt != nil { + // split routing active + // do not handle dns reports here + dnsReports = nil + // handle prefixes from split routing + prefixes = v.splitrt.Prefixes() + } + select { case c := <-v.cmds: v.handleCommand(ctx, c) - case r := <-v.dnsProxy.Reports(): - v.handleDNSReport(r) + case r := <-dnsReports: + // split routing not active, close dns report + r.Close() + case p := <-prefixes: + v.handlePrefixes(ctx, v.config, p) case <-v.done: return }