From 2aa0d40b232e9e98966ab46d614b67483eee7177 Mon Sep 17 00:00:00 2001 From: DI LI Date: Sat, 23 Mar 2024 11:37:02 -0700 Subject: [PATCH 1/5] add test and benchmark for iptrie --- cidranger_test.go | 154 ++++++++++++++++-- go.mod | 10 +- go.sum | 6 +- iptire/trie.go | 405 ++++++++++++++++++++++++++++++++++++++++++++++ iptire/uint128.go | 81 ++++++++++ 5 files changed, 634 insertions(+), 22 deletions(-) create mode 100644 iptire/trie.go create mode 100644 iptire/uint128.go diff --git a/cidranger_test.go b/cidranger_test.go index c1c741e..adcb5ae 100644 --- a/cidranger_test.go +++ b/cidranger_test.go @@ -2,14 +2,17 @@ package cidranger import ( "encoding/json" + iptrie "github.com/yl2chen/cidranger/iptire" "io/ioutil" "math/rand" "net" + "net/netip" "testing" "time" "github.com/stretchr/testify/assert" rnet "github.com/yl2chen/cidranger/net" + "go4.org/netipx" ) /* @@ -51,10 +54,12 @@ func testContainsAgainstBase(t *testing.T, iterations int, ipGen ipGenerator) { } rangers := []Ranger{NewPCTrieRanger()} baseRanger := newBruteRanger() + trie := iptrie.NewTrie() for _, ranger := range rangers { configureRangerWithAWSRanges(t, ranger) } configureRangerWithAWSRanges(t, baseRanger) + configureTrieWithAWSRanges(t, trie) for i := 0; i < iterations; i++ { nn := ipGen() @@ -65,19 +70,40 @@ func testContainsAgainstBase(t *testing.T, iterations int, ipGen ipGenerator) { assert.NoError(t, err) assert.Equal(t, expected, actual) } + addr, ok := netip.AddrFromSlice(nn.ToIP()) + if !ok { + t.Errorf("netip addr convert fail") + continue + } + got := trie.Find(addr) + var gotvalue bool + if got != nil { + gotvalue = true + } + assert.Equal(t, expected, gotvalue) } } +func testNormalizePrefix(pfx netip.Prefix) netip.Prefix { + if pfx.Addr().Is4() { + pfx = netip.PrefixFrom(netip.AddrFrom16(pfx.Addr().As16()), pfx.Bits()+96) + } + return pfx.Masked() +} + func testContainingNetworksAgainstBase(t *testing.T, iterations int, ipGen ipGenerator) { if testing.Short() { t.Skip("Skipping memory test in `-short` mode") } rangers := []Ranger{NewPCTrieRanger()} baseRanger := newBruteRanger() + trie := iptrie.NewTrie() + for _, ranger := range rangers { configureRangerWithAWSRanges(t, ranger) } configureRangerWithAWSRanges(t, baseRanger) + configureTrieWithAWSRanges(t, trie) for i := 0; i < iterations; i++ { nn := ipGen() @@ -91,6 +117,35 @@ func testContainingNetworksAgainstBase(t *testing.T, iterations int, ipGen ipGen assert.Contains(t, expected, network) } } + + addr, ok := netip.AddrFromSlice(nn.ToIP()) + if !ok { + t.Errorf("netip addr convert fail") + continue + } + got := trie.ContainingNetworks(addr) + assert.Equal(t, len(expected), len(got)) + builderExpected := new(netipx.IPSetBuilder) + builderGot := new(netipx.IPSetBuilder) + + for _, p := range expected { + n := p.Network() + prefix, ok := netipx.FromStdIPNet(&n) + if !ok { + t.Errorf("netip addr convert fail") + } + builderExpected.AddPrefix(testNormalizePrefix(prefix)) + } + expSet, err := builderExpected.IPSet() + + for _, g := range got { + builderGot.AddPrefix(g) + } + gotSet, err := builderGot.IPSet() + + if !expSet.Equal(gotSet) { + t.Errorf("not same set") + } } } @@ -100,10 +155,13 @@ func testCoversNetworksAgainstBase(t *testing.T, iterations int, netGen networkG } rangers := []Ranger{NewPCTrieRanger()} baseRanger := newBruteRanger() + trie := iptrie.NewTrie() + for _, ranger := range rangers { configureRangerWithAWSRanges(t, ranger) } configureRangerWithAWSRanges(t, baseRanger) + configureTrieWithAWSRanges(t, trie) for i := 0; i < iterations; i++ { network := netGen() @@ -117,6 +175,35 @@ func testCoversNetworksAgainstBase(t *testing.T, iterations int, netGen networkG assert.Contains(t, expected, network) } } + + queryPrefix, ok := netipx.FromStdIPNet(&network.IPNet) + if !ok { + t.Errorf("netip addr convert fail") + } + + got := trie.CoveredNetworks(queryPrefix) + assert.Equal(t, len(expected), len(got)) + builderExpected := new(netipx.IPSetBuilder) + builderGot := new(netipx.IPSetBuilder) + + for _, p := range expected { + n := p.Network() + prefix, ok := netipx.FromStdIPNet(&n) + if !ok { + t.Errorf("netip addr convert fail") + } + builderExpected.AddPrefix(testNormalizePrefix(prefix)) + } + expSet, err := builderExpected.IPSet() + + for _, g := range got { + builderGot.AddPrefix(g) + } + gotSet, err := builderGot.IPSet() + + if !expSet.Equal(gotSet) { + t.Errorf("not same set") + } } } @@ -129,57 +216,65 @@ func testCoversNetworksAgainstBase(t *testing.T, iterations int, netGen networkG func BenchmarkPCTrieHitIPv4UsingAWSRanges(b *testing.B) { benchmarkContainsUsingAWSRanges(b, net.ParseIP("52.95.110.1"), NewPCTrieRanger()) } -func BenchmarkBruteRangerHitIPv4UsingAWSRanges(b *testing.B) { - benchmarkContainsUsingAWSRanges(b, net.ParseIP("52.95.110.1"), newBruteRanger()) + +func BenchmarkTrieHitIPv4UsingAWSRanges(b *testing.B) { + benchmarkTrieContainsUsingAWSRanges(b, netip.MustParseAddr("52.95.110.1"), iptrie.NewTrie()) } func BenchmarkPCTrieHitIPv6UsingAWSRanges(b *testing.B) { benchmarkContainsUsingAWSRanges(b, net.ParseIP("2620:107:300f::36b7:ff81"), NewPCTrieRanger()) } -func BenchmarkBruteRangerHitIPv6UsingAWSRanges(b *testing.B) { - benchmarkContainsUsingAWSRanges(b, net.ParseIP("2620:107:300f::36b7:ff81"), newBruteRanger()) + +func BenchmarkTrieHitIPv6UsingAWSRanges(b *testing.B) { + benchmarkTrieContainsUsingAWSRanges(b, netip.MustParseAddr("2620:107:300f::36b7:ff81"), iptrie.NewTrie()) } func BenchmarkPCTrieMissIPv4UsingAWSRanges(b *testing.B) { benchmarkContainsUsingAWSRanges(b, net.ParseIP("123.123.123.123"), NewPCTrieRanger()) } -func BenchmarkBruteRangerMissIPv4UsingAWSRanges(b *testing.B) { - benchmarkContainsUsingAWSRanges(b, net.ParseIP("123.123.123.123"), newBruteRanger()) + +func BenchmarkTrieMissIPv4UsingAWSRanges(b *testing.B) { + benchmarkTrieContainsUsingAWSRanges(b, netip.MustParseAddr("123.123.123.123"), iptrie.NewTrie()) } func BenchmarkPCTrieHMissIPv6UsingAWSRanges(b *testing.B) { benchmarkContainsUsingAWSRanges(b, net.ParseIP("2620::ffff"), NewPCTrieRanger()) } -func BenchmarkBruteRangerMissIPv6UsingAWSRanges(b *testing.B) { - benchmarkContainsUsingAWSRanges(b, net.ParseIP("2620::ffff"), newBruteRanger()) + +func BenchmarkTrieHMissIPv6UsingAWSRanges(b *testing.B) { + benchmarkTrieContainsUsingAWSRanges(b, netip.MustParseAddr("2620::ffff"), iptrie.NewTrie()) } func BenchmarkPCTrieHitContainingNetworksIPv4UsingAWSRanges(b *testing.B) { benchmarkContainingNetworksUsingAWSRanges(b, net.ParseIP("52.95.110.1"), NewPCTrieRanger()) } -func BenchmarkBruteRangerHitContainingNetworksIPv4UsingAWSRanges(b *testing.B) { - benchmarkContainingNetworksUsingAWSRanges(b, net.ParseIP("52.95.110.1"), newBruteRanger()) + +func BenchmarkTrieHitContainingNetworksIPv4UsingAWSRanges(b *testing.B) { + benchmarkTrieContainingNetworksUsingAWSRanges(b, netip.MustParseAddr("52.95.110.1"), iptrie.NewTrie()) } func BenchmarkPCTrieHitContainingNetworksIPv6UsingAWSRanges(b *testing.B) { benchmarkContainingNetworksUsingAWSRanges(b, net.ParseIP("2620:107:300f::36b7:ff81"), NewPCTrieRanger()) } -func BenchmarkBruteRangerHitContainingNetworksIPv6UsingAWSRanges(b *testing.B) { - benchmarkContainingNetworksUsingAWSRanges(b, net.ParseIP("2620:107:300f::36b7:ff81"), newBruteRanger()) + +func BenchmarkTrieHitContainingNetworksIPv6UsingAWSRanges(b *testing.B) { + benchmarkTrieContainingNetworksUsingAWSRanges(b, netip.MustParseAddr("2620:107:300f::36b7:ff81"), iptrie.NewTrie()) } func BenchmarkPCTrieMissContainingNetworksIPv4UsingAWSRanges(b *testing.B) { benchmarkContainingNetworksUsingAWSRanges(b, net.ParseIP("123.123.123.123"), NewPCTrieRanger()) } -func BenchmarkBruteRangerMissContainingNetworksIPv4UsingAWSRanges(b *testing.B) { - benchmarkContainingNetworksUsingAWSRanges(b, net.ParseIP("123.123.123.123"), newBruteRanger()) + +func BenchmarkTrieMissContainingNetworksIPv4UsingAWSRanges(b *testing.B) { + benchmarkTrieContainingNetworksUsingAWSRanges(b, netip.MustParseAddr("123.123.123.123"), iptrie.NewTrie()) } func BenchmarkPCTrieHMissContainingNetworksIPv6UsingAWSRanges(b *testing.B) { benchmarkContainingNetworksUsingAWSRanges(b, net.ParseIP("2620::ffff"), NewPCTrieRanger()) } -func BenchmarkBruteRangerMissContainingNetworksIPv6UsingAWSRanges(b *testing.B) { - benchmarkContainingNetworksUsingAWSRanges(b, net.ParseIP("2620::ffff"), newBruteRanger()) + +func BenchmarkTrieHMissContainingNetworksIPv6UsingAWSRanges(b *testing.B) { + benchmarkTrieContainingNetworksUsingAWSRanges(b, netip.MustParseAddr("2620::ffff"), iptrie.NewTrie()) } func BenchmarkNewPathprefixTriev4(b *testing.B) { @@ -217,6 +312,20 @@ func benchmarkNewPathprefixTrie(b *testing.B, net1 string) { } } +func benchmarkTrieContainsUsingAWSRanges(tb testing.TB, nn netip.Addr, trie *iptrie.Trie) { + configureTrieWithAWSRanges(tb, trie) + for n := 0; n < tb.(*testing.B).N; n++ { + trie.Find(nn) + } +} + +func benchmarkTrieContainingNetworksUsingAWSRanges(tb testing.TB, nn netip.Addr, trie *iptrie.Trie) { + configureTrieWithAWSRanges(tb, trie) + for n := 0; n < tb.(*testing.B).N; n++ { + trie.ContainingNetworks(nn) + } +} + /* ****************************************************************** Helper methods and initialization. @@ -299,6 +408,19 @@ func configureRangerWithAWSRanges(tb testing.TB, ranger Ranger) { } } +func configureTrieWithAWSRanges(tb testing.TB, trie *iptrie.Trie) { + for _, prefix := range awsRanges.Prefixes { + network, err := netip.ParsePrefix(prefix.IPPrefix) + assert.NoError(tb, err) + trie.Insert(network, struct{}{}) + } + for _, prefix := range awsRanges.IPv6Prefixes { + network, err := netip.ParsePrefix(prefix.IPPrefix) + assert.NoError(tb, err) + trie.Insert(network, struct{}{}) + } +} + func init() { awsRanges = loadAWSRanges() for _, prefix := range awsRanges.IPv6Prefixes { diff --git a/go.mod b/go.mod index a35ea91..3ae3bdb 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,14 @@ module github.com/yl2chen/cidranger -go 1.13 +go 1.21 require ( github.com/stretchr/testify v1.6.1 - gopkg.in/yaml.v2 v2.2.2 // indirect + go4.org/netipx v0.0.0-20231129151722-fdeea329fbba +) + +require ( + github.com/davecgh/go-spew v1.1.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect ) diff --git a/go.sum b/go.sum index d063842..7623d26 100644 --- a/go.sum +++ b/go.sum @@ -3,13 +3,11 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= -github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +go4.org/netipx v0.0.0-20231129151722-fdeea329fbba h1:0b9z3AuHCjxk0x/opv64kcgZLBseWJUpBw5I82+2U4M= +go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/WfdlivPbZJsZdgWZlrGope/Y= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/iptire/trie.go b/iptire/trie.go new file mode 100644 index 0000000..1badc29 --- /dev/null +++ b/iptire/trie.go @@ -0,0 +1,405 @@ +// Package iptrie is a fork of github.com/yl2chen/cidranger. This fork massively strips down and refactors the code for +// increased performance, resulting in 20x faster load time, and 1.5x faster lookups. + +package iptrie + +import ( + "fmt" + "math/bits" + "net/netip" + "strings" + "unsafe" +) + +// Trie is an IP radix trie implementation, similar to what is described +// at https://vincent.bernat.im/en/blog/2017-ipv4-route-lookup-linux +// +// CIDR blocks are stored using a prefix tree structure where each node has its +// parent as prefix, and the path from the root node represents current CIDR +// block. +// +// Path compression compresses a string of node with only 1 child into a single +// node, decrease the amount of lookups necessary during containment tests. +type Trie struct { + parent *Trie + children [2]*Trie + + network netip.Prefix + value any +} + +// NewTrie creates a new Trie. +func NewTrie() *Trie { + return &Trie{ + network: netip.PrefixFrom(netip.IPv6Unspecified(), 0), + } +} + +func newSubTree(network netip.Prefix, value any) *Trie { + return &Trie{ + network: network, + value: value, + } +} + +// Insert inserts a RangerEntry into prefix trie. +func (p *Trie) Insert(network netip.Prefix, value any) { + network = normalizePrefix(network) + p.insert(network, value) +} + +// Remove removes RangerEntry identified by given network from trie. +func (p *Trie) Remove(network netip.Prefix) any { + network = normalizePrefix(network) + return p.remove(network) +} + +// Find returns the value from the smallest prefix containing the given address. +func (p *Trie) Find(ip netip.Addr) any { + ip = normalizeAddr(ip) + return p.find(ip) +} + +// ContainingNetworks returns the list of RangerEntry(s) the given ip is +// contained in in ascending prefix order. +func (p *Trie) ContainingNetworks(ip netip.Addr) []netip.Prefix { + ip = normalizeAddr(ip) + return p.containingNetworks(ip) +} + +// CoveredNetworks returns the list of RangerEntry(s) the given ipnet +// covers. That is, the networks that are completely subsumed by the +// specified network. +func (p *Trie) CoveredNetworks(network netip.Prefix) []netip.Prefix { + network = normalizePrefix(network) + return p.coveredNetworks(network) +} + +func (p *Trie) Network() netip.Prefix { + return p.network +} + +// String returns string representation of trie, mainly for visualization and +// debugging. +func (p *Trie) String() string { + children := []string{} + padding := strings.Repeat("| ", p.level()+1) + for bit, child := range p.children { + if child == nil { + continue + } + childStr := fmt.Sprintf("\n%s%d--> %s", padding, bit, child.String()) + children = append(children, childStr) + } + return fmt.Sprintf("%s (has_entry:%t)%s", p.network, + p.value != nil, strings.Join(children, "")) +} + +func (p *Trie) find(number netip.Addr) any { + if !netContains(p.network, number) { + return nil + } + if p.value != nil { + return p.value + } + if p.network.Bits() == 128 { + return nil + } + bit := p.discriminatorBitFromIP(number) + child := p.children[bit] + if child != nil { + return child.find(number) + } + return nil +} + +func (p *Trie) containingNetworks(addr netip.Addr) []netip.Prefix { + var results []netip.Prefix + if !p.network.Contains(addr) { + return results + } + if p.value != nil { + results = []netip.Prefix{p.network} + } + if p.network.Bits() == 128 { + return results + } + bit := p.discriminatorBitFromIP(addr) + child := p.children[bit] + if child != nil { + ranges := child.containingNetworks(addr) + if len(ranges) > 0 { + if len(results) > 0 { + results = append(results, ranges...) + } else { + results = ranges + } + } + } + return results +} + +func (p *Trie) coveredNetworks(network netip.Prefix) []netip.Prefix { + var results []netip.Prefix + if network.Bits() <= p.network.Bits() && network.Contains(p.network.Addr()) { + for entry := range p.walkDepth() { + results = append(results, entry) + } + } else if p.network.Bits() < 128 { + bit := p.discriminatorBitFromIP(network.Addr()) + child := p.children[bit] + if child != nil { + return child.coveredNetworks(network) + } + } + return results +} + +// This is an unsafe, but faster version of netip.Prefix.Contains +func netContains(pfx netip.Prefix, ip netip.Addr) bool { + pfxAddr := addr128(pfx.Addr()) + ipAddr := addr128(ip) + return ipAddr.xor(pfxAddr).and(mask6(pfx.Bits())).isZero() +} + +// netDivergence returns the largest prefix shared by the provided 2 prefixes +func netDivergence(net1 netip.Prefix, net2 netip.Prefix) netip.Prefix { + if net1.Bits() > net2.Bits() { + net1, net2 = net2, net1 + } + + if netContains(net1, net2.Addr()) { + return net1 + } + + diff := addr128(net1.Addr()).xor(addr128(net2.Addr())) + var bit int + if diff.hi != 0 { + bit = bits.LeadingZeros64(diff.hi) + } else { + bit = bits.LeadingZeros64(diff.lo) + 64 + } + if bit > net1.Bits() { + bit = net1.Bits() + } + pfx, _ := net1.Addr().Prefix(bit) + return pfx +} + +func (p *Trie) insert(network netip.Prefix, value any) *Trie { + if p.network == network { + p.value = value + return p + } + + bit := p.discriminatorBitFromIP(network.Addr()) + existingChild := p.children[bit] + + // No existing child, insert new leaf trie. + if existingChild == nil { + pNew := newSubTree(network, value) + p.appendTrie(bit, pNew) + return pNew + } + + // Check whether it is necessary to insert additional path prefix between current trie and existing child, + // in the case that inserted network diverges on its path to existing child. + netdiv := netDivergence(existingChild.network, network) + if netdiv != existingChild.network { + pathPrefix := newSubTree(netdiv, nil) + p.insertPrefix(bit, pathPrefix, existingChild) + // Update new child + existingChild = pathPrefix + } + return existingChild.insert(network, value) +} + +func (p *Trie) appendTrie(bit uint8, prefix *Trie) { + p.children[bit] = prefix + prefix.parent = p +} + +func (p *Trie) insertPrefix(bit uint8, pathPrefix, child *Trie) { + // Set parent/child relationship between current trie and inserted pathPrefix + p.children[bit] = pathPrefix + pathPrefix.parent = p + + // Set parent/child relationship between inserted pathPrefix and original child + pathPrefixBit := pathPrefix.discriminatorBitFromIP(child.network.Addr()) + pathPrefix.children[pathPrefixBit] = child + child.parent = pathPrefix +} + +func (p *Trie) remove(network netip.Prefix) any { + if p.value != nil && p.network == network { + entry := p.value + p.value = nil + + p.compressPathIfPossible() + return entry + } + if p.network.Bits() == 128 { + return nil + } + bit := p.discriminatorBitFromIP(network.Addr()) + child := p.children[bit] + if child != nil { + return child.remove(network) + } + return nil +} + +func (p *Trie) qualifiesForPathCompression() bool { + // Current prefix trie can be path compressed if it meets all following. + // 1. records no CIDR entry + // 2. has single or no child + // 3. is not root trie + return p.value == nil && p.childrenCount() <= 1 && p.parent != nil +} + +func (p *Trie) compressPathIfPossible() { + if !p.qualifiesForPathCompression() { + // Does not qualify to be compressed + return + } + + // Find lone child. + var loneChild *Trie + for _, child := range p.children { + if child != nil { + loneChild = child + break + } + } + + // Find root of currnt single child lineage. + parent := p.parent + for ; parent.qualifiesForPathCompression(); parent = parent.parent { + } + parentBit := parent.discriminatorBitFromIP(p.network.Addr()) + parent.children[parentBit] = loneChild + + // Attempts to furthur apply path compression at current lineage parent, in case current lineage + // compressed into parent. + parent.compressPathIfPossible() +} + +func (p *Trie) childrenCount() int { + count := 0 + for _, child := range p.children { + if child != nil { + count++ + } + } + return count +} + +func (p *Trie) discriminatorBitFromIP(addr netip.Addr) uint8 { + // This is a safe uint boxing of int since we should never attempt to get + // target bit at a negative position. + pos := p.network.Bits() + a128 := addr128(addr) + if pos < 64 { + return uint8(a128.hi >> (63 - pos) & 1) + } + return uint8(a128.lo >> (63 - (pos - 64)) & 1) +} + +func (p *Trie) level() int { + if p.parent == nil { + return 0 + } + return p.parent.level() + 1 +} + +// walkDepth walks the trie in depth order, for unit testing. +func (p *Trie) walkDepth() <-chan netip.Prefix { + entries := make(chan netip.Prefix) + go func() { + if p.value != nil { + entries <- p.network + } + childEntriesList := []<-chan netip.Prefix{} + for _, trie := range p.children { + if trie == nil { + continue + } + childEntriesList = append(childEntriesList, trie.walkDepth()) + } + for _, childEntries := range childEntriesList { + for entry := range childEntries { + entries <- entry + } + } + close(entries) + }() + return entries +} + +// TrieLoader can be used to improve the performance of bulk inserts to a Trie. It caches the node of the +// last insert in the tree, using it as the starting point to start searching for the location of the next insert. This +// is highly beneficial when the addresses are pre-sorted. +type TrieLoader struct { + trie *Trie + lastInsert *Trie +} + +func NewTrieLoader(trie *Trie) *TrieLoader { + return &TrieLoader{ + trie: trie, + lastInsert: trie, + } +} + +func (ptl *TrieLoader) Insert(pfx netip.Prefix, v any) { + pfx = normalizePrefix(pfx) + + diff := addr128(ptl.lastInsert.network.Addr()).xor(addr128(pfx.Addr())) + var pos int + if diff.hi != 0 { + pos = bits.LeadingZeros64(diff.hi) + } else { + pos = bits.LeadingZeros64(diff.lo) + 64 + } + if pos > pfx.Bits() { + pos = pfx.Bits() + } + if pos > ptl.lastInsert.network.Bits() { + pos = ptl.lastInsert.network.Bits() + } + + parent := ptl.lastInsert + for parent.network.Bits() > pos { + parent = parent.parent + } + ptl.lastInsert = parent.insert(pfx, v) +} + +func normalizeAddr(addr netip.Addr) netip.Addr { + if addr.Is4() { + return netip.AddrFrom16(addr.As16()) + } + return addr +} + +func normalizePrefix(pfx netip.Prefix) netip.Prefix { + if pfx.Addr().Is4() { + pfx = netip.PrefixFrom(netip.AddrFrom16(pfx.Addr().As16()), pfx.Bits()+96) + } + return pfx.Masked() +} + +func addr128(addr netip.Addr) uint128 { + return *(*uint128)(unsafe.Pointer(&addr)) +} + +func init() { + // Accessing the underlying data of a `netip.Addr` relies upon the data being + // in a known format, which is not guaranteed to be stable. So this init() + // function is to detect if it ever changes. + ip := netip.AddrFrom16([16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}) + i128 := addr128(ip) + if i128.hi != 0x0001020304050607 || i128.lo != 0x08090a0b0c0d0e0f { + panic("netip.Addr format mismatch") + } +} diff --git a/iptire/uint128.go b/iptire/uint128.go new file mode 100644 index 0000000..ea0bff6 --- /dev/null +++ b/iptire/uint128.go @@ -0,0 +1,81 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package iptrie + +import "math/bits" + +// uint128 represents a uint128 using two uint64s. +// +// When the methods below mention a bit number, bit 0 is the most +// significant bit (in hi) and bit 127 is the lowest (lo&1). +type uint128 struct { + hi uint64 + lo uint64 +} + +// mask6 returns a uint128 bitmask with the topmost n bits of a +// 128-bit number. +func mask6(n int) uint128 { + return uint128{^(^uint64(0) >> n), ^uint64(0) << (128 - n)} +} + +// isZero reports whether u == 0. +// +// It's faster than u == (uint128{}) because the compiler (as of Go +// 1.15/1.16b1) doesn't do this trick and instead inserts a branch in +// its eq alg's generated code. +func (u uint128) isZero() bool { return u.hi|u.lo == 0 } + +// and returns the bitwise AND of u and m (u&m). +func (u uint128) and(m uint128) uint128 { + return uint128{u.hi & m.hi, u.lo & m.lo} +} + +// xor returns the bitwise XOR of u and m (u^m). +func (u uint128) xor(m uint128) uint128 { + return uint128{u.hi ^ m.hi, u.lo ^ m.lo} +} + +// or returns the bitwise OR of u and m (u|m). +func (u uint128) or(m uint128) uint128 { + return uint128{u.hi | m.hi, u.lo | m.lo} +} + +// not returns the bitwise NOT of u. +func (u uint128) not() uint128 { + return uint128{^u.hi, ^u.lo} +} + +// subOne returns u - 1. +func (u uint128) subOne() uint128 { + lo, borrow := bits.Sub64(u.lo, 1, 0) + return uint128{u.hi - borrow, lo} +} + +// addOne returns u + 1. +func (u uint128) addOne() uint128 { + lo, carry := bits.Add64(u.lo, 1, 0) + return uint128{u.hi + carry, lo} +} + +// halves returns the two uint64 halves of the uint128. +// +// Logically, think of it as returning two uint64s. +// It only returns pointers for inlining reasons on 32-bit platforms. +func (u *uint128) halves() [2]*uint64 { + return [2]*uint64{&u.hi, &u.lo} +} + +// bitsSetFrom returns a copy of u with the given bit +// and all subsequent ones set. +func (u uint128) bitsSetFrom(bit uint8) uint128 { + return u.or(mask6(int(bit)).not()) +} + +// bitsClearedFrom returns a copy of u with the given bit +// and all subsequent ones cleared. +func (u uint128) bitsClearedFrom(bit uint8) uint128 { + return u.and(mask6(int(bit))) +} From c1091d0a92d604c614b1ab27a09e1c39f635f417 Mon Sep 17 00:00:00 2001 From: DI LI Date: Sat, 23 Mar 2024 12:41:48 -0700 Subject: [PATCH 2/5] modify to return entry instead of just purely prefix --- cidranger_test.go | 4 ++-- iptire/trie.go | 33 +++++++++++++++++++-------------- 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/cidranger_test.go b/cidranger_test.go index adcb5ae..05b3fe6 100644 --- a/cidranger_test.go +++ b/cidranger_test.go @@ -139,7 +139,7 @@ func testContainingNetworksAgainstBase(t *testing.T, iterations int, ipGen ipGen expSet, err := builderExpected.IPSet() for _, g := range got { - builderGot.AddPrefix(g) + builderGot.AddPrefix(g.Network) } gotSet, err := builderGot.IPSet() @@ -197,7 +197,7 @@ func testCoversNetworksAgainstBase(t *testing.T, iterations int, netGen networkG expSet, err := builderExpected.IPSet() for _, g := range got { - builderGot.AddPrefix(g) + builderGot.AddPrefix(g.Network) } gotSet, err := builderGot.IPSet() diff --git a/iptire/trie.go b/iptire/trie.go index 1badc29..31913c3 100644 --- a/iptire/trie.go +++ b/iptire/trie.go @@ -28,6 +28,11 @@ type Trie struct { value any } +type Entry struct { + Network netip.Prefix + Value any +} + // NewTrie creates a new Trie. func NewTrie() *Trie { return &Trie{ @@ -55,14 +60,14 @@ func (p *Trie) Remove(network netip.Prefix) any { } // Find returns the value from the smallest prefix containing the given address. -func (p *Trie) Find(ip netip.Addr) any { +func (p *Trie) Find(ip netip.Addr) *Entry { ip = normalizeAddr(ip) return p.find(ip) } // ContainingNetworks returns the list of RangerEntry(s) the given ip is // contained in in ascending prefix order. -func (p *Trie) ContainingNetworks(ip netip.Addr) []netip.Prefix { +func (p *Trie) ContainingNetworks(ip netip.Addr) []*Entry { ip = normalizeAddr(ip) return p.containingNetworks(ip) } @@ -70,7 +75,7 @@ func (p *Trie) ContainingNetworks(ip netip.Addr) []netip.Prefix { // CoveredNetworks returns the list of RangerEntry(s) the given ipnet // covers. That is, the networks that are completely subsumed by the // specified network. -func (p *Trie) CoveredNetworks(network netip.Prefix) []netip.Prefix { +func (p *Trie) CoveredNetworks(network netip.Prefix) []*Entry { network = normalizePrefix(network) return p.coveredNetworks(network) } @@ -95,12 +100,12 @@ func (p *Trie) String() string { p.value != nil, strings.Join(children, "")) } -func (p *Trie) find(number netip.Addr) any { +func (p *Trie) find(number netip.Addr) *Entry { if !netContains(p.network, number) { return nil } if p.value != nil { - return p.value + return &Entry{p.network, p.value} } if p.network.Bits() == 128 { return nil @@ -113,13 +118,13 @@ func (p *Trie) find(number netip.Addr) any { return nil } -func (p *Trie) containingNetworks(addr netip.Addr) []netip.Prefix { - var results []netip.Prefix +func (p *Trie) containingNetworks(addr netip.Addr) []*Entry { + var results []*Entry if !p.network.Contains(addr) { return results } if p.value != nil { - results = []netip.Prefix{p.network} + results = []*Entry{{p.network, p.value}} } if p.network.Bits() == 128 { return results @@ -139,8 +144,8 @@ func (p *Trie) containingNetworks(addr netip.Addr) []netip.Prefix { return results } -func (p *Trie) coveredNetworks(network netip.Prefix) []netip.Prefix { - var results []netip.Prefix +func (p *Trie) coveredNetworks(network netip.Prefix) []*Entry { + var results []*Entry if network.Bits() <= p.network.Bits() && network.Contains(p.network.Addr()) { for entry := range p.walkDepth() { results = append(results, entry) @@ -313,13 +318,13 @@ func (p *Trie) level() int { } // walkDepth walks the trie in depth order, for unit testing. -func (p *Trie) walkDepth() <-chan netip.Prefix { - entries := make(chan netip.Prefix) +func (p *Trie) walkDepth() <-chan *Entry { + entries := make(chan *Entry) go func() { if p.value != nil { - entries <- p.network + entries <- &Entry{p.network, p.value} } - childEntriesList := []<-chan netip.Prefix{} + childEntriesList := []<-chan *Entry{} for _, trie := range p.children { if trie == nil { continue From 21d1a05ed5cf3f77036f96617f4e6d9a36ce48f5 Mon Sep 17 00:00:00 2001 From: DI LI Date: Sat, 23 Mar 2024 13:07:37 -0700 Subject: [PATCH 3/5] rename Find to Contains --- cidranger_test.go | 10 +++------- iptire/trie.go | 13 ++++++------- 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/cidranger_test.go b/cidranger_test.go index 05b3fe6..5009c83 100644 --- a/cidranger_test.go +++ b/cidranger_test.go @@ -75,12 +75,8 @@ func testContainsAgainstBase(t *testing.T, iterations int, ipGen ipGenerator) { t.Errorf("netip addr convert fail") continue } - got := trie.Find(addr) - var gotvalue bool - if got != nil { - gotvalue = true - } - assert.Equal(t, expected, gotvalue) + got := trie.Contains(addr) + assert.Equal(t, expected, got) } } @@ -315,7 +311,7 @@ func benchmarkNewPathprefixTrie(b *testing.B, net1 string) { func benchmarkTrieContainsUsingAWSRanges(tb testing.TB, nn netip.Addr, trie *iptrie.Trie) { configureTrieWithAWSRanges(tb, trie) for n := 0; n < tb.(*testing.B).N; n++ { - trie.Find(nn) + trie.Contains(nn) } } diff --git a/iptire/trie.go b/iptire/trie.go index 31913c3..109db95 100644 --- a/iptire/trie.go +++ b/iptire/trie.go @@ -59,8 +59,7 @@ func (p *Trie) Remove(network netip.Prefix) any { return p.remove(network) } -// Find returns the value from the smallest prefix containing the given address. -func (p *Trie) Find(ip netip.Addr) *Entry { +func (p *Trie) Contains(ip netip.Addr) bool { ip = normalizeAddr(ip) return p.find(ip) } @@ -100,22 +99,22 @@ func (p *Trie) String() string { p.value != nil, strings.Join(children, "")) } -func (p *Trie) find(number netip.Addr) *Entry { +func (p *Trie) find(number netip.Addr) bool { if !netContains(p.network, number) { - return nil + return false } if p.value != nil { - return &Entry{p.network, p.value} + return true } if p.network.Bits() == 128 { - return nil + return false } bit := p.discriminatorBitFromIP(number) child := p.children[bit] if child != nil { return child.find(number) } - return nil + return false } func (p *Trie) containingNetworks(addr netip.Addr) []*Entry { From 9d947812d0be052f30aeb72e59cd9c0bfc408e6f Mon Sep 17 00:00:00 2001 From: DI LI Date: Sat, 23 Mar 2024 14:52:42 -0700 Subject: [PATCH 4/5] make it generic, so we don't have to use type assertion at all. --- cidranger_test.go | 79 +++++++++++++++++++++------- iptire/trie.go | 123 ++++++++++++++++++++++++++++---------------- iptire/trie_test.go | 65 +++++++++++++++++++++++ 3 files changed, 204 insertions(+), 63 deletions(-) create mode 100644 iptire/trie_test.go diff --git a/cidranger_test.go b/cidranger_test.go index 5009c83..01d5dde 100644 --- a/cidranger_test.go +++ b/cidranger_test.go @@ -54,7 +54,7 @@ func testContainsAgainstBase(t *testing.T, iterations int, ipGen ipGenerator) { } rangers := []Ranger{NewPCTrieRanger()} baseRanger := newBruteRanger() - trie := iptrie.NewTrie() + trie := iptrie.NewTrie[struct{}]() for _, ranger := range rangers { configureRangerWithAWSRanges(t, ranger) } @@ -75,8 +75,12 @@ func testContainsAgainstBase(t *testing.T, iterations int, ipGen ipGenerator) { t.Errorf("netip addr convert fail") continue } - got := trie.Contains(addr) - assert.Equal(t, expected, got) + got := trie.Find(addr) + var gotvalue bool + if got != nil { + gotvalue = true + } + assert.Equal(t, expected, gotvalue) } } @@ -93,7 +97,7 @@ func testContainingNetworksAgainstBase(t *testing.T, iterations int, ipGen ipGen } rangers := []Ranger{NewPCTrieRanger()} baseRanger := newBruteRanger() - trie := iptrie.NewTrie() + trie := iptrie.NewTrie[struct{}]() for _, ranger := range rangers { configureRangerWithAWSRanges(t, ranger) @@ -151,7 +155,7 @@ func testCoversNetworksAgainstBase(t *testing.T, iterations int, netGen networkG } rangers := []Ranger{NewPCTrieRanger()} baseRanger := newBruteRanger() - trie := iptrie.NewTrie() + trie := iptrie.NewTrie[struct{}]() for _, ranger := range rangers { configureRangerWithAWSRanges(t, ranger) @@ -214,7 +218,7 @@ func BenchmarkPCTrieHitIPv4UsingAWSRanges(b *testing.B) { } func BenchmarkTrieHitIPv4UsingAWSRanges(b *testing.B) { - benchmarkTrieContainsUsingAWSRanges(b, netip.MustParseAddr("52.95.110.1"), iptrie.NewTrie()) + benchmarkTrieContainsUsingAWSRanges(b, netip.MustParseAddr("52.95.110.1"), iptrie.NewTrie[struct{}]()) } func BenchmarkPCTrieHitIPv6UsingAWSRanges(b *testing.B) { @@ -222,7 +226,7 @@ func BenchmarkPCTrieHitIPv6UsingAWSRanges(b *testing.B) { } func BenchmarkTrieHitIPv6UsingAWSRanges(b *testing.B) { - benchmarkTrieContainsUsingAWSRanges(b, netip.MustParseAddr("2620:107:300f::36b7:ff81"), iptrie.NewTrie()) + benchmarkTrieContainsUsingAWSRanges(b, netip.MustParseAddr("2620:107:300f::36b7:ff81"), iptrie.NewTrie[struct{}]()) } func BenchmarkPCTrieMissIPv4UsingAWSRanges(b *testing.B) { @@ -230,7 +234,7 @@ func BenchmarkPCTrieMissIPv4UsingAWSRanges(b *testing.B) { } func BenchmarkTrieMissIPv4UsingAWSRanges(b *testing.B) { - benchmarkTrieContainsUsingAWSRanges(b, netip.MustParseAddr("123.123.123.123"), iptrie.NewTrie()) + benchmarkTrieContainsUsingAWSRanges(b, netip.MustParseAddr("123.123.123.123"), iptrie.NewTrie[struct{}]()) } func BenchmarkPCTrieHMissIPv6UsingAWSRanges(b *testing.B) { @@ -238,7 +242,7 @@ func BenchmarkPCTrieHMissIPv6UsingAWSRanges(b *testing.B) { } func BenchmarkTrieHMissIPv6UsingAWSRanges(b *testing.B) { - benchmarkTrieContainsUsingAWSRanges(b, netip.MustParseAddr("2620::ffff"), iptrie.NewTrie()) + benchmarkTrieContainsUsingAWSRanges(b, netip.MustParseAddr("2620::ffff"), iptrie.NewTrie[struct{}]()) } func BenchmarkPCTrieHitContainingNetworksIPv4UsingAWSRanges(b *testing.B) { @@ -246,7 +250,7 @@ func BenchmarkPCTrieHitContainingNetworksIPv4UsingAWSRanges(b *testing.B) { } func BenchmarkTrieHitContainingNetworksIPv4UsingAWSRanges(b *testing.B) { - benchmarkTrieContainingNetworksUsingAWSRanges(b, netip.MustParseAddr("52.95.110.1"), iptrie.NewTrie()) + benchmarkTrieContainingNetworksUsingAWSRanges(b, netip.MustParseAddr("52.95.110.1"), iptrie.NewTrie[struct{}]()) } func BenchmarkPCTrieHitContainingNetworksIPv6UsingAWSRanges(b *testing.B) { @@ -254,7 +258,7 @@ func BenchmarkPCTrieHitContainingNetworksIPv6UsingAWSRanges(b *testing.B) { } func BenchmarkTrieHitContainingNetworksIPv6UsingAWSRanges(b *testing.B) { - benchmarkTrieContainingNetworksUsingAWSRanges(b, netip.MustParseAddr("2620:107:300f::36b7:ff81"), iptrie.NewTrie()) + benchmarkTrieContainingNetworksUsingAWSRanges(b, netip.MustParseAddr("2620:107:300f::36b7:ff81"), iptrie.NewTrie[struct{}]()) } func BenchmarkPCTrieMissContainingNetworksIPv4UsingAWSRanges(b *testing.B) { @@ -262,7 +266,7 @@ func BenchmarkPCTrieMissContainingNetworksIPv4UsingAWSRanges(b *testing.B) { } func BenchmarkTrieMissContainingNetworksIPv4UsingAWSRanges(b *testing.B) { - benchmarkTrieContainingNetworksUsingAWSRanges(b, netip.MustParseAddr("123.123.123.123"), iptrie.NewTrie()) + benchmarkTrieContainingNetworksUsingAWSRanges(b, netip.MustParseAddr("123.123.123.123"), iptrie.NewTrie[struct{}]()) } func BenchmarkPCTrieHMissContainingNetworksIPv6UsingAWSRanges(b *testing.B) { @@ -270,7 +274,7 @@ func BenchmarkPCTrieHMissContainingNetworksIPv6UsingAWSRanges(b *testing.B) { } func BenchmarkTrieHMissContainingNetworksIPv6UsingAWSRanges(b *testing.B) { - benchmarkTrieContainingNetworksUsingAWSRanges(b, netip.MustParseAddr("2620::ffff"), iptrie.NewTrie()) + benchmarkTrieContainingNetworksUsingAWSRanges(b, netip.MustParseAddr("2620::ffff"), iptrie.NewTrie[struct{}]()) } func BenchmarkNewPathprefixTriev4(b *testing.B) { @@ -281,6 +285,38 @@ func BenchmarkNewPathprefixTriev6(b *testing.B) { benchmarkNewPathprefixTrie(b, "8000::/24") } +func BenchmarkPCTLoad(b *testing.B) { + ranger := NewPCTrieRanger() + b.ResetTimer() + + for n := 0; n < b.N; n++ { + for _, prefix := range awsRanges.Prefixes { + _, network, _ := net.ParseCIDR(prefix.IPPrefix) + _ = ranger.Insert(NewBasicRangerEntry(*network)) + } + for _, prefix := range awsRanges.IPv6Prefixes { + _, network, _ := net.ParseCIDR(prefix.IPPrefix) + _ = ranger.Insert(NewBasicRangerEntry(*network)) + } + } +} + +func BenchmarkTrieLoad(b *testing.B) { + trie := iptrie.NewTrie[struct{}]() + b.ResetTimer() + + for n := 0; n < b.N; n++ { + for _, prefix := range awsRanges.Prefixes { + network, _ := netip.ParsePrefix(prefix.IPPrefix) + trie.Insert(network, &struct{}{}) + } + for _, prefix := range awsRanges.IPv6Prefixes { + network, _ := netip.ParsePrefix(prefix.IPPrefix) + trie.Insert(network, &struct{}{}) + } + } +} + func benchmarkContainsUsingAWSRanges(tb testing.TB, nn net.IP, ranger Ranger) { configureRangerWithAWSRanges(tb, ranger) for n := 0; n < tb.(*testing.B).N; n++ { @@ -308,14 +344,21 @@ func benchmarkNewPathprefixTrie(b *testing.B, net1 string) { } } -func benchmarkTrieContainsUsingAWSRanges(tb testing.TB, nn netip.Addr, trie *iptrie.Trie) { +func benchmarkTrieContainsUsingAWSRanges(tb testing.TB, nn netip.Addr, trie *iptrie.Trie[struct{}]) { configureTrieWithAWSRanges(tb, trie) for n := 0; n < tb.(*testing.B).N; n++ { trie.Contains(nn) } } -func benchmarkTrieContainingNetworksUsingAWSRanges(tb testing.TB, nn netip.Addr, trie *iptrie.Trie) { +func benchmarkTrieFindUsingAWSRanges(tb testing.TB, nn netip.Addr, trie *iptrie.Trie[struct{}]) { + configureTrieWithAWSRanges(tb, trie) + for n := 0; n < tb.(*testing.B).N; n++ { + trie.Find(nn) + } +} + +func benchmarkTrieContainingNetworksUsingAWSRanges(tb testing.TB, nn netip.Addr, trie *iptrie.Trie[struct{}]) { configureTrieWithAWSRanges(tb, trie) for n := 0; n < tb.(*testing.B).N; n++ { trie.ContainingNetworks(nn) @@ -404,16 +447,16 @@ func configureRangerWithAWSRanges(tb testing.TB, ranger Ranger) { } } -func configureTrieWithAWSRanges(tb testing.TB, trie *iptrie.Trie) { +func configureTrieWithAWSRanges(tb testing.TB, trie *iptrie.Trie[struct{}]) { for _, prefix := range awsRanges.Prefixes { network, err := netip.ParsePrefix(prefix.IPPrefix) assert.NoError(tb, err) - trie.Insert(network, struct{}{}) + trie.Insert(network, &struct{}{}) } for _, prefix := range awsRanges.IPv6Prefixes { network, err := netip.ParsePrefix(prefix.IPPrefix) assert.NoError(tb, err) - trie.Insert(network, struct{}{}) + trie.Insert(network, &struct{}{}) } } diff --git a/iptire/trie.go b/iptire/trie.go index 109db95..65e3dff 100644 --- a/iptire/trie.go +++ b/iptire/trie.go @@ -20,53 +20,59 @@ import ( // // Path compression compresses a string of node with only 1 child into a single // node, decrease the amount of lookups necessary during containment tests. -type Trie struct { - parent *Trie - children [2]*Trie +type Trie[T any] struct { + parent *Trie[T] + children [2]*Trie[T] network netip.Prefix - value any + value *T } -type Entry struct { +type Entry[T any] struct { Network netip.Prefix - Value any + Value *T } // NewTrie creates a new Trie. -func NewTrie() *Trie { - return &Trie{ +func NewTrie[T any]() *Trie[T] { + return &Trie[T]{ network: netip.PrefixFrom(netip.IPv6Unspecified(), 0), } } -func newSubTree(network netip.Prefix, value any) *Trie { - return &Trie{ +func newSubTree[T any](network netip.Prefix, value *T) *Trie[T] { + return &Trie[T]{ network: network, value: value, } } // Insert inserts a RangerEntry into prefix trie. -func (p *Trie) Insert(network netip.Prefix, value any) { +func (p *Trie[T]) Insert(network netip.Prefix, value *T) { network = normalizePrefix(network) p.insert(network, value) } // Remove removes RangerEntry identified by given network from trie. -func (p *Trie) Remove(network netip.Prefix) any { +func (p *Trie[T]) Remove(network netip.Prefix) *T { network = normalizePrefix(network) return p.remove(network) } -func (p *Trie) Contains(ip netip.Addr) bool { +// Find returns the value from the smallest prefix containing the given address. +func (p *Trie[T]) Find(ip netip.Addr) *Entry[T] { ip = normalizeAddr(ip) return p.find(ip) } +func (p *Trie[T]) Contains(ip netip.Addr) bool { + ip = normalizeAddr(ip) + return p.contains(ip) +} + // ContainingNetworks returns the list of RangerEntry(s) the given ip is // contained in in ascending prefix order. -func (p *Trie) ContainingNetworks(ip netip.Addr) []*Entry { +func (p *Trie[T]) ContainingNetworks(ip netip.Addr) []*Entry[T] { ip = normalizeAddr(ip) return p.containingNetworks(ip) } @@ -74,18 +80,18 @@ func (p *Trie) ContainingNetworks(ip netip.Addr) []*Entry { // CoveredNetworks returns the list of RangerEntry(s) the given ipnet // covers. That is, the networks that are completely subsumed by the // specified network. -func (p *Trie) CoveredNetworks(network netip.Prefix) []*Entry { +func (p *Trie[T]) CoveredNetworks(network netip.Prefix) []*Entry[T] { network = normalizePrefix(network) return p.coveredNetworks(network) } -func (p *Trie) Network() netip.Prefix { +func (p *Trie[T]) Network() netip.Prefix { return p.network } // String returns string representation of trie, mainly for visualization and // debugging. -func (p *Trie) String() string { +func (p *Trie[T]) String() string { children := []string{} padding := strings.Repeat("| ", p.level()+1) for bit, child := range p.children { @@ -99,31 +105,57 @@ func (p *Trie) String() string { p.value != nil, strings.Join(children, "")) } -func (p *Trie) find(number netip.Addr) bool { +func (p *Trie[T]) contains(number netip.Addr) bool { if !netContains(p.network, number) { return false } + if p.value != nil { return true } + if p.network.Bits() == 128 { return false } bit := p.discriminatorBitFromIP(number) child := p.children[bit] if child != nil { - return child.find(number) + return child.contains(number) } + return false } -func (p *Trie) containingNetworks(addr netip.Addr) []*Entry { - var results []*Entry +func (p *Trie[T]) find(number netip.Addr) *Entry[T] { + if !netContains(p.network, number) { + return nil + } + + if p.network.Bits() == 128 { + return nil + } + bit := p.discriminatorBitFromIP(number) + child := p.children[bit] + if child != nil { + r := child.find(number) + if r != nil { + return r + } + } + + if p.value != nil { + return &Entry[T]{p.network, p.value} + } + return nil +} + +func (p *Trie[T]) containingNetworks(addr netip.Addr) []*Entry[T] { + var results []*Entry[T] if !p.network.Contains(addr) { return results } if p.value != nil { - results = []*Entry{{p.network, p.value}} + results = []*Entry[T]{{p.network, p.value}} } if p.network.Bits() == 128 { return results @@ -143,8 +175,8 @@ func (p *Trie) containingNetworks(addr netip.Addr) []*Entry { return results } -func (p *Trie) coveredNetworks(network netip.Prefix) []*Entry { - var results []*Entry +func (p *Trie[T]) coveredNetworks(network netip.Prefix) []*Entry[T] { + var results []*Entry[T] if network.Bits() <= p.network.Bits() && network.Contains(p.network.Addr()) { for entry := range p.walkDepth() { results = append(results, entry) @@ -190,7 +222,7 @@ func netDivergence(net1 netip.Prefix, net2 netip.Prefix) netip.Prefix { return pfx } -func (p *Trie) insert(network netip.Prefix, value any) *Trie { +func (p *Trie[T]) insert(network netip.Prefix, value *T) *Trie[T] { if p.network == network { p.value = value return p @@ -210,7 +242,8 @@ func (p *Trie) insert(network netip.Prefix, value any) *Trie { // in the case that inserted network diverges on its path to existing child. netdiv := netDivergence(existingChild.network, network) if netdiv != existingChild.network { - pathPrefix := newSubTree(netdiv, nil) + var x *T = nil + pathPrefix := newSubTree(netdiv, x) p.insertPrefix(bit, pathPrefix, existingChild) // Update new child existingChild = pathPrefix @@ -218,12 +251,12 @@ func (p *Trie) insert(network netip.Prefix, value any) *Trie { return existingChild.insert(network, value) } -func (p *Trie) appendTrie(bit uint8, prefix *Trie) { +func (p *Trie[T]) appendTrie(bit uint8, prefix *Trie[T]) { p.children[bit] = prefix prefix.parent = p } -func (p *Trie) insertPrefix(bit uint8, pathPrefix, child *Trie) { +func (p *Trie[T]) insertPrefix(bit uint8, pathPrefix, child *Trie[T]) { // Set parent/child relationship between current trie and inserted pathPrefix p.children[bit] = pathPrefix pathPrefix.parent = p @@ -234,7 +267,7 @@ func (p *Trie) insertPrefix(bit uint8, pathPrefix, child *Trie) { child.parent = pathPrefix } -func (p *Trie) remove(network netip.Prefix) any { +func (p *Trie[T]) remove(network netip.Prefix) *T { if p.value != nil && p.network == network { entry := p.value p.value = nil @@ -253,7 +286,7 @@ func (p *Trie) remove(network netip.Prefix) any { return nil } -func (p *Trie) qualifiesForPathCompression() bool { +func (p *Trie[T]) qualifiesForPathCompression() bool { // Current prefix trie can be path compressed if it meets all following. // 1. records no CIDR entry // 2. has single or no child @@ -261,14 +294,14 @@ func (p *Trie) qualifiesForPathCompression() bool { return p.value == nil && p.childrenCount() <= 1 && p.parent != nil } -func (p *Trie) compressPathIfPossible() { +func (p *Trie[T]) compressPathIfPossible() { if !p.qualifiesForPathCompression() { // Does not qualify to be compressed return } // Find lone child. - var loneChild *Trie + var loneChild *Trie[T] for _, child := range p.children { if child != nil { loneChild = child @@ -288,7 +321,7 @@ func (p *Trie) compressPathIfPossible() { parent.compressPathIfPossible() } -func (p *Trie) childrenCount() int { +func (p *Trie[T]) childrenCount() int { count := 0 for _, child := range p.children { if child != nil { @@ -298,7 +331,7 @@ func (p *Trie) childrenCount() int { return count } -func (p *Trie) discriminatorBitFromIP(addr netip.Addr) uint8 { +func (p *Trie[T]) discriminatorBitFromIP(addr netip.Addr) uint8 { // This is a safe uint boxing of int since we should never attempt to get // target bit at a negative position. pos := p.network.Bits() @@ -309,7 +342,7 @@ func (p *Trie) discriminatorBitFromIP(addr netip.Addr) uint8 { return uint8(a128.lo >> (63 - (pos - 64)) & 1) } -func (p *Trie) level() int { +func (p *Trie[T]) level() int { if p.parent == nil { return 0 } @@ -317,13 +350,13 @@ func (p *Trie) level() int { } // walkDepth walks the trie in depth order, for unit testing. -func (p *Trie) walkDepth() <-chan *Entry { - entries := make(chan *Entry) +func (p *Trie[T]) walkDepth() <-chan *Entry[T] { + entries := make(chan *Entry[T]) go func() { if p.value != nil { - entries <- &Entry{p.network, p.value} + entries <- &Entry[T]{p.network, p.value} } - childEntriesList := []<-chan *Entry{} + childEntriesList := []<-chan *Entry[T]{} for _, trie := range p.children { if trie == nil { continue @@ -343,19 +376,19 @@ func (p *Trie) walkDepth() <-chan *Entry { // TrieLoader can be used to improve the performance of bulk inserts to a Trie. It caches the node of the // last insert in the tree, using it as the starting point to start searching for the location of the next insert. This // is highly beneficial when the addresses are pre-sorted. -type TrieLoader struct { - trie *Trie - lastInsert *Trie +type TrieLoader[T any] struct { + trie *Trie[T] + lastInsert *Trie[T] } -func NewTrieLoader(trie *Trie) *TrieLoader { - return &TrieLoader{ +func NewTrieLoader[T any](trie *Trie[T]) *TrieLoader[T] { + return &TrieLoader[T]{ trie: trie, lastInsert: trie, } } -func (ptl *TrieLoader) Insert(pfx netip.Prefix, v any) { +func (ptl *TrieLoader[T]) Insert(pfx netip.Prefix, v *T) { pfx = normalizePrefix(pfx) diff := addr128(ptl.lastInsert.network.Addr()).xor(addr128(pfx.Addr())) diff --git a/iptire/trie_test.go b/iptire/trie_test.go new file mode 100644 index 0000000..baa697e --- /dev/null +++ b/iptire/trie_test.go @@ -0,0 +1,65 @@ +package iptrie + +import ( + "log" + "net/netip" + "testing" +) + +func TestFind(t *testing.T) { + trie := NewTrie[[]string]() + n1 := netip.MustParsePrefix("8.8.8.0/24") + trie.Insert(n1, &[]string{"ABC", "DEF"}) + + n2 := netip.MustParsePrefix("8.8.8.128/25") + trie.Insert(n2, &[]string{"ABC", "xyz"}) + + e1 := trie.Find(netip.MustParseAddr("8.8.8.130")) + if e1 == nil { + t.Errorf("shouldn't be nil") + } + + for _, a := range *e1.Value { + log.Printf("%+v\n", a) + } +} + +func TestContainNetwork(t *testing.T) { + trie := NewTrie[[]string]() + n1 := netip.MustParsePrefix("8.8.8.0/24") + trie.Insert(n1, &[]string{"ABC", "DEF"}) + + n2 := netip.MustParsePrefix("8.8.8.0/25") + trie.Insert(n2, &[]string{"ABC", "xyz"}) + + ee := trie.ContainingNetworks(netip.MustParseAddr("8.8.8.0")) + if ee == nil { + t.Errorf("should find") + } + + for _, e := range ee { + for _, a := range *e.Value { + log.Printf("%+v\n", a) + } + } +} + +func TestCoveredNetworks(t *testing.T) { + trie := NewTrie[[]string]() + n1 := netip.MustParsePrefix("8.8.8.0/24") + trie.Insert(n1, &[]string{"ABC", "DEF"}) + + n2 := netip.MustParsePrefix("8.8.8.0/25") + trie.Insert(n2, &[]string{"ABC", "xyz"}) + + ee := trie.CoveredNetworks(netip.MustParsePrefix("8.8.0.0/16")) + if ee == nil { + t.Errorf("should find") + } + + for _, e := range ee { + for _, a := range *e.Value { + log.Printf("%+v\n", a) + } + } +} From e1b24622494a8b7929c7bf8d6d3760e65971ba04 Mon Sep 17 00:00:00 2001 From: DI LI Date: Sat, 23 Mar 2024 14:58:54 -0700 Subject: [PATCH 5/5] add notes --- iptire/trie.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/iptire/trie.go b/iptire/trie.go index 65e3dff..b681560 100644 --- a/iptire/trie.go +++ b/iptire/trie.go @@ -1,6 +1,13 @@ // Package iptrie is a fork of github.com/yl2chen/cidranger. This fork massively strips down and refactors the code for // increased performance, resulting in 20x faster load time, and 1.5x faster lookups. +// most code is from https://gist.github.com/phemmer/6231b12d5207ea93a1690ddc44a2c811 +// several modification as following: +// 1. add a Contains interface to match original contains interface +// 2. fix Find to return most specific entry instead of first match entry +// 3. ContainingNetworks and CoveredNetworks will return entries instead of purely networks, I believe this is the whole point of this lib +// 4. as we start return entries, now whole code base to use generic, so we don't need to pay tax in runtime as type assertion + package iptrie import (