Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements to RangerIter #1

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 103 additions & 18 deletions cidranger.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ var AllIPv4 = parseCIDRUnsafe("0.0.0.0/0")
// AllIPv6 is a IPv6 CIDR that contains all networks
var AllIPv6 = parseCIDRUnsafe("0::0/0")

type TraversalMethod int

const (
TraversalMethodBreadth = iota
TraversalMethodDepth = iota
)

func parseCIDRUnsafe(s string) *net.IPNet {
_, cidr, _ := net.ParseCIDR(s)
return cidr
Expand Down Expand Up @@ -135,7 +142,7 @@ func Subnets(base net.IPNet, prefixlen int) (subnets []net.IPNet, err error) {
}

// RangerIter is an interface to use with an iterator-like pattern
// ri := NewBredthIter(ptrie)
// ri := NewBreadthIter(ptrie)
// for ri.Next() {
// entry := ri.Get()
// ...
Expand All @@ -150,39 +157,95 @@ func Subnets(base net.IPNet, prefixlen int) (subnets []net.IPNet, err error) {
type RangerIter interface {
Next() bool
Get() RangerEntry
GetPath() []RangerEntry
Error() error
}

type bredthRangerIter struct {
type rangerIter struct {
path *list.List
node *prefixTrie
shallow bool
method TraversalMethod
}

type versionedRangerIter struct {
v4 RangerIter
v4Done bool
v6 RangerIter
}

func NewVersionedIter(r Ranger, shallow bool, method TraversalMethod) *versionedRangerIter {
root, ok := r.(*versionedRanger)
if !ok {
panic(fmt.Errorf("Invalid type for versionedRangerIter"))
}
iter := versionedRangerIter{
v4: NewIter(root.ipV4Ranger, shallow, method),
v6: NewIter(root.ipV6Ranger, shallow, method),
}
return &iter
}

// A bredth-first iterator that returns all netblocks with a RangerEntry
func NewBredthIter(r Ranger) bredthRangerIter {
return newBredthIter(r, false)
func (i *versionedRangerIter) Next() bool {
if i.v4.Next() {
return true
}

i.v4Done = true
if i.v6.Next() {
return true
}

return false
}

// A bredth-first iterator that will return only the largest netblocks with an entry
func NewShallowBredthIter(r Ranger) bredthRangerIter {
return newBredthIter(r, true)
func (i *versionedRangerIter) Get() RangerEntry {
if i.v4Done {
return i.v6.Get()
}
return i.v4.Get()
}
func newBredthIter(r Ranger, shallow bool) bredthRangerIter {

func (i *versionedRangerIter) Error() error {
if err := i.v4.Error(); err != nil {
return err
}
return i.v6.Error()
}

func (i *versionedRangerIter) GetPath() []RangerEntry {
if i.v4Done {
return i.v6.GetPath()
}
return i.v4.GetPath()
}

// A breadth-first iterator that returns all netblocks with a RangerEntry
func NewBreadthIter(r Ranger) *rangerIter {
return NewIter(r, false, TraversalMethodBreadth)
}

// A breadth-first iterator that will return only the largest netblocks with an entry
func NewShallowBreadthIter(r Ranger) *rangerIter {
return NewIter(r, true, TraversalMethodBreadth)
}

func NewIter(r Ranger, shallow bool, method TraversalMethod) *rangerIter {
root, ok := r.(*prefixTrie)
if !ok {
panic(fmt.Errorf("Invalid type for bredthRangerIter"))
panic(fmt.Errorf("Invalid type for rangerIter"))
}
iter := bredthRangerIter{
iter := rangerIter{
node: root,
path: list.New(),
shallow: shallow,
method: method,
}
iter.path.PushBack(root)
return iter
return &iter
}

func (i *bredthRangerIter) Next() bool {
func (i *rangerIter) Next() bool {
for i.path.Len() > 0 {
element := i.path.Front()
i.path.Remove(element)
Expand All @@ -192,7 +255,14 @@ func (i *bredthRangerIter) Next() bool {
}
for _, child := range i.node.children {
if child != nil {
i.path.PushBack(child)
switch i.method {
case TraversalMethodBreadth:
i.path.PushBack(child)
case TraversalMethodDepth:
i.path.PushFront(child)
default:
panic(fmt.Sprintf("Unrecognized TraversalMethod %v", i.method))
}
}
}
if i.node.hasEntry() {
Expand All @@ -202,11 +272,21 @@ func (i *bredthRangerIter) Next() bool {
return false
}

func (i *bredthRangerIter) Get() RangerEntry {
func (i *rangerIter) Get() RangerEntry {
return i.node.entry
}

func (i *bredthRangerIter) Error() error {
func (i *rangerIter) GetPath() []RangerEntry {
retv := make([]RangerEntry, 0)
for this := i.node; this.parent != nil; this = this.parent {
if this.hasEntry() {
retv = append(retv, this.entry)
}
}
return retv
}

func (i *rangerIter) Error() error {
return nil
}

Expand Down Expand Up @@ -271,8 +351,13 @@ func getRollupEntries(trie *prefixTrie, f RollupApply) []RangerEntry {
// If both have an entry, check to rollup
if node.children[0].hasEntry() && node.children[1].hasEntry() {
if f.CanRollup(node.children[0].entry, node.children[1].entry) {
rollupEntries = append(rollupEntries,
f.GetParentEntry(node.children[0].entry, node.children[1].entry, node.network.IPNet),
rollupEntries = append(
rollupEntries,
f.GetParentEntry(
node.children[0].entry,
node.children[1].entry,
node.network.IPNet,
),
)
}
continue
Expand Down
95 changes: 87 additions & 8 deletions cidranger_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,13 @@ func TestSubnets(t *testing.T) {
name string
}{
{"0.0.0.0/8", 33, nil, rnet.ErrBadMaskLength, "IPv4 prefix too long"},
{"0.0.0.0/0", 2, []string{"0.0.0.0/2", "64.0.0.0/2", "128.0.0.0/2", "192.0.0.0/2"}, nil, "IPv4 /0 to /2"},
{
"0.0.0.0/0",
2,
[]string{"0.0.0.0/2", "64.0.0.0/2", "128.0.0.0/2", "192.0.0.0/2"},
nil,
"IPv4 /0 to /2",
},
{"10.0.0.0/8", 0, []string{"10.0.0.0/9", "10.128.0.0/9"}, nil, "IPv4 default split /8"},
{"::/2", 4, []string{"::/4", "1000::/4", "2000::/4", "3000::/4"}, nil, "IPv6 /2 to /4"},
{"10.0.0.0/15", 15, []string{"10.0.0.0/15"}, nil, "IPv4 prefix self"},
Expand All @@ -156,7 +162,7 @@ func TestSubnets(t *testing.T) {
}
}

func TestBredthRangerIter(t *testing.T) {
func TestBreadthRangerIter(t *testing.T) {
cases := []struct {
version rnet.IPVersion
inserts []string
Expand All @@ -165,7 +171,68 @@ func TestBredthRangerIter(t *testing.T) {
}{
{rnet.IPv4, []string{}, []string{}, "empty"},
{rnet.IPv4, []string{"1.2.3.4/15"}, []string{"1.2.3.4/15"}, "single v4"},
{rnet.IPv4, []string{"255.0.0.0/8", "0.0.0.0/8"}, []string{"0.0.0.0/8", "255.0.0.0/8"}, "ordering v4"},
{
rnet.IPv4,
[]string{"255.0.0.0/8", "0.0.0.0/8"},
[]string{"0.0.0.0/8", "255.0.0.0/8"},
"ordering v4",
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
trie := newPrefixTree(tc.version)
for _, insert := range tc.inserts {
_, network, _ := net.ParseCIDR(insert)
err := trie.Insert(NewBasicRangerEntry(*network))
assert.NoError(t, err)
}
var expectedEntries []net.IPNet
for _, expected := range tc.expected {
_, network, _ := net.ParseCIDR(expected)
expectedEntries = append(expectedEntries, (*network))
}
var resultEntries []net.IPNet
iter := NewBreadthIter(trie.(*prefixTrie))
for iter.Next() {
entry := iter.Get()
resultEntries = append(resultEntries, entry.Network())
}
assert.Equal(t, expectedEntries, resultEntries)
})
}
}

func TestDepthRangerIter(t *testing.T) {
cases := []struct {
version rnet.IPVersion
inserts []string
expected []string
name string
}{
{rnet.IPv4, []string{}, []string{}, "empty"},
{rnet.IPv4, []string{"1.2.3.4/15"}, []string{"1.2.3.4/15"}, "single v4"},
{
rnet.IPv4,
[]string{
"255.255.0.0/16",
"8.0.0.0/8",
"255.255.254.0/24",
"255.255.253.128/25",
"255.254.0.0/16",
"255.255.255.0/24",
"255.0.0.0/8",
},
[]string{
"255.0.0.0/8",
"255.255.0.0/16",
"255.255.255.0/24",
"255.255.254.0/24",
"255.255.253.128/25",
"255.254.0.0/16",
"8.0.0.0/8",
},
"nest, out-of-order inserts",
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
Expand All @@ -181,7 +248,7 @@ func TestBredthRangerIter(t *testing.T) {
expectedEntries = append(expectedEntries, (*network))
}
var resultEntries []net.IPNet
iter := NewBredthIter(trie.(*prefixTrie))
iter := NewIter(trie.(*prefixTrie), false, TraversalMethodDepth)
for iter.Next() {
entry := iter.Get()
resultEntries = append(resultEntries, entry.Network())
Expand Down Expand Up @@ -218,7 +285,11 @@ func (rc rollupCount) CanRollup(c0 RangerEntry, c1 RangerEntry) bool {
return rc0.Records+rc1.Records < rc.rollupThreshold
}

func (rc rollupCount) GetParentEntry(c0 RangerEntry, c1 RangerEntry, parentNet net.IPNet) RangerEntry {
func (rc rollupCount) GetParentEntry(
c0 RangerEntry,
c1 RangerEntry,
parentNet net.IPNet,
) RangerEntry {
rc0, ok := c0.(RecordEntry)
if !ok {
panic(c0)
Expand Down Expand Up @@ -279,7 +350,7 @@ func TestRollupApply(t *testing.T) {
assert.Errorf(t, err, "Expected error: %v", tc.err)
}
if tc.err == nil && err == nil {
iter := NewShallowBredthIter(trie.(*prefixTrie))
iter := NewShallowBreadthIter(trie.(*prefixTrie))
for iter.Next() {
entry := iter.Get()
resultEntries = append(resultEntries, entry.(RecordEntry))
Expand Down Expand Up @@ -332,10 +403,18 @@ func BenchmarkBruteRangerHitContainingNetworksIPv4UsingAWSRanges(b *testing.B) {
}

func BenchmarkPCTrieHitContainingNetworksIPv6UsingAWSRanges(b *testing.B) {
benchmarkContainingNetworksUsingAWSRanges(b, net.ParseIP("2620:107:300f::36b7:ff81"), NewPCTrieRanger())
benchmarkContainingNetworksUsingAWSRanges(
b,
net.ParseIP("2620:107:300f::36b7:ff81"),
NewPCTrieRanger(),
)
}
func BenchmarkBruteRangerHitContainingNetworksIPv6UsingAWSRanges(b *testing.B) {
benchmarkContainingNetworksUsingAWSRanges(b, net.ParseIP("2620:107:300f::36b7:ff81"), newBruteRanger())
benchmarkContainingNetworksUsingAWSRanges(
b,
net.ParseIP("2620:107:300f::36b7:ff81"),
newBruteRanger(),
)
}

func BenchmarkPCTrieMissContainingNetworksIPv4UsingAWSRanges(b *testing.B) {
Expand Down
20 changes: 17 additions & 3 deletions trie_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,12 @@ func TestPrefixTrieInsert(t *testing.T) {
assert.NoError(t, err)
}

assert.Equal(t, len(tc.expectedNetworksInDepthOrder), trie.Len(), "trie size should match")
assert.Equal(
t,
len(tc.expectedNetworksInDepthOrder),
trie.Len(),
"trie size should match",
)

allNetworks, err := trie.CoveredNetworks(*getAllByVersion(tc.version))
assert.Nil(t, err)
Expand Down Expand Up @@ -213,7 +218,12 @@ func TestPrefixTrieRemove(t *testing.T) {
}
}

assert.Equal(t, len(tc.expectedNetworksInDepthOrder), trie.Len(), "trie size should match after revmoval")
assert.Equal(
t,
len(tc.expectedNetworksInDepthOrder),
trie.Len(),
"trie size should match after revmoval",
)

allNetworks, err := trie.CoveredNetworks(*getAllByVersion(tc.version))
assert.Nil(t, err)
Expand Down Expand Up @@ -686,7 +696,11 @@ func TestTrieMemUsage(t *testing.T) {

// Assert that heap allocation from first loop is within set threshold of avg over all runs.
assert.Less(t, uint64(0), baseLineHeap)
assert.LessOrEqual(t, float64(baseLineHeap), float64(totalHeapAllocOverRuns/uint64(runs))*thresh)
assert.LessOrEqual(
t,
float64(baseLineHeap),
float64(totalHeapAllocOverRuns/uint64(runs))*thresh,
)
}

func GenLeafIPNet(ip net.IP) net.IPNet {
Expand Down