From 67901a996fa55076898da2a0bb3667075f93d253 Mon Sep 17 00:00:00 2001 From: Elliot Cubit Date: Fri, 26 May 2023 14:20:46 -0400 Subject: [PATCH 1/5] Format with golines --- cidranger.go | 9 +++++++-- cidranger_test.go | 33 ++++++++++++++++++++++++++++----- trie_test.go | 20 +++++++++++++++++--- 3 files changed, 52 insertions(+), 10 deletions(-) diff --git a/cidranger.go b/cidranger.go index 66f6685..eb95b80 100644 --- a/cidranger.go +++ b/cidranger.go @@ -271,8 +271,13 @@ func getRollupEntries(trie *prefixTrie, f RollupApply) []RangerEntry { // If both have an entry, check to rollup if node.children[0].hasEntry() && node.children[1].hasEntry() { if f.CanRollup(node.children[0].entry, node.children[1].entry) { - rollupEntries = append(rollupEntries, - f.GetParentEntry(node.children[0].entry, node.children[1].entry, node.network.IPNet), + rollupEntries = append( + rollupEntries, + f.GetParentEntry( + node.children[0].entry, + node.children[1].entry, + node.network.IPNet, + ), ) } continue diff --git a/cidranger_test.go b/cidranger_test.go index 03389a5..1af558e 100644 --- a/cidranger_test.go +++ b/cidranger_test.go @@ -129,7 +129,13 @@ func TestSubnets(t *testing.T) { name string }{ {"0.0.0.0/8", 33, nil, rnet.ErrBadMaskLength, "IPv4 prefix too long"}, - {"0.0.0.0/0", 2, []string{"0.0.0.0/2", "64.0.0.0/2", "128.0.0.0/2", "192.0.0.0/2"}, nil, "IPv4 /0 to /2"}, + { + "0.0.0.0/0", + 2, + []string{"0.0.0.0/2", "64.0.0.0/2", "128.0.0.0/2", "192.0.0.0/2"}, + nil, + "IPv4 /0 to /2", + }, {"10.0.0.0/8", 0, []string{"10.0.0.0/9", "10.128.0.0/9"}, nil, "IPv4 default split /8"}, {"::/2", 4, []string{"::/4", "1000::/4", "2000::/4", "3000::/4"}, nil, "IPv6 /2 to /4"}, {"10.0.0.0/15", 15, []string{"10.0.0.0/15"}, nil, "IPv4 prefix self"}, @@ -165,7 +171,12 @@ func TestBredthRangerIter(t *testing.T) { }{ {rnet.IPv4, []string{}, []string{}, "empty"}, {rnet.IPv4, []string{"1.2.3.4/15"}, []string{"1.2.3.4/15"}, "single v4"}, - {rnet.IPv4, []string{"255.0.0.0/8", "0.0.0.0/8"}, []string{"0.0.0.0/8", "255.0.0.0/8"}, "ordering v4"}, + { + rnet.IPv4, + []string{"255.0.0.0/8", "0.0.0.0/8"}, + []string{"0.0.0.0/8", "255.0.0.0/8"}, + "ordering v4", + }, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { @@ -218,7 +229,11 @@ func (rc rollupCount) CanRollup(c0 RangerEntry, c1 RangerEntry) bool { return rc0.Records+rc1.Records < rc.rollupThreshold } -func (rc rollupCount) GetParentEntry(c0 RangerEntry, c1 RangerEntry, parentNet net.IPNet) RangerEntry { +func (rc rollupCount) GetParentEntry( + c0 RangerEntry, + c1 RangerEntry, + parentNet net.IPNet, +) RangerEntry { rc0, ok := c0.(RecordEntry) if !ok { panic(c0) @@ -332,10 +347,18 @@ func BenchmarkBruteRangerHitContainingNetworksIPv4UsingAWSRanges(b *testing.B) { } func BenchmarkPCTrieHitContainingNetworksIPv6UsingAWSRanges(b *testing.B) { - benchmarkContainingNetworksUsingAWSRanges(b, net.ParseIP("2620:107:300f::36b7:ff81"), NewPCTrieRanger()) + benchmarkContainingNetworksUsingAWSRanges( + b, + net.ParseIP("2620:107:300f::36b7:ff81"), + NewPCTrieRanger(), + ) } func BenchmarkBruteRangerHitContainingNetworksIPv6UsingAWSRanges(b *testing.B) { - benchmarkContainingNetworksUsingAWSRanges(b, net.ParseIP("2620:107:300f::36b7:ff81"), newBruteRanger()) + benchmarkContainingNetworksUsingAWSRanges( + b, + net.ParseIP("2620:107:300f::36b7:ff81"), + newBruteRanger(), + ) } func BenchmarkPCTrieMissContainingNetworksIPv4UsingAWSRanges(b *testing.B) { diff --git a/trie_test.go b/trie_test.go index da7b777..2ac9a9a 100644 --- a/trie_test.go +++ b/trie_test.go @@ -79,7 +79,12 @@ func TestPrefixTrieInsert(t *testing.T) { assert.NoError(t, err) } - assert.Equal(t, len(tc.expectedNetworksInDepthOrder), trie.Len(), "trie size should match") + assert.Equal( + t, + len(tc.expectedNetworksInDepthOrder), + trie.Len(), + "trie size should match", + ) allNetworks, err := trie.CoveredNetworks(*getAllByVersion(tc.version)) assert.Nil(t, err) @@ -213,7 +218,12 @@ func TestPrefixTrieRemove(t *testing.T) { } } - assert.Equal(t, len(tc.expectedNetworksInDepthOrder), trie.Len(), "trie size should match after revmoval") + assert.Equal( + t, + len(tc.expectedNetworksInDepthOrder), + trie.Len(), + "trie size should match after revmoval", + ) allNetworks, err := trie.CoveredNetworks(*getAllByVersion(tc.version)) assert.Nil(t, err) @@ -686,7 +696,11 @@ func TestTrieMemUsage(t *testing.T) { // Assert that heap allocation from first loop is within set threshold of avg over all runs. assert.Less(t, uint64(0), baseLineHeap) - assert.LessOrEqual(t, float64(baseLineHeap), float64(totalHeapAllocOverRuns/uint64(runs))*thresh) + assert.LessOrEqual( + t, + float64(baseLineHeap), + float64(totalHeapAllocOverRuns/uint64(runs))*thresh, + ) } func GenLeafIPNet(ip net.IP) net.IPNet { From bf85f45cb5900d0a8612ae4eca1bde503abc49ec Mon Sep 17 00:00:00 2001 From: Elliot Cubit Date: Fri, 26 May 2023 14:20:46 -0400 Subject: [PATCH 2/5] Add GetPath method to RangerIter --- cidranger.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/cidranger.go b/cidranger.go index eb95b80..5516b96 100644 --- a/cidranger.go +++ b/cidranger.go @@ -150,6 +150,7 @@ func Subnets(base net.IPNet, prefixlen int) (subnets []net.IPNet, err error) { type RangerIter interface { Next() bool Get() RangerEntry + GetPath() []RangerEntry Error() error } @@ -206,6 +207,16 @@ func (i *bredthRangerIter) Get() RangerEntry { return i.node.entry } +func (i *bredthRangerIter) GetPath() []RangerEntry { + retv := make([]RangerEntry, 0) + for this := i.node; this.parent != nil; this = this.parent { + if this.hasEntry() { + retv = append(retv, this.entry) + } + } + return retv +} + func (i *bredthRangerIter) Error() error { return nil } From 313d871f0bd9c65045fa323931123e8187596355 Mon Sep 17 00:00:00 2001 From: Elliot Cubit Date: Fri, 26 May 2023 14:20:46 -0400 Subject: [PATCH 3/5] Add TraversalMethod for RangerIter --- cidranger.go | 42 ++++++++++++++++++++++++----------- cidranger_test.go | 56 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+), 13 deletions(-) diff --git a/cidranger.go b/cidranger.go index 5516b96..2ea5393 100644 --- a/cidranger.go +++ b/cidranger.go @@ -59,6 +59,13 @@ var AllIPv4 = parseCIDRUnsafe("0.0.0.0/0") // AllIPv6 is a IPv6 CIDR that contains all networks var AllIPv6 = parseCIDRUnsafe("0::0/0") +type TraversalMethod int + +const ( + TraversalMethodBreadth = iota + TraversalMethodDepth = iota +) + func parseCIDRUnsafe(s string) *net.IPNet { _, cidr, _ := net.ParseCIDR(s) return cidr @@ -154,36 +161,38 @@ type RangerIter interface { Error() error } -type bredthRangerIter struct { +type rangerIter struct { path *list.List node *prefixTrie shallow bool + method TraversalMethod } // A bredth-first iterator that returns all netblocks with a RangerEntry -func NewBredthIter(r Ranger) bredthRangerIter { - return newBredthIter(r, false) +func NewBredthIter(r Ranger) *rangerIter { + return NewIter(r, false, TraversalMethodBreadth) } // A bredth-first iterator that will return only the largest netblocks with an entry -func NewShallowBredthIter(r Ranger) bredthRangerIter { - return newBredthIter(r, true) +func NewShallowBredthIter(r Ranger) *rangerIter { + return NewIter(r, true, TraversalMethodBreadth) } -func newBredthIter(r Ranger, shallow bool) bredthRangerIter { +func NewIter(r Ranger, shallow bool, method TraversalMethod) *rangerIter { root, ok := r.(*prefixTrie) if !ok { panic(fmt.Errorf("Invalid type for bredthRangerIter")) } - iter := bredthRangerIter{ + iter := rangerIter{ node: root, path: list.New(), shallow: shallow, + method: method, } iter.path.PushBack(root) - return iter + return &iter } -func (i *bredthRangerIter) Next() bool { +func (i *rangerIter) Next() bool { for i.path.Len() > 0 { element := i.path.Front() i.path.Remove(element) @@ -193,7 +202,14 @@ func (i *bredthRangerIter) Next() bool { } for _, child := range i.node.children { if child != nil { - i.path.PushBack(child) + switch i.method { + case TraversalMethodBreadth: + i.path.PushBack(child) + case TraversalMethodDepth: + i.path.PushFront(child) + default: + panic(fmt.Sprintf("Unrecognized TraversalMethod %v", i.method)) + } } } if i.node.hasEntry() { @@ -203,11 +219,11 @@ func (i *bredthRangerIter) Next() bool { return false } -func (i *bredthRangerIter) Get() RangerEntry { +func (i *rangerIter) Get() RangerEntry { return i.node.entry } -func (i *bredthRangerIter) GetPath() []RangerEntry { +func (i *rangerIter) GetPath() []RangerEntry { retv := make([]RangerEntry, 0) for this := i.node; this.parent != nil; this = this.parent { if this.hasEntry() { @@ -217,7 +233,7 @@ func (i *bredthRangerIter) GetPath() []RangerEntry { return retv } -func (i *bredthRangerIter) Error() error { +func (i *rangerIter) Error() error { return nil } diff --git a/cidranger_test.go b/cidranger_test.go index 1af558e..65d8ec5 100644 --- a/cidranger_test.go +++ b/cidranger_test.go @@ -202,6 +202,62 @@ func TestBredthRangerIter(t *testing.T) { } } +func TestDepthRangerIter(t *testing.T) { + cases := []struct { + version rnet.IPVersion + inserts []string + expected []string + name string + }{ + {rnet.IPv4, []string{}, []string{}, "empty"}, + {rnet.IPv4, []string{"1.2.3.4/15"}, []string{"1.2.3.4/15"}, "single v4"}, + { + rnet.IPv4, + []string{ + "255.255.0.0/16", + "8.0.0.0/8", + "255.255.254.0/24", + "255.255.253.128/25", + "255.254.0.0/16", + "255.255.255.0/24", + "255.0.0.0/8", + }, + []string{ + "255.0.0.0/8", + "255.255.0.0/16", + "255.255.255.0/24", + "255.255.254.0/24", + "255.255.253.128/25", + "255.254.0.0/16", + "8.0.0.0/8", + }, + "nest, out-of-order inserts", + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + trie := newPrefixTree(tc.version) + for _, insert := range tc.inserts { + _, network, _ := net.ParseCIDR(insert) + err := trie.Insert(NewBasicRangerEntry(*network)) + assert.NoError(t, err) + } + var expectedEntries []net.IPNet + for _, expected := range tc.expected { + _, network, _ := net.ParseCIDR(expected) + expectedEntries = append(expectedEntries, (*network)) + } + var resultEntries []net.IPNet + iter := NewIter(trie.(*prefixTrie), false, TraversalMethodDepth) + for iter.Next() { + entry := iter.Get() + resultEntries = append(resultEntries, entry.Network()) + } + assert.Equal(t, expectedEntries, resultEntries) + }) + } +} + // basic impl for a rollup based on weighted counts type RecordEntry struct { net.IPNet From 3bc086fe1c3f0fd10618741dbb7e02a72f195b51 Mon Sep 17 00:00:00 2001 From: Elliot Cubit Date: Fri, 26 May 2023 14:20:46 -0400 Subject: [PATCH 4/5] Add versionedRangerIter to support ipv4+ivp6 --- cidranger.go | 55 +++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 54 insertions(+), 1 deletion(-) diff --git a/cidranger.go b/cidranger.go index 2ea5393..a990c57 100644 --- a/cidranger.go +++ b/cidranger.go @@ -168,6 +168,58 @@ type rangerIter struct { method TraversalMethod } +type versionedRangerIter struct { + v4 RangerIter + v4Done bool + v6 RangerIter +} + +func NewVersionedIter(r Ranger, shallow bool, method TraversalMethod) *versionedRangerIter { + root, ok := r.(*versionedRanger) + if !ok { + panic(fmt.Errorf("Invalid type for versionedRangerIter")) + } + iter := versionedRangerIter{ + v4: NewIter(root.ipV4Ranger, shallow, method), + v6: NewIter(root.ipV6Ranger, shallow, method), + } + return &iter +} + +func (i *versionedRangerIter) Next() bool { + if i.v4.Next() { + return true + } + + i.v4Done = true + if i.v6.Next() { + return true + } + + return false +} + +func (i *versionedRangerIter) Get() RangerEntry { + if i.v4Done { + return i.v6.Get() + } + return i.v4.Get() +} + +func (i *versionedRangerIter) Error() error { + if err := i.v4.Error(); err != nil { + return err + } + return i.v6.Error() +} + +func (i *versionedRangerIter) GetPath() []RangerEntry { + if i.v4Done { + return i.v6.GetPath() + } + return i.v4.GetPath() +} + // A bredth-first iterator that returns all netblocks with a RangerEntry func NewBredthIter(r Ranger) *rangerIter { return NewIter(r, false, TraversalMethodBreadth) @@ -177,10 +229,11 @@ func NewBredthIter(r Ranger) *rangerIter { func NewShallowBredthIter(r Ranger) *rangerIter { return NewIter(r, true, TraversalMethodBreadth) } + func NewIter(r Ranger, shallow bool, method TraversalMethod) *rangerIter { root, ok := r.(*prefixTrie) if !ok { - panic(fmt.Errorf("Invalid type for bredthRangerIter")) + panic(fmt.Errorf("Invalid type for rangerIter")) } iter := rangerIter{ node: root, From 683c84ad62926ce020d6b09787e404e90c9a82ba Mon Sep 17 00:00:00 2001 From: Elliot Cubit Date: Fri, 26 May 2023 14:20:46 -0400 Subject: [PATCH 5/5] Correct spelling of breadth --- cidranger.go | 10 +++++----- cidranger_test.go | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/cidranger.go b/cidranger.go index a990c57..d595272 100644 --- a/cidranger.go +++ b/cidranger.go @@ -142,7 +142,7 @@ func Subnets(base net.IPNet, prefixlen int) (subnets []net.IPNet, err error) { } // RangerIter is an interface to use with an iterator-like pattern -// ri := NewBredthIter(ptrie) +// ri := NewBreadthIter(ptrie) // for ri.Next() { // entry := ri.Get() // ... @@ -220,13 +220,13 @@ func (i *versionedRangerIter) GetPath() []RangerEntry { return i.v4.GetPath() } -// A bredth-first iterator that returns all netblocks with a RangerEntry -func NewBredthIter(r Ranger) *rangerIter { +// A breadth-first iterator that returns all netblocks with a RangerEntry +func NewBreadthIter(r Ranger) *rangerIter { return NewIter(r, false, TraversalMethodBreadth) } -// A bredth-first iterator that will return only the largest netblocks with an entry -func NewShallowBredthIter(r Ranger) *rangerIter { +// A breadth-first iterator that will return only the largest netblocks with an entry +func NewShallowBreadthIter(r Ranger) *rangerIter { return NewIter(r, true, TraversalMethodBreadth) } diff --git a/cidranger_test.go b/cidranger_test.go index 65d8ec5..95c46d1 100644 --- a/cidranger_test.go +++ b/cidranger_test.go @@ -162,7 +162,7 @@ func TestSubnets(t *testing.T) { } } -func TestBredthRangerIter(t *testing.T) { +func TestBreadthRangerIter(t *testing.T) { cases := []struct { version rnet.IPVersion inserts []string @@ -192,7 +192,7 @@ func TestBredthRangerIter(t *testing.T) { expectedEntries = append(expectedEntries, (*network)) } var resultEntries []net.IPNet - iter := NewBredthIter(trie.(*prefixTrie)) + iter := NewBreadthIter(trie.(*prefixTrie)) for iter.Next() { entry := iter.Get() resultEntries = append(resultEntries, entry.Network()) @@ -350,7 +350,7 @@ func TestRollupApply(t *testing.T) { assert.Errorf(t, err, "Expected error: %v", tc.err) } if tc.err == nil && err == nil { - iter := NewShallowBredthIter(trie.(*prefixTrie)) + iter := NewShallowBreadthIter(trie.(*prefixTrie)) for iter.Next() { entry := iter.Get() resultEntries = append(resultEntries, entry.(RecordEntry))