diff --git a/go-sdk/integration-test/e2e_test.go b/go-sdk/integration-test/e2e_test.go index bfd25a9..2ebd972 100644 --- a/go-sdk/integration-test/e2e_test.go +++ b/go-sdk/integration-test/e2e_test.go @@ -284,13 +284,11 @@ func TestZeto_3_SuccessfulProving(t *testing.T) { assert.NoError(t, err) err = mt.AddLeaf(n2) assert.NoError(t, err) - proof1, _, err := mt.GenerateProof(input1, nil) + proofs, _, err := mt.GenerateProofs([]*big.Int{input1, input2}, nil) assert.NoError(t, err) - circomProof1, err := proof1.ToCircomVerifierProof(input1, input1, mt.Root(), MAX_HEIGHT) + circomProof1, err := proofs[0].ToCircomVerifierProof(input1, input1, mt.Root(), MAX_HEIGHT) assert.NoError(t, err) - proof2, _, err := mt.GenerateProof(input2, nil) - assert.NoError(t, err) - circomProof2, err := proof2.ToCircomVerifierProof(input2, input2, mt.Root(), MAX_HEIGHT) + circomProof2, err := proofs[1].ToCircomVerifierProof(input2, input2, mt.Root(), MAX_HEIGHT) assert.NoError(t, err) salt3 := crypto.NewSalt() @@ -370,13 +368,11 @@ func TestZeto_4_SuccessfulProving(t *testing.T) { assert.NoError(t, err) err = mt.AddLeaf(n2) assert.NoError(t, err) - proof1, _, err := mt.GenerateProof(input1, nil) - assert.NoError(t, err) - circomProof1, err := proof1.ToCircomVerifierProof(input1, input1, mt.Root(), MAX_HEIGHT) + proofs, _, err := mt.GenerateProofs([]*big.Int{input1, input2}, nil) assert.NoError(t, err) - proof2, _, err := mt.GenerateProof(input2, nil) + circomProof1, err := proofs[0].ToCircomVerifierProof(input1, input1, mt.Root(), MAX_HEIGHT) assert.NoError(t, err) - circomProof2, err := proof2.ToCircomVerifierProof(input2, input2, mt.Root(), MAX_HEIGHT) + circomProof2, err := proofs[1].ToCircomVerifierProof(input2, input2, mt.Root(), MAX_HEIGHT) assert.NoError(t, err) salt3 := crypto.NewSalt() @@ -510,9 +506,9 @@ func TestZeto_6_SuccessfulProving(t *testing.T) { assert.NoError(t, err) err = mt.AddLeaf(n1) assert.NoError(t, err) - proof1, _, err := mt.GenerateProof(input1, nil) + proofs, _, err := mt.GenerateProofs([]*big.Int{input1}, nil) assert.NoError(t, err) - circomProof1, err := proof1.ToCircomVerifierProof(input1, input1, mt.Root(), MAX_HEIGHT) + circomProof1, err := proofs[0].ToCircomVerifierProof(input1, input1, mt.Root(), MAX_HEIGHT) assert.NoError(t, err) proof1Siblings := make([]*big.Int, len(circomProof1.Siblings)-1) for i, s := range circomProof1.Siblings[0 : len(circomProof1.Siblings)-1] { diff --git a/go-sdk/internal/sparse-merkle-tree/smt/merkletree.go b/go-sdk/internal/sparse-merkle-tree/smt/merkletree.go index e1f0a8b..4223117 100644 --- a/go-sdk/internal/sparse-merkle-tree/smt/merkletree.go +++ b/go-sdk/internal/sparse-merkle-tree/smt/merkletree.go @@ -104,14 +104,29 @@ func (mt *sparseMerkleTree) GetNode(key core.NodeIndex) (core.Node, error) { // GenerateProof generates the proof of existence (or non-existence) of a leaf node // for a Merkle Tree given the root. It uses the node's index to represent the node. // If the rootKey is nil, the current merkletree root is used -func (mt *sparseMerkleTree) GenerateProof(k *big.Int, rootKey core.NodeIndex) (core.Proof, *big.Int, error) { +func (mt *sparseMerkleTree) GenerateProofs(keys []*big.Int, rootKey core.NodeIndex) ([]core.Proof, []*big.Int, error) { mt.RLock() defer mt.RUnlock() + merkleProofs := make([]core.Proof, len(keys)) + foundValues := make([]*big.Int, len(keys)) + for i, key := range keys { + proof, value, err := mt.generateProof(key, rootKey) + if err != nil { + return nil, nil, err + } + merkleProofs[i] = proof + foundValues[i] = value + } + + return merkleProofs, foundValues, nil +} + +func (mt *sparseMerkleTree) generateProof(key *big.Int, rootKey core.NodeIndex) (core.Proof, *big.Int, error) { p := &proof{} var siblingKey core.NodeIndex - kHash, err := node.NewNodeIndexFromBigInt(k) + kHash, err := node.NewNodeIndexFromBigInt(key) if err != nil { return nil, nil, err } @@ -160,7 +175,7 @@ func (mt *sparseMerkleTree) GenerateProof(k *big.Int, rootKey core.NodeIndex) (c p.siblings = append(p.siblings, siblingKey) } } - return nil, nil, ErrKeyNotFound + return nil, nil, ErrReachedMaxLevel } // must be called from inside a read lock diff --git a/go-sdk/internal/sparse-merkle-tree/smt/smt_test.go b/go-sdk/internal/sparse-merkle-tree/smt/smt_test.go index 11ab4d0..c2ab5b2 100644 --- a/go-sdk/internal/sparse-merkle-tree/smt/smt_test.go +++ b/go-sdk/internal/sparse-merkle-tree/smt/smt_test.go @@ -131,22 +131,20 @@ func TestGenerateProof(t *testing.T) { assert.NoError(t, err) target1 := node1.Index().BigInt() - proof1, foundValue1, err := mt.GenerateProof(target1, mt.Root()) - assert.NoError(t, err) - assert.Equal(t, target1, foundValue1) - assert.True(t, proof1.(*proof).existence) - valid := VerifyProof(mt.Root(), proof1, node1) - assert.True(t, valid) utxo3 := node.NewFungible(big.NewInt(10), alice.PublicKey, big.NewInt(12347)) node3, err := node.NewLeafNode(utxo3) assert.NoError(t, err) target2 := node3.Index().BigInt() - proof2, _, err := mt.GenerateProof(target2, mt.Root()) + proofs, foundValues, err := mt.GenerateProofs([]*big.Int{target1, target2}, mt.Root()) assert.NoError(t, err) - assert.False(t, proof2.(*proof).existence) + assert.Equal(t, target1, foundValues[0]) + assert.True(t, proofs[0].(*proof).existence) + valid := VerifyProof(mt.Root(), proofs[0], node1) + assert.True(t, valid) + assert.False(t, proofs[1].(*proof).existence) - proof3, err := proof1.ToCircomVerifierProof(target1, foundValue1, mt.Root(), levels) + proof3, err := proofs[0].ToCircomVerifierProof(target1, foundValues[0], mt.Root(), levels) assert.NoError(t, err) assert.False(t, proof3.IsOld0) } @@ -181,11 +179,11 @@ func TestVerifyProof(t *testing.T) { target := n.Index().BigInt() root := mt.Root() - p, _, err := mt.GenerateProof(target, root) + p, _, err := mt.GenerateProofs([]*big.Int{target}, root) assert.NoError(t, err) - assert.True(t, p.(*proof).existence) + assert.True(t, p[0].(*proof).existence) - valid := VerifyProof(root, p, n) + valid := VerifyProof(root, p[0], n) assert.True(t, valid) }() diff --git a/go-sdk/pkg/sparse-merkle-tree/core/merkletree.go b/go-sdk/pkg/sparse-merkle-tree/core/merkletree.go index de5447a..316b110 100644 --- a/go-sdk/pkg/sparse-merkle-tree/core/merkletree.go +++ b/go-sdk/pkg/sparse-merkle-tree/core/merkletree.go @@ -36,9 +36,9 @@ type SparseMerkleTree interface { // Root returns the root hash of the tree Root() NodeIndex // AddLeaf adds a key-value pair to the tree - AddLeaf(Node) error + AddLeaf(leaf Node) error // GetNode returns the node at the given reference hash - GetNode(NodeIndex) (Node, error) + GetNode(node NodeIndex) (Node, error) // GetnerateProof generates a proof of existence (or non-existence) of a leaf node - GenerateProof(*big.Int, NodeIndex) (Proof, *big.Int, error) + GenerateProofs(nodeIndexes []*big.Int, root NodeIndex) ([]Proof, []*big.Int, error) }