From 226a91405b361dd059ca126e453e03349d53b8e4 Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Sun, 31 Dec 2023 01:19:29 -0500 Subject: [PATCH] fix: handle leaf encoding correctly (#11) Leaf nodes should not be encoded with children --- pkg/btree/node.go | 22 +++++----- pkg/btree/node_test.go | 92 +++++++++++++++++++++++++++++++++++++++--- 2 files changed, 99 insertions(+), 15 deletions(-) diff --git a/pkg/btree/node.go b/pkg/btree/node.go index d57a4112..9560f274 100644 --- a/pkg/btree/node.go +++ b/pkg/btree/node.go @@ -52,9 +52,11 @@ func (n *Node) WriteTo(w io.Writer) (int64, error) { return 0, err } } - for _, child := range n.Children { - if err := encoding.WriteUint64(w, child); err != nil { - return 0, err + if !n.Leaf { + for _, child := range n.Children { + if err := encoding.WriteUint64(w, child); err != nil { + return 0, err + } } } return int64(1 + 16*len(n.Keys) + 8*len(n.Children)), nil @@ -68,7 +70,6 @@ func (n *Node) ReadFrom(r io.Reader) (int64, error) { n.Leaf = size&(1<<7) != 0 size = size & (1<<7 - 1) n.Keys = make([]DataPointer, size) - n.Children = make([]uint64, size+1) for i := 0; i < int(size); i++ { recordOffset, err := encoding.ReadUint64(r) if err != nil { @@ -88,12 +89,15 @@ func (n *Node) ReadFrom(r io.Reader) (int64, error) { Length: length, } } - for i := 0; i <= int(size); i++ { - child, err := encoding.ReadUint64(r) - if err != nil { - return 0, err + if !n.Leaf { + n.Children = make([]uint64, size+1) + for i := 0; i <= int(size); i++ { + child, err := encoding.ReadUint64(r) + if err != nil { + return 0, err + } + n.Children[i] = child } - n.Children[i] = child } return 1 + 16*int64(size) + 8*int64(size+1), nil } diff --git a/pkg/btree/node_test.go b/pkg/btree/node_test.go index 21213706..3e3eb34c 100644 --- a/pkg/btree/node_test.go +++ b/pkg/btree/node_test.go @@ -7,7 +7,32 @@ import ( ) func TestNode(t *testing.T) { - t.Run("encode", func(t *testing.T) { + t.Run("encode leaf", func(t *testing.T) { + n := &Node{ + Keys: []DataPointer{ + { + RecordOffset: 0, + FieldOffset: 0, + Length: 5, + }, + { + RecordOffset: 0, + FieldOffset: 5, + Length: 5, + }, + }, + Leaf: true, + } + buf := &bytes.Buffer{} + if _, err := n.WriteTo(buf); err != nil { + t.Fatal(err) + } + if buf.Len() != 1+16*2 { + t.Fatalf("expected buffer length to be 1+16*2+8*3, got %d", buf.Len()) + } + }) + + t.Run("encode leaf ignores children", func(t *testing.T) { n := &Node{ Keys: []DataPointer{ { @@ -21,8 +46,34 @@ func TestNode(t *testing.T) { Length: 5, }, }, - Children: []uint64{0, 1, 2}, Leaf: true, + Children: []uint64{1, 2, 3}, + } + buf := &bytes.Buffer{} + if _, err := n.WriteTo(buf); err != nil { + t.Fatal(err) + } + if buf.Len() != 1+16*2 { + t.Fatalf("expected buffer length to be 1+16*2+8*3, got %d", buf.Len()) + } + }) + + t.Run("encode non-leaf", func(t *testing.T) { + n := &Node{ + Keys: []DataPointer{ + { + RecordOffset: 0, + FieldOffset: 0, + Length: 5, + }, + { + RecordOffset: 0, + FieldOffset: 5, + Length: 5, + }, + }, + Leaf: false, + Children: []uint64{1, 2, 3}, } buf := &bytes.Buffer{} if _, err := n.WriteTo(buf); err != nil { @@ -33,7 +84,7 @@ func TestNode(t *testing.T) { } }) - t.Run("decode", func(t *testing.T) { + t.Run("decode leaf", func(t *testing.T) { n := &Node{ Keys: []DataPointer{ { @@ -47,8 +98,37 @@ func TestNode(t *testing.T) { Length: 5, }, }, - Children: []uint64{0, 1, 2}, - Leaf: true, + Leaf: true, + } + buf := &bytes.Buffer{} + if _, err := n.WriteTo(buf); err != nil { + t.Fatal(err) + } + m := &Node{} + if _, err := m.ReadFrom(buf); err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(n, m) { + t.Fatalf("expected decoded node to be equal to original node, got %#v want %#v", m, n) + } + }) + + t.Run("decode non-leaf", func(t *testing.T) { + n := &Node{ + Keys: []DataPointer{ + { + RecordOffset: 0, + FieldOffset: 0, + Length: 5, + }, + { + RecordOffset: 0, + FieldOffset: 5, + Length: 5, + }, + }, + Leaf: false, + Children: []uint64{1, 2, 3}, } buf := &bytes.Buffer{} if _, err := n.WriteTo(buf); err != nil { @@ -59,7 +139,7 @@ func TestNode(t *testing.T) { t.Fatal(err) } if !reflect.DeepEqual(n, m) { - t.Fatalf("expected decoded node to be equal to original node") + t.Fatalf("expected decoded node to be equal to original node, got %#v want %#v", m, n) } }) }