Skip to content
This repository has been archived by the owner on Feb 27, 2023. It is now read-only.

Fix proof direction and key bit order #22

Merged
merged 9 commits into from
Mar 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mapstore.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ type InvalidKeyError struct {
}

func (e *InvalidKeyError) Error() string {
return fmt.Sprintf("invalid key: %s", e.Key)
return fmt.Sprintf("invalid key: %x", e.Key)
}

// SimpleMap is a simple in-memory map.
Expand Down
8 changes: 4 additions & 4 deletions proofs.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,11 @@ func verifyProofWithUpdates(proof SparseMerkleProof, root []byte, key []byte, va
}

// Recompute root.
for i := len(proof.SideNodes) - 1; i >= 0; i-- {
for i := 0; i < len(proof.SideNodes); i++ {
node := make([]byte, th.pathSize())
copy(node, proof.SideNodes[i])

if hasBit(path, i) == right {
if getBitAtFromMSB(path, len(proof.SideNodes)-1-i) == right {
currentHash, currentData = th.digestNode(node, currentHash)
} else {
currentHash, currentData = th.digestNode(currentHash, node)
Expand Down Expand Up @@ -170,7 +170,7 @@ func CompactProof(proof SparseMerkleProof, hasher hash.Hash) (SparseCompactMerkl
node := make([]byte, th.hasher.Size())
copy(node, proof.SideNodes[i])
if bytes.Equal(node, th.placeholder()) {
setBit(bitMask, i)
setBitAtFromMSB(bitMask, i)
} else {
compactedSideNodes = append(compactedSideNodes, node)
}
Expand All @@ -195,7 +195,7 @@ func DecompactProof(proof SparseCompactMerkleProof, hasher hash.Hash) (SparseMer
decompactedSideNodes := make([][]byte, proof.NumSideNodes)
position := 0
for i := 0; i < proof.NumSideNodes; i++ {
if hasBit(proof.BitMask, i) == 1 {
if getBitAtFromMSB(proof.BitMask, i) == 1 {
decompactedSideNodes[i] = th.placeholder()
} else {
decompactedSideNodes[i] = proof.SideNodes[position]
Expand Down
4 changes: 2 additions & 2 deletions proofs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ func TestCompactProofsSanityCheck(t *testing.T) {

// Case (compact proofs): unexpected bit mask length.
proof, _ = smt.ProveCompact([]byte("testKey1"))
proof.NumSideNodes = 1
proof.NumSideNodes = 10
if proof.sanityCheck(th) {
t.Error("sanity check incorrectly passed")
}
Expand All @@ -221,7 +221,7 @@ func TestCompactProofsSanityCheck(t *testing.T) {

// Case (compact proofs): unexpected number of sidenodes for number of side nodes.
proof, _ = smt.ProveCompact([]byte("testKey1"))
proof.SideNodes = proof.SideNodes[:1]
proof.SideNodes = append(proof.SideNodes, proof.SideNodes...)
if proof.sanityCheck(th) {
t.Error("sanity check incorrectly passed")
}
Expand Down
38 changes: 22 additions & 16 deletions smt.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func (smt *SparseMerkleTree) GetForRoot(key []byte, root []byte) ([]byte, error)
}

leftNode, rightNode := smt.th.parseNode(currentData)
if hasBit(path, i) == right {
if getBitAtFromMSB(path, i) == right {
currentHash = rightNode
} else {
currentHash = leftNode
Expand Down Expand Up @@ -188,7 +188,7 @@ func (smt *SparseMerkleTree) deleteWithSideNodes(path []byte, sideNodes [][]byte

var currentHash, currentData []byte
nonPlaceholderReached := false
for i := smt.depth() - 1; i >= 0; i-- {
for i := 0; i < len(sideNodes); i++ {
if sideNodes[i] == nil {
continue
}
Expand All @@ -215,7 +215,7 @@ func (smt *SparseMerkleTree) deleteWithSideNodes(path []byte, sideNodes [][]byte
}

if !nonPlaceholderReached && bytes.Equal(sideNode, smt.th.placeholder()) {
// We found another placeholder sibling node, keep going down the
// We found another placeholder sibling node, keep going up the
// tree until we find the first sibling that is not a placeholder.
continue
} else if !nonPlaceholderReached {
Expand All @@ -224,7 +224,7 @@ func (smt *SparseMerkleTree) deleteWithSideNodes(path []byte, sideNodes [][]byte
nonPlaceholderReached = true
}

if hasBit(path, i) == right {
if getBitAtFromMSB(path, len(sideNodes)-1-i) == right {
currentHash, currentData = smt.th.digestNode(sideNode, currentData)
} else {
currentHash, currentData = smt.th.digestNode(currentData, sideNode)
Expand Down Expand Up @@ -269,7 +269,7 @@ func (smt *SparseMerkleTree) updateWithSideNodes(path []byte, value []byte, side
commonPrefixCount = countCommonPrefix(path, actualPath)
}
if commonPrefixCount != smt.depth() {
if hasBit(path, commonPrefixCount) == right {
if getBitAtFromMSB(path, commonPrefixCount) == right {
currentHash, currentData = smt.th.digestNode(oldLeafHash, currentData)
} else {
currentHash, currentData = smt.th.digestNode(currentData, oldLeafHash)
Expand All @@ -283,11 +283,15 @@ func (smt *SparseMerkleTree) updateWithSideNodes(path []byte, value []byte, side
currentData = currentHash
}

for i := smt.depth() - 1; i >= 0; i-- {
for i := 0; i < smt.depth(); i++ {
sideNode := make([]byte, smt.th.pathSize())

if sideNodes[i] == nil {
if commonPrefixCount != smt.depth() && commonPrefixCount > i {
// The offset from the bottom of the tree to the start of the side nodes
// i-offsetOfSideNodes is the index into sideNodes[]
offsetOfSideNodes := smt.depth() - len(sideNodes)

if i-offsetOfSideNodes < 0 || sideNodes[i-offsetOfSideNodes] == nil {
if commonPrefixCount != smt.depth() && commonPrefixCount > smt.depth()-1-i {
// If there are no sidenodes at this height, but the number of
// bits that the paths of the two leaf nodes share in common is
// greater than this height, then we need to build up the tree
Expand All @@ -297,10 +301,10 @@ func (smt *SparseMerkleTree) updateWithSideNodes(path []byte, value []byte, side
continue
}
} else {
copy(sideNode, sideNodes[i])
copy(sideNode, sideNodes[i-offsetOfSideNodes])
}

if hasBit(path, i) == right {
if getBitAtFromMSB(path, smt.depth()-1-i) == right {
currentHash, currentData = smt.th.digestNode(sideNode, currentData)
} else {
currentHash, currentData = smt.th.digestNode(currentData, sideNode)
Expand All @@ -319,7 +323,9 @@ func (smt *SparseMerkleTree) updateWithSideNodes(path []byte, value []byte, side
// Returns an array of sibling nodes, the leaf hash found at that path and the
// leaf data. If the leaf is a placeholder, the leaf data is nil.
func (smt *SparseMerkleTree) sideNodesForRoot(path []byte, root []byte) ([][]byte, []byte, []byte, error) {
sideNodes := make([][]byte, smt.depth())
// Side nodes for the path. Nodes are inserted in reverse order, then the
// slice is reversed at the end.
sideNodes := make([][]byte, 0, smt.depth())

if bytes.Equal(root, smt.th.placeholder()) {
// If the root is a placeholder, there are no sidenodes to return.
Expand All @@ -340,17 +346,17 @@ func (smt *SparseMerkleTree) sideNodesForRoot(path []byte, root []byte) ([][]byt
leftNode, rightNode := smt.th.parseNode(currentData)

// Get sidenode depending on whether the path bit is on or off.
if hasBit(path, i) == right {
sideNodes[i] = leftNode
if getBitAtFromMSB(path, i) == right {
sideNodes = append(sideNodes, leftNode)
nodeHash = rightNode
} else {
sideNodes[i] = rightNode
sideNodes = append(sideNodes, rightNode)
nodeHash = leftNode
}

if bytes.Equal(nodeHash, smt.th.placeholder()) {
// If the node is a placeholder, we've reached the end.
return sideNodes, nodeHash, nil, nil
return reverseSideNodes(sideNodes), nodeHash, nil, nil
}

currentData, err = smt.ms.Get(nodeHash)
Expand All @@ -362,7 +368,7 @@ func (smt *SparseMerkleTree) sideNodesForRoot(path []byte, root []byte) ([][]byt
}
}

return sideNodes, nodeHash, currentData, err
return reverseSideNodes(sideNodes), nodeHash, currentData, nil
}

// Prove generates a Merkle proof for a key.
Expand Down
128 changes: 123 additions & 5 deletions smt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,17 +168,127 @@ func TestSparseMerkleTreeUpdateBasic(t *testing.T) {
}
}

// Test tree operations when two leafs are immediate neighbours.
// Test known tree ops
func TestSparseMerkleTreeKnown(t *testing.T) {
h := newDummyHasher(sha256.New())
sm := NewSimpleMap()
smt := NewSparseMerkleTree(sm, h)
var value []byte
var err error

baseKey := make([]byte, h.Size()+4)
key1 := make([]byte, h.Size()+4)
copy(key1, baseKey)
key1[4] = byte(0b00000000)
key2 := make([]byte, h.Size()+4)
copy(key2, baseKey)
key2[4] = byte(0b01000000)
key3 := make([]byte, h.Size()+4)
copy(key3, baseKey)
key3[4] = byte(0b10000000)
key4 := make([]byte, h.Size()+4)
copy(key4, baseKey)
key4[4] = byte(0b11000000)
key5 := make([]byte, h.Size()+4)
copy(key5, baseKey)
key5[4] = byte(0b11010000)

_, err = smt.Update(key1, []byte("testValue1"))
if err != nil {
t.Errorf("returned error when updating empty key: %v", err)
}
_, err = smt.Update(key2, []byte("testValue2"))
if err != nil {
t.Errorf("returned error when updating empty key: %v", err)
}
_, err = smt.Update(key3, []byte("testValue3"))
if err != nil {
t.Errorf("returned error when updating empty key: %v", err)
}
_, err = smt.Update(key4, []byte("testValue4"))
if err != nil {
t.Errorf("returned error when updating empty key: %v", err)
}
_, err = smt.Update(key5, []byte("testValue5"))
if err != nil {
t.Errorf("returned error when updating empty key: %v", err)
}

value, err = smt.Get(key1)
if err != nil {
t.Errorf("returned error when getting non-empty key: %v", err)
}
if !bytes.Equal([]byte("testValue1"), value) {
t.Error("did not get correct value when getting non-empty key")
}
value, err = smt.Get(key2)
if err != nil {
t.Errorf("returned error when getting non-empty key: %v", err)
}
if !bytes.Equal([]byte("testValue2"), value) {
t.Error("did not get correct value when getting non-empty key")
}
value, err = smt.Get(key3)
if err != nil {
t.Errorf("returned error when getting non-empty key: %v", err)
}
if !bytes.Equal([]byte("testValue3"), value) {
t.Error("did not get correct value when getting non-empty key")
}
value, err = smt.Get(key4)
if err != nil {
t.Errorf("returned error when getting non-empty key: %v", err)
}
if !bytes.Equal([]byte("testValue4"), value) {
t.Error("did not get correct value when getting non-empty key")
}
value, err = smt.Get(key5)
if err != nil {
t.Errorf("returned error when getting non-empty key: %v", err)
}
if !bytes.Equal([]byte("testValue5"), value) {
t.Error("did not get correct value when getting non-empty key")
}

proof1, _ := smt.Prove(key1)
proof2, _ := smt.Prove(key2)
proof3, _ := smt.Prove(key3)
proof4, _ := smt.Prove(key4)
proof5, _ := smt.Prove(key5)
dsmst := NewDeepSparseMerkleSubTree(NewSimpleMap(), h, smt.Root())
err = dsmst.AddBranch(proof1, key1, []byte("testValue1"))
if err != nil {
t.Errorf("returned error when adding branch to deep subtree: %v", err)
}
err = dsmst.AddBranch(proof2, key2, []byte("testValue2"))
if err != nil {
t.Errorf("returned error when adding branch to deep subtree: %v", err)
}
err = dsmst.AddBranch(proof3, key3, []byte("testValue3"))
if err != nil {
t.Errorf("returned error when adding branch to deep subtree: %v", err)
}
err = dsmst.AddBranch(proof4, key4, []byte("testValue4"))
if err != nil {
t.Errorf("returned error when adding branch to deep subtree: %v", err)
}
err = dsmst.AddBranch(proof5, key5, []byte("testValue5"))
if err != nil {
t.Errorf("returned error when adding branch to deep subtree: %v", err)
}
}

// Test tree operations when two leafs are immediate neighbors.
func TestSparseMerkleTreeMaxHeightCase(t *testing.T) {
h := newDummyHasher(sha256.New())
sm := NewSimpleMap()
smt := NewSparseMerkleTree(sm, h)
var value []byte
var err error

// Make two neighbouring keys.
// Make two neighboring keys.
//
// The dummy hash function excepts keys to prefixed with four bytes of 0,
// The dummy hash function expects keys to prefixed with four bytes of 0,
// which will cause it to return the preimage itself as the digest, without
// the first four bytes.
key1 := make([]byte, h.Size()+4)
Expand All @@ -187,7 +297,8 @@ func TestSparseMerkleTreeMaxHeightCase(t *testing.T) {
key1[h.Size()+4-1] = byte(0)
key2 := make([]byte, h.Size()+4)
copy(key2, key1)
setBit(key2, (h.Size()+4)*8-1)
// We make key2's least significant bit different than key1's
key2[h.Size()+4-1] = byte(1)

_, err = smt.Update(key1, []byte("testValue1"))
if err != nil {
Expand All @@ -205,14 +316,21 @@ func TestSparseMerkleTreeMaxHeightCase(t *testing.T) {
if !bytes.Equal([]byte("testValue1"), value) {
t.Error("did not get correct value when getting non-empty key")
}

value, err = smt.Get(key2)
if err != nil {
t.Errorf("returned error when getting non-empty key: %v", err)
}
if !bytes.Equal([]byte("testValue2"), value) {
t.Error("did not get correct value when getting non-empty key")
}

proof, err := smt.Prove(key1)
if err != nil {
t.Errorf("returned error when proving key: %v", err)
}
if len(proof.SideNodes) != 256 {
t.Errorf("unexpected proof size")
}
}

// Test base case tree delete operations with a few keys.
Expand Down
22 changes: 16 additions & 6 deletions utils.go
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
package smt

func hasBit(data []byte, position int) int {
if int(data[position/8])&(1<<(uint(position)%8)) > 0 {
// getBitAtFromMSB gets the bit at an offset from the most significant bit
func getBitAtFromMSB(data []byte, position int) int {
if int(data[position/8])&(1<<(8-1-uint(position)%8)) > 0 {
return 1
}
return 0
}

func setBit(data []byte, position int) {
// setBitAtFromMSB sets the bit at an offset from the most significant bit
func setBitAtFromMSB(data []byte, position int) {
n := int(data[position/8])
n |= (1 << (uint(position) % 8))
n |= (1 << (8 - 1 - uint(position)%8))
data[position/8] = byte(n)
}

func countSetBits(data []byte) int {
count := 0
for i := 0; i < len(data)*8; i++ {
if hasBit(data, i) == 1 {
if getBitAtFromMSB(data, i) == 1 {
adlerjohn marked this conversation as resolved.
Show resolved Hide resolved
count++
}
}
Expand All @@ -26,7 +28,7 @@ func countSetBits(data []byte) int {
func countCommonPrefix(data1 []byte, data2 []byte) int {
count := 0
for i := 0; i < len(data1)*8; i++ {
if hasBit(data1, i) == hasBit(data2, i) {
if getBitAtFromMSB(data1, i) == getBitAtFromMSB(data2, i) {
adlerjohn marked this conversation as resolved.
Show resolved Hide resolved
count++
} else {
break
Expand All @@ -39,3 +41,11 @@ func emptyBytes(length int) []byte {
b := make([]byte, length)
return b
}

func reverseSideNodes(sideNodes [][]byte) [][]byte {
for left, right := 0, len(sideNodes)-1; left < right; left, right = left+1, right-1 {
sideNodes[left], sideNodes[right] = sideNodes[right], sideNodes[left]
}

return sideNodes
}