diff --git a/node.go b/node.go index 61a47e5..c2148d5 100644 --- a/node.go +++ b/node.go @@ -71,36 +71,7 @@ func (r *record) insert( if err != nil { return err } - - // Check to see if the children are the same and can be merged. - child0 := r.node.children[0] - child1 := r.node.children[1] - if child0.recordType != child1.recordType { - return nil - } - switch child0.recordType { - // Nodes can't be merged - case recordTypeFixedNode, - recordTypeNode: - return nil - case recordTypeEmpty, - recordTypeReserved: - r.recordType = child0.recordType - r.node = nil - return nil - case recordTypeData: - if child0.value.key != child1.value.key { - return nil - } - // Children have same data and can be merged - r.recordType = recordTypeData - r.value = child0.value - iRec.dataMap.remove(child1.value) - r.node = nil - return nil - default: - return fmt.Errorf("merging record type %d is not implemented", child0.recordType) - } + return r.maybeMergeChildren(iRec) case recordTypeFixedNode: return r.node.insert(iRec, newDepth) case recordTypeEmpty, recordTypeData: @@ -139,7 +110,11 @@ func (r *record) insert( r.node = &node{children: [2]record{*r, *r}} r.value = nil r.recordType = recordTypeNode - return r.node.insert(iRec, newDepth) + err := r.node.insert(iRec, newDepth) + if err != nil { + return err + } + return r.maybeMergeChildren(iRec) case recordTypeReserved: if iRec.prefixLen >= newDepth { return fmt.Errorf( @@ -168,6 +143,36 @@ func (r *record) insert( } } +func (r *record) maybeMergeChildren(iRec insertRecord) error { + // Check to see if the children are the same and can be merged. + child0 := r.node.children[0] + child1 := r.node.children[1] + if child0.recordType != child1.recordType { + return nil + } + switch child0.recordType { + // Nodes can't be merged + case recordTypeFixedNode, recordTypeNode: + return nil + case recordTypeEmpty, recordTypeReserved: + r.recordType = child0.recordType + r.node = nil + return nil + case recordTypeData: + if child0.value.key != child1.value.key { + return nil + } + // Children have same data and can be merged + r.recordType = recordTypeData + r.value = child0.value + iRec.dataMap.remove(child1.value) + r.node = nil + return nil + default: + return fmt.Errorf("merging record type %d is not implemented", child0.recordType) + } +} + func (n *node) get( ip net.IP, depth int, diff --git a/tree_test.go b/tree_test.go index 3c8dbc9..f035172 100644 --- a/tree_test.go +++ b/tree_test.go @@ -380,7 +380,7 @@ func TestTreeInsertAndGet(t *testing.T) { expectedNodeCount: 368, }, { - name: "node pruning", + name: "node pruning - adjacent", inserts: []testInsert{ { network: "1.1.0.0/24", @@ -410,6 +410,76 @@ func TestTreeInsertAndGet(t *testing.T) { }, expectedNodeCount: 366, }, + { + name: "node pruning - inserting smaller duplicate into larger", + inserts: []testInsert{ + { + network: "1.1.0.0/24", + start: "1.1.0.0", + end: "1.1.0.255", + value: mmdbtype.Map{"a": mmdbtype.Slice{mmdbtype.Uint64(1), mmdbtype.Bytes{1, 2}}}, + }, + { + network: "1.1.0.128/26", + start: "1.1.0.128", + end: "1.1.0.191", + // We intentionally don't use the same variable for + // here and above as we want them to be different instances. + value: mmdbtype.Map{"a": mmdbtype.Slice{mmdbtype.Uint64(1), mmdbtype.Bytes{1, 2}}}, + }, + }, + gets: []testGet{ + { + ip: "1.1.0.0", + expectedNetwork: "1.1.0.0/24", + expectedGetValue: mmdbtype.Map{"a": mmdbtype.Slice{mmdbtype.Uint64(1), mmdbtype.Bytes{1, 2}}}, + expectedLookupValue: func() *any { + v := any(map[string]any{"a": []any{uint64(1), []byte{1, 2}}}) + return &v + }(), + }, + }, + expectedNodeCount: 367, + }, + { + name: "node pruning - inserting smaller non-duplicate and then duplicate into larger", + inserts: []testInsert{ + { + network: "1.1.0.0/24", + start: "1.1.0.0", + end: "1.1.0.255", + value: mmdbtype.Map{"a": mmdbtype.Slice{mmdbtype.Uint64(1), mmdbtype.Bytes{1, 2}}}, + }, + { + network: "1.1.0.128/26", + start: "1.1.0.128", + end: "1.1.0.191", + // We intentionally don't use the same variable for + // here and above as we want them to be different instances. + value: mmdbtype.Map{"a": mmdbtype.Int32(1)}, + }, + { + network: "1.1.0.128/26", + start: "1.1.0.128", + end: "1.1.0.191", + // We intentionally don't use the same variable for + // here and above as we want them to be different instances. + value: mmdbtype.Map{"a": mmdbtype.Slice{mmdbtype.Uint64(1), mmdbtype.Bytes{1, 2}}}, + }, + }, + gets: []testGet{ + { + ip: "1.1.0.0", + expectedNetwork: "1.1.0.0/24", + expectedGetValue: mmdbtype.Map{"a": mmdbtype.Slice{mmdbtype.Uint64(1), mmdbtype.Bytes{1, 2}}}, + expectedLookupValue: func() *any { + v := any(map[string]any{"a": []any{uint64(1), []byte{1, 2}}}) + return &v + }(), + }, + }, + expectedNodeCount: 367, + }, { name: "insertion of range with multiple subnets", insertType: "range",