Skip to content

Commit

Permalink
Add read lock for mt.Root(); add verify tests
Browse files Browse the repository at this point in the history
Signed-off-by: Jim Zhang <[email protected]>
  • Loading branch information
jimthematrix committed Jul 22, 2024
1 parent f062b35 commit 54a3778
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 2 deletions.
9 changes: 7 additions & 2 deletions zkp/golang/internal/smt/merkletree.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ func NewMerkleTree(db core.Storage, maxLevels int) (core.SparseMerkleTree, error
}

func (mt *sparseMerkleTree) Root() core.NodeIndex {
mt.RLock()
defer mt.RUnlock()
return mt.rootKey
}

Expand Down Expand Up @@ -103,6 +105,9 @@ func (mt *sparseMerkleTree) GetNode(key core.NodeIndex) (core.Node, error) {
// for a Merkle Tree given the root. It uses the node's index to represent the node.
// If the rootKey is nil, the current merkletree root is used
func (mt *sparseMerkleTree) GenerateProof(k *big.Int, rootKey core.NodeIndex) (core.Proof, *big.Int, error) {
mt.RLock()
defer mt.RUnlock()

p := &proof{}
var siblingKey core.NodeIndex

Expand All @@ -112,11 +117,11 @@ func (mt *sparseMerkleTree) GenerateProof(k *big.Int, rootKey core.NodeIndex) (c
}
path := kHash.ToPath(mt.maxLevels)
if rootKey == nil {
rootKey = mt.Root()
rootKey = mt.rootKey
}
nextKey := rootKey
for p.depth = 0; p.depth < uint(mt.maxLevels); p.depth++ {
n, err := mt.GetNode(nextKey)
n, err := mt.getNode(nextKey)
if err != nil {
return nil, nil, err
}
Expand Down
48 changes: 48 additions & 0 deletions zkp/golang/internal/smt/smt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,16 @@
package smt

import (
"fmt"
"math/big"
"math/rand"
"testing"

"github.com/hyperledger-labs/zeto/internal/node"
"github.com/hyperledger-labs/zeto/internal/storage"
"github.com/hyperledger-labs/zeto/internal/testutils"
"github.com/hyperledger-labs/zeto/internal/utxo"
"github.com/hyperledger-labs/zeto/pkg/core"
"github.com/iden3/go-iden3-crypto/babyjub"
"github.com/stretchr/testify/assert"
)
Expand Down Expand Up @@ -144,3 +147,48 @@ func TestGenerateProof(t *testing.T) {
assert.NoError(t, err)
assert.False(t, proof3.IsOld0)
}

func TestVerifyProof(t *testing.T) {
const levels = 10
db := storage.NewMemoryStorage()
mt, _ := NewMerkleTree(db, levels)

alice := testutils.NewKeypair()
values := []int{10, 20, 30, 40, 50}
done := make(chan bool, len(values))
startProving := make(chan core.Node, len(values))
for idx, value := range values {
go func(v int) {
salt := rand.Intn(100000)
utxo := utxo.NewFungible(big.NewInt(int64(v)), alice.PublicKey, big.NewInt(int64(salt)))
node, err := node.NewLeafNode(utxo)
assert.NoError(t, err)
err = mt.AddLeaf(node)
assert.NoError(t, err)
startProving <- node
done <- true
fmt.Printf("Added node %d\n", idx)
}(value)
}

go func() {
// trigger the proving process after 1 nodes are added
n := <-startProving
fmt.Println("Received node for proving")

target := n.Index().BigInt()
root := mt.Root()
p, _, err := mt.GenerateProof(target, root)
assert.NoError(t, err)
assert.True(t, p.(*proof).existence)

valid := VerifyProof(root, p, n)
assert.True(t, valid)
}()

for i := 0; i < len(values); i++ {
<-done
}

fmt.Println("All done")
}

0 comments on commit 54a3778

Please sign in to comment.