diff --git a/zkp/golang/internal/smt/merkletree.go b/zkp/golang/internal/smt/merkletree.go index 98bc4db..bdebec5 100644 --- a/zkp/golang/internal/smt/merkletree.go +++ b/zkp/golang/internal/smt/merkletree.go @@ -59,6 +59,8 @@ func NewMerkleTree(db core.Storage, maxLevels int) (core.SparseMerkleTree, error } func (mt *sparseMerkleTree) Root() core.NodeIndex { + mt.RLock() + defer mt.RUnlock() return mt.rootKey } @@ -103,6 +105,9 @@ func (mt *sparseMerkleTree) GetNode(key core.NodeIndex) (core.Node, error) { // for a Merkle Tree given the root. It uses the node's index to represent the node. // If the rootKey is nil, the current merkletree root is used func (mt *sparseMerkleTree) GenerateProof(k *big.Int, rootKey core.NodeIndex) (core.Proof, *big.Int, error) { + mt.RLock() + defer mt.RUnlock() + p := &proof{} var siblingKey core.NodeIndex @@ -112,11 +117,11 @@ func (mt *sparseMerkleTree) GenerateProof(k *big.Int, rootKey core.NodeIndex) (c } path := kHash.ToPath(mt.maxLevels) if rootKey == nil { - rootKey = mt.Root() + rootKey = mt.rootKey } nextKey := rootKey for p.depth = 0; p.depth < uint(mt.maxLevels); p.depth++ { - n, err := mt.GetNode(nextKey) + n, err := mt.getNode(nextKey) if err != nil { return nil, nil, err } diff --git a/zkp/golang/internal/smt/smt_test.go b/zkp/golang/internal/smt/smt_test.go index 42c844c..c5e9670 100644 --- a/zkp/golang/internal/smt/smt_test.go +++ b/zkp/golang/internal/smt/smt_test.go @@ -17,13 +17,16 @@ package smt import ( + "fmt" "math/big" + "math/rand" "testing" "github.com/hyperledger-labs/zeto/internal/node" "github.com/hyperledger-labs/zeto/internal/storage" "github.com/hyperledger-labs/zeto/internal/testutils" "github.com/hyperledger-labs/zeto/internal/utxo" + "github.com/hyperledger-labs/zeto/pkg/core" "github.com/iden3/go-iden3-crypto/babyjub" "github.com/stretchr/testify/assert" ) @@ -144,3 +147,48 @@ func TestGenerateProof(t *testing.T) { assert.NoError(t, err) assert.False(t, proof3.IsOld0) } + +func TestVerifyProof(t *testing.T) { + const levels = 10 + db := storage.NewMemoryStorage() + mt, _ := NewMerkleTree(db, levels) + + alice := testutils.NewKeypair() + values := []int{10, 20, 30, 40, 50} + done := make(chan bool, len(values)) + startProving := make(chan core.Node, len(values)) + for idx, value := range values { + go func(v int) { + salt := rand.Intn(100000) + utxo := utxo.NewFungible(big.NewInt(int64(v)), alice.PublicKey, big.NewInt(int64(salt))) + node, err := node.NewLeafNode(utxo) + assert.NoError(t, err) + err = mt.AddLeaf(node) + assert.NoError(t, err) + startProving <- node + done <- true + fmt.Printf("Added node %d\n", idx) + }(value) + } + + go func() { + // trigger the proving process after 1 nodes are added + n := <-startProving + fmt.Println("Received node for proving") + + target := n.Index().BigInt() + root := mt.Root() + p, _, err := mt.GenerateProof(target, root) + assert.NoError(t, err) + assert.True(t, p.(*proof).existence) + + valid := VerifyProof(root, p, n) + assert.True(t, valid) + }() + + for i := 0; i < len(values); i++ { + <-done + } + + fmt.Println("All done") +}