Skip to content

Commit

Permalink
Improve test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
pnowosie committed Dec 9, 2024
1 parent e691455 commit bef9818
Show file tree
Hide file tree
Showing 8 changed files with 99 additions and 300 deletions.
9 changes: 4 additions & 5 deletions core/trie/key.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
package trie

import (
"bytes"
"encoding/hex"
"errors"
"fmt"
"io"
"math/big"

"github.com/NethermindEth/juno/core/felt"
Expand Down Expand Up @@ -61,7 +60,7 @@ func (k *Key) MostSignificantBits(n uint8) (*Key, error) {

func (k *Key) SubKey(n uint8) (*Key, error) {
if n > k.len {
return nil, errors.New(fmt.Sprint("cannot subtract key of length %i from key of length %i", n, k.len))
return nil, fmt.Errorf("cannot subtract key of length %d from key of length %d", n, k.len)
}
if n == k.len {
return &Key{}, nil
Expand Down Expand Up @@ -95,8 +94,8 @@ func (k *Key) unusedBytes() []byte {
return k.bitset[:len(k.bitset)-int(k.bytesNeeded())]
}

func (k *Key) WriteTo(buf *bytes.Buffer) (int64, error) {
if err := buf.WriteByte(k.len); err != nil {
func (k *Key) WriteTo(buf io.Writer) (int64, error) {
if _, err := buf.Write([]byte{k.len}); err != nil {
return 0, err
}

Expand Down
58 changes: 58 additions & 0 deletions core/trie/key_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package trie_test

import (
"bytes"
"errors"
"testing"

"github.com/NethermindEth/juno/core/felt"
Expand Down Expand Up @@ -153,3 +154,60 @@ func TestTruncate(t *testing.T) {
})
}
}

func TestKeyErrorHandling(t *testing.T) {
t.Run("passed too long key bytes panics", func(t *testing.T) {
defer func() {
r := recover()
require.NotNil(t, r)
require.Contains(t, r.(string), "bytes does not fit in bitset")
}()
tooLongKeyB := make([]byte, 33)
trie.NewKey(8, tooLongKeyB)
})
t.Run("MostSignificantBits n greater than key length", func(t *testing.T) {
key := trie.NewKey(8, []byte{0x01})
_, err := key.MostSignificantBits(9)
require.Error(t, err)
require.Contains(t, err.Error(), "cannot take 9 bits from key of length 8")
})
t.Run("MostSignificantBits equals key length return copy of key", func(t *testing.T) {
key := trie.NewKey(8, []byte{0x01})
kCopy, err := key.MostSignificantBits(8)
require.NoError(t, err)
require.Equal(t, key, *kCopy)
})
t.Run("SubKey n greater than key length", func(t *testing.T) {
key := trie.NewKey(8, []byte{0x01})
_, err := key.SubKey(9)
require.Error(t, err)
require.Contains(t, err.Error(), "cannot subtract key of length 9 from key of length 8")
})
t.Run("SubKey n equals k length returns empty key", func(t *testing.T) {
key := trie.NewKey(8, []byte{0x01})
kCopy, err := key.SubKey(8)
require.NoError(t, err)
require.Equal(t, trie.Key{}, *kCopy)
})
t.Run("delete more bits than key length panics", func(t *testing.T) {
defer func() {
r := recover()
require.NotNil(t, r)
require.Contains(t, r.(string), "deleting more bits than there are")
}()
key := trie.NewKey(8, []byte{0x01})
key.DeleteLSB(9)
})
t.Run("WriteTo returns error", func(t *testing.T) {
key := trie.NewKey(8, []byte{0x01})
wrote, err := key.WriteTo(&errorBuffer{})
require.Error(t, err)
require.Equal(t, int64(0), wrote)
})
}

type errorBuffer struct{}

func (*errorBuffer) Write([]byte) (int, error) {
return 0, errors.New("expected to fail")
}
4 changes: 2 additions & 2 deletions core/trie/node.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package trie

import (
"bytes"
"errors"
"io"

"github.com/NethermindEth/juno/core/felt"
)
Expand Down Expand Up @@ -38,7 +38,7 @@ func (n *Node) HashFromParent(parentKey, nodeKey *Key, hashFunc HashFunc) *felt.
return n.Hash(&path, hashFunc)
}

func (n *Node) WriteTo(buf *bytes.Buffer) (int64, error) {
func (n *Node) WriteTo(buf io.Writer) (int64, error) {
if n.Value == nil {
return 0, errors.New("cannot marshal node with nil value")
}
Expand Down
32 changes: 32 additions & 0 deletions core/trie/node_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package trie_test

import (
"bytes"
"encoding/hex"
"errors"
"testing"

"github.com/NethermindEth/juno/core/crypto"
Expand All @@ -26,3 +28,33 @@ func TestNodeHash(t *testing.T) {

assert.Equal(t, expected, node.Hash(&path, crypto.Pedersen), "TestTrieNode_Hash failed")
}

func TestNodeErrorHandling(t *testing.T) {
t.Run("WriteTo node value is nil", func(t *testing.T) {
node := trie.Node{}
var buffer bytes.Buffer
_, err := node.WriteTo(&buffer)
require.Error(t, err)
})
t.Run("WriteTo returns error", func(t *testing.T) {
node := trie.Node{
Value: new(felt.Felt).SetUint64(42),
Left: &trie.Key{},
Right: &trie.Key{},
}

wrote, err := node.WriteTo(&errorBuffer{})
require.Error(t, err)
require.Equal(t, int64(0), wrote)
})
t.Run("UnmarshalBinary returns error", func(t *testing.T) {
node := trie.Node{}

err := node.UnmarshalBinary([]byte{42})
require.Equal(t, errors.New("size of input data is less than felt size"), err)

bs := new(felt.Felt).Bytes()
err = node.UnmarshalBinary(append(bs[:], 0, 0, 42))
require.Equal(t, errors.New("the node does not contain both left and right hash"), err)
})
}
88 changes: 0 additions & 88 deletions core/trie/proof.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,94 +274,6 @@ func VerifyProof(root *felt.Felt, key *Key, proofSet *ProofSet, hash HashFunc) (
}
}

// VerifyRangeProof verifies the range proof for the given range of keys.
// This is achieved by constructing a trie from the boundary proofs, and the supplied key-values.
// If the root of the reconstructed trie matches the supplied root, then the verification passes.
// If the trie is constructed incorrectly then the root will have an incorrect key(len,path), and value,
// and therefore its hash won't match the expected root.
// ref: https://github.com/ethereum/go-ethereum/blob/v1.14.3/trie/proof.go#L484
//
//nolint:gocyclo
func VerifyRangeProof(root, firstKey *felt.Felt, keys, values []*felt.Felt, proofSet *ProofSet, hash HashFunc) (bool, error) {
// Ensure the number of keys and values are the same
if len(keys) != len(values) {
return false, fmt.Errorf("inconsistent proof data, number of keys: %d, number of values: %d", len(keys), len(values))
}

// Ensure all keys are monotonic increasing
for i := 0; i < len(keys)-1; i++ {
if keys[i].Cmp(keys[i+1]) >= 0 {
return false, errors.New("keys are not monotonic increasing")
}
}

// Ensure the range contains no deletions
for _, value := range values {
if value.Equal(&felt.Zero) {
return false, errors.New("range contains deletion")
}
}

// Special case: no edge proof at all, given range is the whole leaf set in the trie
if proofSet == nil {
tr, err := NewTriePedersen(newMemStorage(), 251) //nolint:mnd
if err != nil {
return false, err
}

for index, key := range keys {
_, err = tr.Put(key, values[index])
if err != nil {
return false, err
}
}

recomputedRoot, err := tr.Root()
if err != nil {
return false, err
}

if !recomputedRoot.Equal(root) {
return false, fmt.Errorf("root hash mismatch, expected: %s, got: %s", root.String(), recomputedRoot.String())
}

return true, nil
}

proofList := proofSet.List()
lastKey := keys[len(keys)-1]

// Construct the left proof path
leftProofPath, err := ProofToPath(proofList, &Key{len: 251, bitset: firstKey.Bytes()}, hash)
if err != nil {
return false, err
}

// Construct the right proof path
rightProofPath, err := ProofToPath(proofList, &Key{len: 251, bitset: lastKey.Bytes()}, hash)
if err != nil {
return false, err
}

// Build the trie from the proof paths and the key-value pairs
tr, err := BuildTrie(leftProofPath, rightProofPath, keys, values)
if err != nil {
return false, err
}

// Verify that the recomputed root hash matches the provided root hash
recomputedRoot, err := tr.Root()
if err != nil {
return false, err
}

if !recomputedRoot.Equal(root) {
return false, fmt.Errorf("root hash mismatch, expected: %s, got: %s", root.String(), recomputedRoot.String())
}

return true, nil
}

// compressNode determines if the node needs compressed, and if so, the len needed to arrive at the next key
func compressNode(idx int, proofNodes []ProofNode, hashF HashFunc) (int, uint8, error) {
parent := proofNodes[idx]
Expand Down
Loading

0 comments on commit bef9818

Please sign in to comment.