Skip to content

Commit

Permalink
Prune node when inserting into data/empty record
Browse files Browse the repository at this point in the history
Previously we only pruned if the existing record was a node.
  • Loading branch information
oschwald committed Jul 28, 2023
1 parent bb691ac commit 7cb4df0
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 32 deletions.
67 changes: 36 additions & 31 deletions node.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,36 +71,7 @@ func (r *record) insert(
if err != nil {
return err
}

// Check to see if the children are the same and can be merged.
child0 := r.node.children[0]
child1 := r.node.children[1]
if child0.recordType != child1.recordType {
return nil
}
switch child0.recordType {
// Nodes can't be merged
case recordTypeFixedNode,
recordTypeNode:
return nil
case recordTypeEmpty,
recordTypeReserved:
r.recordType = child0.recordType
r.node = nil
return nil
case recordTypeData:
if child0.value.key != child1.value.key {
return nil
}
// Children have same data and can be merged
r.recordType = recordTypeData
r.value = child0.value
iRec.dataMap.remove(child1.value)
r.node = nil
return nil
default:
return fmt.Errorf("merging record type %d is not implemented", child0.recordType)
}
return r.maybeMergeChildren(iRec)
case recordTypeFixedNode:
return r.node.insert(iRec, newDepth)
case recordTypeEmpty, recordTypeData:
Expand Down Expand Up @@ -139,7 +110,11 @@ func (r *record) insert(
r.node = &node{children: [2]record{*r, *r}}
r.value = nil
r.recordType = recordTypeNode
return r.node.insert(iRec, newDepth)
err := r.node.insert(iRec, newDepth)
if err != nil {
return err
}
return r.maybeMergeChildren(iRec)
case recordTypeReserved:
if iRec.prefixLen >= newDepth {
return fmt.Errorf(
Expand Down Expand Up @@ -168,6 +143,36 @@ func (r *record) insert(
}
}

func (r *record) maybeMergeChildren(iRec insertRecord) error {
// Check to see if the children are the same and can be merged.
child0 := r.node.children[0]
child1 := r.node.children[1]
if child0.recordType != child1.recordType {
return nil
}
switch child0.recordType {
// Nodes can't be merged
case recordTypeFixedNode, recordTypeNode:
return nil
case recordTypeEmpty, recordTypeReserved:
r.recordType = child0.recordType
r.node = nil
return nil
case recordTypeData:
if child0.value.key != child1.value.key {
return nil
}
// Children have same data and can be merged
r.recordType = recordTypeData
r.value = child0.value
iRec.dataMap.remove(child1.value)
r.node = nil
return nil
default:
return fmt.Errorf("merging record type %d is not implemented", child0.recordType)
}
}

func (n *node) get(
ip net.IP,
depth int,
Expand Down
72 changes: 71 additions & 1 deletion tree_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ func TestTreeInsertAndGet(t *testing.T) {
expectedNodeCount: 368,
},
{
name: "node pruning",
name: "node pruning - adjacent",
inserts: []testInsert{
{
network: "1.1.0.0/24",
Expand Down Expand Up @@ -410,6 +410,76 @@ func TestTreeInsertAndGet(t *testing.T) {
},
expectedNodeCount: 366,
},
{
name: "node pruning - inserting smaller duplicate into larger",
inserts: []testInsert{
{
network: "1.1.0.0/24",
start: "1.1.0.0",
end: "1.1.0.255",
value: mmdbtype.Map{"a": mmdbtype.Slice{mmdbtype.Uint64(1), mmdbtype.Bytes{1, 2}}},
},
{
network: "1.1.0.128/26",
start: "1.1.0.128",
end: "1.1.0.191",
// We intentionally don't use the same variable for
// here and above as we want them to be different instances.
value: mmdbtype.Map{"a": mmdbtype.Slice{mmdbtype.Uint64(1), mmdbtype.Bytes{1, 2}}},
},
},
gets: []testGet{
{
ip: "1.1.0.0",
expectedNetwork: "1.1.0.0/24",
expectedGetValue: mmdbtype.Map{"a": mmdbtype.Slice{mmdbtype.Uint64(1), mmdbtype.Bytes{1, 2}}},
expectedLookupValue: func() *any {
v := any(map[string]any{"a": []any{uint64(1), []byte{1, 2}}})
return &v
}(),
},
},
expectedNodeCount: 367,
},
{
name: "node pruning - inserting smaller non-duplicate and then duplicate into larger",
inserts: []testInsert{
{
network: "1.1.0.0/24",
start: "1.1.0.0",
end: "1.1.0.255",
value: mmdbtype.Map{"a": mmdbtype.Slice{mmdbtype.Uint64(1), mmdbtype.Bytes{1, 2}}},
},
{
network: "1.1.0.128/26",
start: "1.1.0.128",
end: "1.1.0.191",
// We intentionally don't use the same variable for
// here and above as we want them to be different instances.
value: mmdbtype.Map{"a": mmdbtype.Int32(1)},
},
{
network: "1.1.0.128/26",
start: "1.1.0.128",
end: "1.1.0.191",
// We intentionally don't use the same variable for
// here and above as we want them to be different instances.
value: mmdbtype.Map{"a": mmdbtype.Slice{mmdbtype.Uint64(1), mmdbtype.Bytes{1, 2}}},
},
},
gets: []testGet{
{
ip: "1.1.0.0",
expectedNetwork: "1.1.0.0/24",
expectedGetValue: mmdbtype.Map{"a": mmdbtype.Slice{mmdbtype.Uint64(1), mmdbtype.Bytes{1, 2}}},
expectedLookupValue: func() *any {
v := any(map[string]any{"a": []any{uint64(1), []byte{1, 2}}})
return &v
}(),
},
},
expectedNodeCount: 367,
},
{
name: "insertion of range with multiple subnets",
insertType: "range",
Expand Down

0 comments on commit 7cb4df0

Please sign in to comment.