diff --git a/internal/splitrt/excludes.go b/internal/splitrt/excludes.go index 5046bce..be03487 100644 --- a/internal/splitrt/excludes.go +++ b/internal/splitrt/excludes.go @@ -3,6 +3,7 @@ package splitrt import ( "context" "net" + "net/netip" "sync" "time" @@ -14,10 +15,8 @@ const ( excludesTimer = 300 ) -// exclude is a split excludes entry. -type exclude struct { - net *net.IPNet - static bool +// dynExclude is a dynamic split excludes entry. +type dynExclude struct { ttl uint32 updated bool } @@ -25,61 +24,71 @@ type exclude struct { // Excludes contains split Excludes. type Excludes struct { sync.Mutex - m map[string]*exclude + s map[string]*netip.Prefix + d map[netip.Addr]*dynExclude done chan struct{} closed chan struct{} } -// addFilter adds the exclude to netfilter. -func (e *Excludes) addFilter(ctx context.Context, exclude *exclude) { - log.WithField("address", exclude.net).Debug("SplitRouting adding exclude to netfilter") - addExclude(ctx, exclude.net) -} - // setFilter resets the excludes in netfilter. func (e *Excludes) setFilter(ctx context.Context) { log.Debug("SplitRouting resetting excludes in netfilter") - addresses := []*net.IPNet{} - for _, v := range e.m { - addresses = append(addresses, v.net) + addresses := []*netip.Prefix{} + for _, v := range e.s { + addresses = append(addresses, v) + } + for k := range e.d { + prefix := netip.PrefixFrom(k, k.BitLen()) + addresses = append(addresses, &prefix) } setExcludes(ctx, addresses) } -// add adds the exclude entry for ip to the split excludes. -func (e *Excludes) add(ctx context.Context, ip *net.IPNet, exclude *exclude) { +// AddStatic adds a static entry to the split excludes. +func (e *Excludes) AddStatic(ctx context.Context, address *net.IPNet) { + log.WithField("address", address).Debug("SplitRouting adding static exclude") + + a, err := netip.ParsePrefix(address.String()) + if err != nil { + log.WithError(err).Error("SplitRouting could not parse static exclude") + return + } + e.Lock() defer e.Unlock() - key := ip.String() - old := e.m[key] + // make sure new prefix in address does not overlap with existing + // prefixes in static excludes + removed := false + for k, v := range e.s { + if !v.Overlaps(a) { + // no overlap + continue + } + if v.Bits() <= a.Bits() { + // new prefix is already in existing prefix, + // do not add it + return + } - // new entry, just add it - if old == nil { - e.m[key] = exclude - e.addFilter(ctx, exclude) - return + // new prefix contains old prefix, remove old prefix + delete(e.s, k) + removed = true } - // old entry exists, update values - if old.static { - // static entry is not updated + // add new prefix to static excludes + key := address.String() + e.s[key] = &a + + // add to netfilter + if removed { + // existing entries removed, we need to reset all excludes + e.setFilter(ctx) return } - // update entry - old.static = exclude.static - old.ttl = exclude.ttl - old.updated = true -} - -// AddStatic adds a static entry to the split excludes. -func (e *Excludes) AddStatic(ctx context.Context, address *net.IPNet) { - log.WithField("address", address).Debug("SplitRouting adding static exclude") - e.add(ctx, address, &exclude{ - net: address, - static: true, - }) + // single new entry, add it + addExclude(ctx, &a) } // AddDynamic adds a dynamic entry to the split excludes. @@ -88,34 +97,62 @@ func (e *Excludes) AddDynamic(ctx context.Context, address *net.IPNet, ttl uint3 "address": address, "ttl": ttl, }).Debug("SplitRouting adding dynamic exclude") - e.add(ctx, address, &exclude{ - net: address, + + prefix, err := netip.ParsePrefix(address.String()) + if err != nil { + log.WithError(err).Error("SplitRouting could not parse dynamic exclude") + return + } + if !prefix.IsSingleIP() { + log.WithError(err).Error("SplitRouting error adding dynamic exclude with multiple IPs") + return + } + a := prefix.Addr() + + e.Lock() + defer e.Unlock() + + // make sure new ip address is not in existing static excludes + for _, v := range e.s { + if v.Contains(a) { + return + } + } + + // update existing entry in dynamic excludes + old := e.d[a] + if old != nil { + old.ttl = ttl + old.updated = true + return + } + + // create new entry in dynamic excludes + e.d[a] = &dynExclude{ ttl: ttl, updated: true, - }) + } + + // add to netfilter + addExclude(ctx, &prefix) } -// Remove removes an entry from the split excludes. -func (e *Excludes) Remove(ctx context.Context, address *net.IPNet) { +// RemoveStatic removes a static entry from the split excludes. +func (e *Excludes) RemoveStatic(ctx context.Context, address *net.IPNet) { e.Lock() defer e.Unlock() - delete(e.m, address.String()) + delete(e.s, address.String()) e.setFilter(ctx) } -// cleanup cleans up the split excludes. +// cleanup cleans up the dynamic split excludes. func (e *Excludes) cleanup(ctx context.Context) { e.Lock() defer e.Unlock() changed := false - for k, v := range e.m { - // skip static entries - if v.static { - continue - } - + for k, v := range e.d { // skip recently updated entries if v.updated { v.updated = false @@ -124,7 +161,7 @@ func (e *Excludes) cleanup(ctx context.Context) { // exclude expired entries if v.ttl < excludesTimer { - delete(e.m, k) + delete(e.d, k) changed = true continue } @@ -176,7 +213,8 @@ func (e *Excludes) Stop() { // NewExcludes returns new split excludes. func NewExcludes() *Excludes { return &Excludes{ - m: make(map[string]*exclude), + s: make(map[string]*netip.Prefix), + d: make(map[netip.Addr]*dynExclude), done: make(chan struct{}), closed: make(chan struct{}), } diff --git a/internal/splitrt/excludes_test.go b/internal/splitrt/excludes_test.go index b290037..872aa8b 100644 --- a/internal/splitrt/excludes_test.go +++ b/internal/splitrt/excludes_test.go @@ -11,12 +11,9 @@ import ( ) // getTestExcludes returns excludes for testing. -func getTestExcludes(t *testing.T) []*net.IPNet { +func getTestExcludes(t *testing.T, es []string) []*net.IPNet { excludes := []*net.IPNet{} - for _, s := range []string{ - "192.168.1.0/24", - "2001::/64", - } { + for _, s := range es { _, exclude, err := net.ParseCIDR(s) if err != nil { t.Fatal(err) @@ -26,11 +23,48 @@ func getTestExcludes(t *testing.T) []*net.IPNet { return excludes } +// getTestStaticExcludes returns static excludes for testing. +func getTestStaticExcludes(t *testing.T) []*net.IPNet { + return getTestExcludes(t, []string{ + "192.168.1.0/24", + "2001::/64", + }) +} + +// getTestStaticExcludesOverlap returns static excludes that overlap for testing. +func getTestStaticExcludesOverlap(t *testing.T) []*net.IPNet { + return getTestExcludes(t, []string{ + "192.168.1.0/26", + "192.168.1.64/26", + "192.168.1.128/26", + "192.168.1.192/26", + "192.168.1.0/25", + "192.168.1.128/25", + "192.168.1.0/24", + "2001:2001:2001:2000::/64", + "2001:2001:2001:2001::/64", + "2001:2001:2001:2002::/64", + "2001:2001:2001:2003::/64", + "2001:2001:2001:2000::/63", + "2001:2001:2001:2002::/63", + "2001:2001:2001:2000::/56", + }) +} + +// getTestDynamicExcludes returns dynamic excludes for testing. +func getTestDynamicExcludes(t *testing.T) []*net.IPNet { + return getTestExcludes(t, []string{ + "192.168.1.1/32", + "2001::1/128", + "172.16.1.1/32", + }) +} + // TestExcludesAddStatic tests AddStatic of Excludes. func TestExcludesAddStatic(t *testing.T) { ctx := context.Background() e := NewExcludes() - excludes := getTestExcludes(t) + excludes := getTestStaticExcludes(t) // set testing runNft function got := []string{} @@ -58,13 +92,24 @@ func TestExcludesAddStatic(t *testing.T) { if !reflect.DeepEqual(got, want) { t.Errorf("got %v, want %v", got, want) } + + // test adding overlapping excludes + e = NewExcludes() + for _, exclude := range getTestStaticExcludesOverlap(t) { + e.AddStatic(ctx, exclude) + } + for k := range e.s { + if k != "192.168.1.0/24" && k != "2001:2001:2001:2000::/56" { + t.Errorf("unexpected key: %s", k) + } + } } // TestExcludesAddDynamic tests AddDynamic of Excludes. func TestExcludesAddDynamic(t *testing.T) { ctx := context.Background() e := NewExcludes() - excludes := getTestExcludes(t) + excludes := getTestDynamicExcludes(t) // set testing runNft function got := []string{} @@ -75,8 +120,9 @@ func TestExcludesAddDynamic(t *testing.T) { // test adding excludes want := []string{ - "add element inet oc-daemon-routing excludes4 { 192.168.1.0/24 }", - "add element inet oc-daemon-routing excludes6 { 2001::/64 }", + "add element inet oc-daemon-routing excludes4 { 192.168.1.1/32 }", + "add element inet oc-daemon-routing excludes6 { 2001::1/128 }", + "add element inet oc-daemon-routing excludes4 { 172.16.1.1/32 }", } for _, exclude := range excludes { e.AddDynamic(ctx, exclude, 300) @@ -92,13 +138,41 @@ func TestExcludesAddDynamic(t *testing.T) { if !reflect.DeepEqual(got, want) { t.Errorf("got %v, want %v", got, want) } + + // test adding excludes with existing static excludes, + // should only add new excludes + e = NewExcludes() + for _, exclude := range getTestStaticExcludes(t) { + e.AddStatic(ctx, exclude) + } + got = []string{} + want = []string{ + "add element inet oc-daemon-routing excludes4 { 172.16.1.1/32 }", + } + for _, exclude := range excludes { + e.AddDynamic(ctx, exclude, 300) + } + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } + + // test adding invalid excludes (static as dynamic) + e = NewExcludes() + got = []string{} + want = []string{} + for _, exclude := range getTestStaticExcludes(t) { + e.AddDynamic(ctx, exclude, 300) + } + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } } -// TestExcludesRemove tests Remove of Excludes. +// TestExcludesRemoveStatic tests RemoveStatic of Excludes. func TestExcludesRemove(t *testing.T) { ctx := context.Background() e := NewExcludes() - excludes := getTestExcludes(t) + excludes := getTestStaticExcludes(t) // set testing runNft function got := []string{} @@ -117,7 +191,7 @@ func TestExcludesRemove(t *testing.T) { "flush set inet oc-daemon-routing excludes6\n", } for _, exclude := range excludes { - e.Remove(ctx, exclude) + e.RemoveStatic(ctx, exclude) } if !reflect.DeepEqual(got, want) { t.Errorf("got %v, want %v", got, want) @@ -138,36 +212,50 @@ func TestExcludesRemove(t *testing.T) { e.AddStatic(ctx, exclude) } for _, exclude := range excludes { - e.Remove(ctx, exclude) + e.RemoveStatic(ctx, exclude) } if !reflect.DeepEqual(got, want) { t.Errorf("got %v, want %v", got, want) } - // test removing dynamic excludes - // should have same nft commands as static case + // test with nft error got = []string{} + execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) ([]byte, []byte, error) { + got = append(got, s) + return nil, nil, errors.New("test error") + } for _, exclude := range excludes { - e.AddDynamic(ctx, exclude, 300) + e.AddStatic(ctx, exclude) } for _, exclude := range excludes { - e.Remove(ctx, exclude) + e.RemoveStatic(ctx, exclude) } if !reflect.DeepEqual(got, want) { t.Errorf("got %v, want %v", got, want) } - // test with nft error + // test removing with dynamic excludes got = []string{} - execs.RunCmd = func(ctx context.Context, cmd string, s string, arg ...string) ([]byte, []byte, error) { - got = append(got, s) - return nil, nil, errors.New("test error") + want = []string{ + "add element inet oc-daemon-routing excludes4 { 192.168.1.0/24 }", + "add element inet oc-daemon-routing excludes6 { 2001::/64 }", + "add element inet oc-daemon-routing excludes4 { 172.16.1.1/32 }", + "flush set inet oc-daemon-routing excludes4\n" + + "flush set inet oc-daemon-routing excludes6\n" + + "add element inet oc-daemon-routing excludes6 { 2001::/64 }\n" + + "add element inet oc-daemon-routing excludes4 { 172.16.1.1/32 }\n", + "flush set inet oc-daemon-routing excludes4\n" + + "flush set inet oc-daemon-routing excludes6\n" + + "add element inet oc-daemon-routing excludes4 { 172.16.1.1/32 }\n", } for _, exclude := range excludes { e.AddStatic(ctx, exclude) } + for _, exclude := range getTestDynamicExcludes(t) { + e.AddDynamic(ctx, exclude, 300) + } for _, exclude := range excludes { - e.Remove(ctx, exclude) + e.RemoveStatic(ctx, exclude) } if !reflect.DeepEqual(got, want) { t.Errorf("got %v, want %v", got, want) @@ -178,7 +266,6 @@ func TestExcludesRemove(t *testing.T) { func TestExcludesCleanup(t *testing.T) { ctx := context.Background() e := NewExcludes() - excludes := getTestExcludes(t) // set testing runNft function got := []string{} @@ -195,7 +282,7 @@ func TestExcludesCleanup(t *testing.T) { } // test with dynamic excludes - for _, exclude := range excludes { + for _, exclude := range getTestDynamicExcludes(t) { e.AddDynamic(ctx, exclude, excludesTimer) } @@ -217,7 +304,7 @@ func TestExcludesCleanup(t *testing.T) { } // test with static excludes - for _, exclude := range excludes { + for _, exclude := range getTestStaticExcludes(t) { e.AddStatic(ctx, exclude) } got = []string{} @@ -238,7 +325,9 @@ func TestExcludesStartStop(_ *testing.T) { // TestNewExcludes tests NewExcludes. func TestNewExcludes(t *testing.T) { e := NewExcludes() - if e.m == nil || + if e == nil || + e.s == nil || + e.d == nil || e.done == nil || e.closed == nil { diff --git a/internal/splitrt/filter.go b/internal/splitrt/filter.go index afcdb21..6e59411 100644 --- a/internal/splitrt/filter.go +++ b/internal/splitrt/filter.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net" + "net/netip" "strings" log "github.com/sirupsen/logrus" @@ -190,9 +191,11 @@ func rejectIPv4(ctx context.Context, device string) { } // addExclude adds exclude address to netfilter. -func addExclude(ctx context.Context, address *net.IPNet) { +func addExclude(ctx context.Context, address *netip.Prefix) { + log.WithField("address", address).Debug("SplitRouting adding exclude to netfilter") + set := "excludes4" - if address.IP.To4() == nil { + if address.Addr().Is6() { set = "excludes6" } @@ -207,7 +210,7 @@ func addExclude(ctx context.Context, address *net.IPNet) { } // setExcludes resets the excludes to addresses in netfilter. -func setExcludes(ctx context.Context, addresses []*net.IPNet) { +func setExcludes(ctx context.Context, addresses []*netip.Prefix) { // flush existing entries nftconf := "" nftconf += "flush set inet oc-daemon-routing excludes4\n" @@ -216,7 +219,7 @@ func setExcludes(ctx context.Context, addresses []*net.IPNet) { // add entries for _, a := range addresses { set := "excludes4" - if a.IP.To4() == nil { + if a.Addr().Is6() { set = "excludes6" } nftconf += fmt.Sprintf( diff --git a/internal/splitrt/splitrt.go b/internal/splitrt/splitrt.go index bb2c94b..081454c 100644 --- a/internal/splitrt/splitrt.go +++ b/internal/splitrt/splitrt.go @@ -158,7 +158,7 @@ func (s *SplitRouting) updateLocalNetworkExcludes(ctx context.Context) { // remove old excludes for _, l := range s.locals { if !isIn(l, excludes) { - s.excludes.Remove(ctx, l) + s.excludes.RemoveStatic(ctx, l) } }