diff --git a/core/trie2/collector.go b/core/trie2/collector.go index 710f55b5bb..d4bcca1aa0 100644 --- a/core/trie2/collector.go +++ b/core/trie2/collector.go @@ -91,7 +91,7 @@ func (c *collector) collectChildren(path *Path, n *binaryNode, parallel bool) [2 var wg sync.WaitGroup var mu sync.Mutex - for i := 0; i < 2; i++ { + for i := range 2 { wg.Add(1) go func(idx int) { defer wg.Done() diff --git a/core/trie2/id.go b/core/trie2/id.go index 49e3b09fe6..ed9a8c517d 100644 --- a/core/trie2/id.go +++ b/core/trie2/id.go @@ -15,10 +15,8 @@ const ( // Represents the identifier for uniquely identifying a trie. type ID struct { - TrieType TrieType - Root felt.Felt // The root hash of the trie - Owner felt.Felt // The contract address which the trie belongs to - StorageRoot felt.Felt // The root hash of the storage trie of a contract. + TrieType TrieType + Owner felt.Felt // The contract address which the trie belongs to } // Returns the corresponding DB bucket for the trie @@ -39,31 +37,25 @@ func (id *ID) Bucket() db.Bucket { } // Constructs an identifier for a class trie with the provided class trie root hash -func ClassTrieID(root felt.Felt) *ID { +func ClassTrieID() *ID { return &ID{ - TrieType: ClassTrie, - Root: root, - Owner: felt.Zero, // class trie does not have an owner - StorageRoot: felt.Zero, // only contract storage trie has a storage root + TrieType: ClassTrie, + Owner: felt.Zero, // class trie does not have an owner } } // Constructs an identifier for a contract trie or a contract's storage trie -func ContractTrieID(root, owner, storageRoot felt.Felt) *ID { +func ContractTrieID(owner felt.Felt) *ID { return &ID{ - TrieType: ContractTrie, - Root: root, - Owner: owner, - StorageRoot: storageRoot, + TrieType: ContractTrie, + Owner: owner, } } // A general identifier, typically used for temporary trie -func TrieID(root felt.Felt) *ID { +func TrieID() *ID { return &ID{ - TrieType: Empty, - Root: root, - Owner: felt.Zero, - StorageRoot: felt.Zero, + TrieType: Empty, + Owner: felt.Zero, } } diff --git a/core/trie2/node_enc.go b/core/trie2/node_enc.go index 3cc0229ecf..83d1e75626 100644 --- a/core/trie2/node_enc.go +++ b/core/trie2/node_enc.go @@ -102,31 +102,15 @@ func nodeToBytes(n node) []byte { } // Decodes the encoded bytes and returns the corresponding node -func decodeNode(blob []byte, hash felt.Felt, pathLen, maxPathLen uint8) (node, error) { +func decodeNode(blob []byte, hash *felt.Felt, pathLen, maxPathLen uint8) (node, error) { if len(blob) == 0 { return nil, errors.New("cannot decode empty blob") } - - var ( - n node - err error - isLeaf bool - ) - if pathLen > maxPathLen { - return nil, fmt.Errorf("node path length (%d) is greater than max path length (%d)", pathLen, maxPathLen) + return nil, fmt.Errorf("node path length (%d) > max (%d)", pathLen, maxPathLen) } - isLeaf = pathLen == maxPathLen - - if len(blob) == hashOrValueNodeSize { - var f felt.Felt - f.SetBytes(blob) - if isLeaf { - n = &valueNode{Felt: f} - } else { - n = &hashNode{Felt: f} - } + if n, ok := decodeHashOrValueNode(blob, pathLen, maxPathLen); ok { return n, nil } @@ -135,44 +119,66 @@ func decodeNode(blob []byte, hash felt.Felt, pathLen, maxPathLen uint8) (node, e switch nodeType { case binaryNodeType: - binary := &binaryNode{flags: nodeFlag{hash: &hashNode{Felt: hash}}} // cache the hash - binary.children[0], err = decodeNode(blob[:hashOrValueNodeSize], hash, pathLen+1, maxPathLen) - if err != nil { - return nil, err - } - binary.children[1], err = decodeNode(blob[hashOrValueNodeSize:], hash, pathLen+1, maxPathLen) - if err != nil { - return nil, err - } - - n = binary + return decodeBinaryNode(blob, hash, pathLen, maxPathLen) case edgeNodeType: - // Ensure the blob length is within the valid range for an edge node - if len(blob) > edgeNodeMaxSize || len(blob) < hashOrValueNodeSize { - return nil, fmt.Errorf("invalid node size: %d", len(blob)) - } + return decodeEdgeNode(blob, hash, pathLen, maxPathLen) + default: + panic(fmt.Sprintf("unknown node type: %d", nodeType)) + } +} - edge := &edgeNode{ - path: &trieutils.BitArray{}, - flags: nodeFlag{hash: &hashNode{Felt: hash}}, // cache the hash - } - edge.child, err = decodeNode(blob[:hashOrValueNodeSize], felt.Felt{}, pathLen, maxPathLen) - if err != nil { - return nil, err - } - if err := edge.path.UnmarshalBinary(blob[hashOrValueNodeSize:]); err != nil { - return nil, err +func decodeHashOrValueNode(blob []byte, pathLen, maxPathLen uint8) (node, bool) { + if len(blob) == hashOrValueNodeSize { + var f felt.Felt + f.SetBytes(blob) + if pathLen == maxPathLen { + return &valueNode{Felt: f}, true } + return &hashNode{Felt: f}, true + } + return nil, false +} - // We do another path length check to see if the node is a leaf - if pathLen+edge.path.Len() == maxPathLen { - edge.child = &valueNode{Felt: edge.child.(*hashNode).Felt} - } +func decodeBinaryNode(blob []byte, hash *felt.Felt, pathLen, maxPathLen uint8) (*binaryNode, error) { + if len(blob) < 2*hashOrValueNodeSize { + return nil, fmt.Errorf("invalid binary node size: %d", len(blob)) + } - n = edge - default: - panic(fmt.Sprintf("unknown decode node type: %d", nodeType)) + binary := &binaryNode{} + if hash != nil { + binary.flags.hash = &hashNode{Felt: *hash} + } + + var err error + if binary.children[0], err = decodeNode(blob[:hashOrValueNodeSize], hash, pathLen+1, maxPathLen); err != nil { + return nil, err + } + if binary.children[1], err = decodeNode(blob[hashOrValueNodeSize:], hash, pathLen+1, maxPathLen); err != nil { + return nil, err + } + return binary, nil +} + +func decodeEdgeNode(blob []byte, hash *felt.Felt, pathLen, maxPathLen uint8) (*edgeNode, error) { + if len(blob) > edgeNodeMaxSize || len(blob) < hashOrValueNodeSize { + return nil, fmt.Errorf("invalid edge node size: %d", len(blob)) + } + + edge := &edgeNode{path: &trieutils.BitArray{}} + if hash != nil { + edge.flags.hash = &hashNode{Felt: *hash} } - return n, nil + var err error + if edge.child, err = decodeNode(blob[:hashOrValueNodeSize], nil, pathLen, maxPathLen); err != nil { + return nil, err + } + if err := edge.path.UnmarshalBinary(blob[hashOrValueNodeSize:]); err != nil { + return nil, err + } + + if pathLen+edge.path.Len() == maxPathLen { + edge.child = &valueNode{Felt: edge.child.(*hashNode).Felt} + } + return edge, nil } diff --git a/core/trie2/node_enc_test.go b/core/trie2/node_enc_test.go index 022e1730f9..a002606e87 100644 --- a/core/trie2/node_enc_test.go +++ b/core/trie2/node_enc_test.go @@ -40,7 +40,6 @@ func TestNodeEncodingDecoding(t *testing.T) { node: &edgeNode{ path: newPath(1, 0, 1), child: &valueNode{Felt: newFelt(123)}, - flags: nodeFlag{}, }, pathLen: 8, maxPath: 8, @@ -50,7 +49,6 @@ func TestNodeEncodingDecoding(t *testing.T) { node: &edgeNode{ path: newPath(0, 1, 0), child: &hashNode{Felt: newFelt(456)}, - flags: nodeFlag{}, }, pathLen: 3, maxPath: 8, @@ -62,7 +60,6 @@ func TestNodeEncodingDecoding(t *testing.T) { &hashNode{Felt: newFelt(111)}, &hashNode{Felt: newFelt(222)}, }, - flags: nodeFlag{}, }, pathLen: 0, maxPath: 8, @@ -74,7 +71,6 @@ func TestNodeEncodingDecoding(t *testing.T) { &valueNode{Felt: newFelt(555)}, &valueNode{Felt: newFelt(666)}, }, - flags: nodeFlag{}, }, pathLen: 7, maxPath: 8, @@ -96,7 +92,6 @@ func TestNodeEncodingDecoding(t *testing.T) { node: &edgeNode{ path: newPath(), child: &hashNode{Felt: newFelt(1111)}, - flags: nodeFlag{}, }, pathLen: 0, maxPath: 8, @@ -106,7 +101,6 @@ func TestNodeEncodingDecoding(t *testing.T) { node: &edgeNode{ path: newPath(1, 1, 1, 1, 1, 1, 1, 1), child: &valueNode{Felt: newFelt(2222)}, - flags: nodeFlag{}, }, pathLen: 0, maxPath: 8, @@ -120,7 +114,7 @@ func TestNodeEncodingDecoding(t *testing.T) { // Try to decode hash := tt.node.hash(crypto.Pedersen) - decoded, err := decodeNode(encoded, *hash, tt.pathLen, tt.maxPath) + decoded, err := decodeNode(encoded, hash, tt.pathLen, tt.maxPath) if tt.wantErr { require.Error(t, err) @@ -163,16 +157,14 @@ func TestNodeEncodingDecodingBoundary(t *testing.T) { // Test with invalid path lengths t.Run("invalid path lengths", func(t *testing.T) { - hash := felt.Zero blob := make([]byte, hashOrValueNodeSize) - _, err := decodeNode(blob, hash, 255, 8) // pathLen > maxPath + _, err := decodeNode(blob, nil, 255, 8) // pathLen > maxPath require.Error(t, err) }) // Test with empty buffer t.Run("empty buffer", func(t *testing.T) { - hash := felt.Zero - _, err := decodeNode([]byte{}, hash, 0, 8) + _, err := decodeNode([]byte{}, nil, 0, 8) require.Error(t, err) }) } diff --git a/core/trie2/proof.go b/core/trie2/proof.go index 1916cb1339..49210cb210 100644 --- a/core/trie2/proof.go +++ b/core/trie2/proof.go @@ -152,11 +152,6 @@ func VerifyProof(root, key *felt.Felt, proof *ProofNodeSet, hash crypto.HashFn) // - Zero element proof: A single edge proof suffices for verification. The proof is invalid if there are additional elements. // // The function returns a boolean indicating if there are more elements and an error if the range proof is invalid. -// -// TODO(weiihann): Given a binary leaf and a left-sibling first key, if the right sibling is removed, the proof would still be valid. -// Conversely, given a binary leaf and a right-sibling last key, if the left sibling is removed, the proof would still be valid. -// Range proof should not be valid for both of these cases, but currently is, which is an attack vector. -// The problem probably lies in how we do root hash calculation. func VerifyRangeProof(rootHash, first *felt.Felt, keys, values []*felt.Felt, proof *ProofNodeSet) (bool, error) { //nolint:funlen,gocyclo // Ensure the number of keys and values are the same if len(keys) != len(values) { @@ -164,7 +159,7 @@ func VerifyRangeProof(rootHash, first *felt.Felt, keys, values []*felt.Felt, pro } // Ensure all keys are monotonically increasing and values contain no deletions - for i := 0; i < len(keys); i++ { + for i := range keys { if i < len(keys)-1 && keys[i].Cmp(keys[i+1]) > 0 { return false, errors.New("keys are not monotonic increasing") } diff --git a/core/trie2/proof_test.go b/core/trie2/proof_test.go index ed462e7992..69390425d7 100644 --- a/core/trie2/proof_test.go +++ b/core/trie2/proof_test.go @@ -372,7 +372,7 @@ func TestSingleSideRangeProof(t *testing.T) { keys := make([]*felt.Felt, i+1) values := make([]*felt.Felt, i+1) - for j := 0; j < i+1; j++ { + for j := range i + 1 { keys[j] = records[j].key values[j] = records[j].value } @@ -487,7 +487,7 @@ func TestBadRangeProof(t *testing.T) { tr, records := randomTrie(t, 5000) root := tr.Hash() - for i := 0; i < 500; i++ { + for range 500 { start := rand.Intn(len(records)) end := rand.Intn(len(records)-start) + start + 1 @@ -539,7 +539,7 @@ func TestBadRangeProof(t *testing.T) { func BenchmarkProve(b *testing.B) { tr, records := randomTrie(b, 1000) b.ResetTimer() - for i := 0; i < b.N; i++ { + for i := range b.N { proof := NewProofNodeSet() key := records[i%len(records)].key if err := tr.Prove(key, proof); err != nil { @@ -552,7 +552,7 @@ func BenchmarkVerifyProof(b *testing.B) { tr, records := randomTrie(b, 1000) root := tr.Hash() - var proofs []*ProofNodeSet + proofs := make([]*ProofNodeSet, 0, len(records)) for _, record := range records { proof := NewProofNodeSet() if err := tr.Prove(record.key, proof); err != nil { @@ -562,7 +562,7 @@ func BenchmarkVerifyProof(b *testing.B) { } b.ResetTimer() - for i := 0; i < b.N; i++ { + for i := range b.N { index := i % len(records) if _, err := VerifyProof(&root, records[index].key, proofs[index], crypto.Pedersen); err != nil { b.Fatal(err) @@ -589,7 +589,7 @@ func BenchmarkVerifyRangeProof(b *testing.B) { } b.ResetTimer() - for i := 0; i < b.N; i++ { + for range b.N { _, err := VerifyRangeProof(&root, keys[0], keys, values, proof) require.NoError(b, err) } diff --git a/core/trie2/trie.go b/core/trie2/trie.go index d3830e9bc3..d00b16cd73 100644 --- a/core/trie2/trie.go +++ b/core/trie2/trie.go @@ -2,6 +2,7 @@ package trie2 import ( "bytes" + "errors" "fmt" "github.com/NethermindEth/juno/core/crypto" @@ -14,8 +15,6 @@ import ( const contractClassTrieHeight = 251 -var emptyRoot = felt.Felt{} - type Path = trieutils.BitArray type Trie struct { @@ -57,13 +56,13 @@ func New(id *ID, height uint8, hashFn crypto.HashFn, txn db.Transaction) (*Trie, nodeTracer: newTracer(), } - if id.Root != emptyRoot { - root, err := tr.resolveNode(&hashNode{Felt: id.Root}, Path{}) - if err != nil { - return nil, err - } + root, err := tr.resolveNode(nil, Path{}) + if err == nil { tr.root = root + } else if !errors.Is(err, db.ErrKeyNotFound) { + return nil, err } + return tr, nil } @@ -455,7 +454,7 @@ func (t *Trie) delete(n node, prefix, key *Path) (bool, node, error) { } // Resolves the node at the given path from the database -func (t *Trie) resolveNode(hash *hashNode, path Path) (node, error) { +func (t *Trie) resolveNode(hn *hashNode, path Path) (node, error) { buf := bufferPool.Get().(*bytes.Buffer) buf.Reset() defer func() { @@ -469,7 +468,12 @@ func (t *Trie) resolveNode(hash *hashNode, path Path) (node, error) { } blob := buf.Bytes() - return decodeNode(blob, hash.Felt, path.Len(), t.height) + + var hash *felt.Felt + if hn != nil { + hash = &hn.Felt + } + return decodeNode(blob, hash, path.Len(), t.height) } // Calculate the hash of the root node @@ -499,9 +503,9 @@ func (t *Trie) String() string { } func NewEmptyPedersen() (*Trie, error) { - return New(TrieID(felt.Zero), contractClassTrieHeight, crypto.Pedersen, db.NewMemTransaction()) + return New(TrieID(), contractClassTrieHeight, crypto.Pedersen, db.NewMemTransaction()) } func NewEmptyPoseidon() (*Trie, error) { - return New(TrieID(felt.Zero), contractClassTrieHeight, crypto.Poseidon, db.NewMemTransaction()) + return New(TrieID(), contractClassTrieHeight, crypto.Poseidon, db.NewMemTransaction()) } diff --git a/core/trie2/trie_test.go b/core/trie2/trie_test.go index 7effe704a2..4b19a323ef 100644 --- a/core/trie2/trie_test.go +++ b/core/trie2/trie_test.go @@ -182,20 +182,11 @@ func TestHash(t *testing.T) { }) } -func TestMissingRoot(t *testing.T) { - var root felt.Felt - root.SetUint64(1) - - tr, err := New(TrieID(root), contractClassTrieHeight, crypto.Pedersen, db.NewMemTransaction()) - require.Nil(t, tr) - require.Error(t, err) -} - func TestCommit(t *testing.T) { verifyCommit := func(t *testing.T, records []*keyValue) { t.Helper() db := db.NewMemTransaction() - tr, err := New(TrieID(felt.Zero), contractClassTrieHeight, crypto.Pedersen, db) + tr, err := New(TrieID(), contractClassTrieHeight, crypto.Pedersen, db) require.NoError(t, err) for _, record := range records { @@ -203,10 +194,10 @@ func TestCommit(t *testing.T) { require.NoError(t, err) } - root, err := tr.Commit() + _, err = tr.Commit() require.NoError(t, err) - tr2, err := New(TrieID(root), contractClassTrieHeight, crypto.Pedersen, db) + tr2, err := New(TrieID(), contractClassTrieHeight, crypto.Pedersen, db) require.NoError(t, err) for _, record := range records { @@ -361,7 +352,7 @@ func runRandTestBool(rt randTest) bool { //nolint:gocyclo func runRandTest(rt randTest) error { txn := db.NewMemTransaction() - tr, err := New(TrieID(felt.Zero), contractClassTrieHeight, crypto.Pedersen, txn) + tr, err := New(TrieID(), contractClassTrieHeight, crypto.Pedersen, txn) if err != nil { return err } @@ -408,11 +399,11 @@ func runRandTest(rt randTest) error { case opHash: tr.Hash() case opCommit: - root, err := tr.Commit() + _, err := tr.Commit() if err != nil { rt[i].err = fmt.Errorf("commit failed: %w", err) } - newtr, err := New(TrieID(root), contractClassTrieHeight, crypto.Pedersen, txn) + newtr, err := New(TrieID(), contractClassTrieHeight, crypto.Pedersen, txn) if err != nil { rt[i].err = fmt.Errorf("new trie failed: %w", err) } @@ -451,7 +442,7 @@ func randomTrie(t testing.TB, n int) (*Trie, []*keyValue) { tr, _ := NewEmptyPedersen() records := make([]*keyValue, n) - for i := 0; i < n; i++ { + for i := range n { key := new(felt.Felt).SetUint64(uint64(rrand.Uint32() + 1)) records[i] = &keyValue{key: key, value: key} err := tr.Update(key, key)