Skip to content

Commit

Permalink
Fix encoding bug
Browse files Browse the repository at this point in the history
  • Loading branch information
weiihann committed Feb 17, 2025
1 parent 5f183de commit 18b6b2a
Show file tree
Hide file tree
Showing 8 changed files with 102 additions and 122 deletions.
2 changes: 1 addition & 1 deletion core/trie2/collector.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func (c *collector) collectChildren(path *Path, n *binaryNode, parallel bool) [2
var wg sync.WaitGroup
var mu sync.Mutex

for i := 0; i < 2; i++ {
for i := range 2 {
wg.Add(1)
go func(idx int) {
defer wg.Done()
Expand Down
30 changes: 11 additions & 19 deletions core/trie2/id.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,8 @@ const (

// Represents the identifier for uniquely identifying a trie.
type ID struct {
TrieType TrieType
Root felt.Felt // The root hash of the trie
Owner felt.Felt // The contract address which the trie belongs to
StorageRoot felt.Felt // The root hash of the storage trie of a contract.
TrieType TrieType
Owner felt.Felt // The contract address which the trie belongs to
}

// Returns the corresponding DB bucket for the trie
Expand All @@ -39,31 +37,25 @@ func (id *ID) Bucket() db.Bucket {
}

// Constructs an identifier for a class trie with the provided class trie root hash
func ClassTrieID(root felt.Felt) *ID {
func ClassTrieID() *ID {
return &ID{
TrieType: ClassTrie,
Root: root,
Owner: felt.Zero, // class trie does not have an owner
StorageRoot: felt.Zero, // only contract storage trie has a storage root
TrieType: ClassTrie,
Owner: felt.Zero, // class trie does not have an owner
}

Check warning on line 44 in core/trie2/id.go

View check run for this annotation

Codecov / codecov/patch

core/trie2/id.go#L40-L44

Added lines #L40 - L44 were not covered by tests
}

// Constructs an identifier for a contract trie or a contract's storage trie
func ContractTrieID(root, owner, storageRoot felt.Felt) *ID {
func ContractTrieID(owner felt.Felt) *ID {
return &ID{
TrieType: ContractTrie,
Root: root,
Owner: owner,
StorageRoot: storageRoot,
TrieType: ContractTrie,
Owner: owner,
}

Check warning on line 52 in core/trie2/id.go

View check run for this annotation

Codecov / codecov/patch

core/trie2/id.go#L48-L52

Added lines #L48 - L52 were not covered by tests
}

// A general identifier, typically used for temporary trie
func TrieID(root felt.Felt) *ID {
func TrieID() *ID {
return &ID{
TrieType: Empty,
Root: root,
Owner: felt.Zero,
StorageRoot: felt.Zero,
TrieType: Empty,
Owner: felt.Zero,
}
}
110 changes: 58 additions & 52 deletions core/trie2/node_enc.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,31 +102,15 @@ func nodeToBytes(n node) []byte {
}

// Decodes the encoded bytes and returns the corresponding node
func decodeNode(blob []byte, hash felt.Felt, pathLen, maxPathLen uint8) (node, error) {
func decodeNode(blob []byte, hash *felt.Felt, pathLen, maxPathLen uint8) (node, error) {
if len(blob) == 0 {
return nil, errors.New("cannot decode empty blob")
}

var (
n node
err error
isLeaf bool
)

if pathLen > maxPathLen {
return nil, fmt.Errorf("node path length (%d) is greater than max path length (%d)", pathLen, maxPathLen)
return nil, fmt.Errorf("node path length (%d) > max (%d)", pathLen, maxPathLen)
}

isLeaf = pathLen == maxPathLen

if len(blob) == hashOrValueNodeSize {
var f felt.Felt
f.SetBytes(blob)
if isLeaf {
n = &valueNode{Felt: f}
} else {
n = &hashNode{Felt: f}
}
if n, ok := decodeHashOrValueNode(blob, pathLen, maxPathLen); ok {
return n, nil
}

Expand All @@ -135,44 +119,66 @@ func decodeNode(blob []byte, hash felt.Felt, pathLen, maxPathLen uint8) (node, e

switch nodeType {
case binaryNodeType:
binary := &binaryNode{flags: nodeFlag{hash: &hashNode{Felt: hash}}} // cache the hash
binary.children[0], err = decodeNode(blob[:hashOrValueNodeSize], hash, pathLen+1, maxPathLen)
if err != nil {
return nil, err
}
binary.children[1], err = decodeNode(blob[hashOrValueNodeSize:], hash, pathLen+1, maxPathLen)
if err != nil {
return nil, err
}

n = binary
return decodeBinaryNode(blob, hash, pathLen, maxPathLen)
case edgeNodeType:
// Ensure the blob length is within the valid range for an edge node
if len(blob) > edgeNodeMaxSize || len(blob) < hashOrValueNodeSize {
return nil, fmt.Errorf("invalid node size: %d", len(blob))
}
return decodeEdgeNode(blob, hash, pathLen, maxPathLen)
default:
panic(fmt.Sprintf("unknown node type: %d", nodeType))

Check warning on line 126 in core/trie2/node_enc.go

View check run for this annotation

Codecov / codecov/patch

core/trie2/node_enc.go#L125-L126

Added lines #L125 - L126 were not covered by tests
}
}

edge := &edgeNode{
path: &trieutils.BitArray{},
flags: nodeFlag{hash: &hashNode{Felt: hash}}, // cache the hash
}
edge.child, err = decodeNode(blob[:hashOrValueNodeSize], felt.Felt{}, pathLen, maxPathLen)
if err != nil {
return nil, err
}
if err := edge.path.UnmarshalBinary(blob[hashOrValueNodeSize:]); err != nil {
return nil, err
func decodeHashOrValueNode(blob []byte, pathLen, maxPathLen uint8) (node, bool) {
if len(blob) == hashOrValueNodeSize {
var f felt.Felt
f.SetBytes(blob)
if pathLen == maxPathLen {
return &valueNode{Felt: f}, true
}
return &hashNode{Felt: f}, true
}
return nil, false
}

// We do another path length check to see if the node is a leaf
if pathLen+edge.path.Len() == maxPathLen {
edge.child = &valueNode{Felt: edge.child.(*hashNode).Felt}
}
func decodeBinaryNode(blob []byte, hash *felt.Felt, pathLen, maxPathLen uint8) (*binaryNode, error) {
if len(blob) < 2*hashOrValueNodeSize {
return nil, fmt.Errorf("invalid binary node size: %d", len(blob))
}

Check warning on line 145 in core/trie2/node_enc.go

View check run for this annotation

Codecov / codecov/patch

core/trie2/node_enc.go#L144-L145

Added lines #L144 - L145 were not covered by tests

n = edge
default:
panic(fmt.Sprintf("unknown decode node type: %d", nodeType))
binary := &binaryNode{}
if hash != nil {
binary.flags.hash = &hashNode{Felt: *hash}
}

var err error
if binary.children[0], err = decodeNode(blob[:hashOrValueNodeSize], hash, pathLen+1, maxPathLen); err != nil {
return nil, err
}

Check warning on line 155 in core/trie2/node_enc.go

View check run for this annotation

Codecov / codecov/patch

core/trie2/node_enc.go#L154-L155

Added lines #L154 - L155 were not covered by tests
if binary.children[1], err = decodeNode(blob[hashOrValueNodeSize:], hash, pathLen+1, maxPathLen); err != nil {
return nil, err
}

Check warning on line 158 in core/trie2/node_enc.go

View check run for this annotation

Codecov / codecov/patch

core/trie2/node_enc.go#L157-L158

Added lines #L157 - L158 were not covered by tests
return binary, nil
}

func decodeEdgeNode(blob []byte, hash *felt.Felt, pathLen, maxPathLen uint8) (*edgeNode, error) {
if len(blob) > edgeNodeMaxSize || len(blob) < hashOrValueNodeSize {
return nil, fmt.Errorf("invalid edge node size: %d", len(blob))
}

Check warning on line 165 in core/trie2/node_enc.go

View check run for this annotation

Codecov / codecov/patch

core/trie2/node_enc.go#L164-L165

Added lines #L164 - L165 were not covered by tests

edge := &edgeNode{path: &trieutils.BitArray{}}
if hash != nil {
edge.flags.hash = &hashNode{Felt: *hash}
}

return n, nil
var err error
if edge.child, err = decodeNode(blob[:hashOrValueNodeSize], nil, pathLen, maxPathLen); err != nil {
return nil, err
}

Check warning on line 175 in core/trie2/node_enc.go

View check run for this annotation

Codecov / codecov/patch

core/trie2/node_enc.go#L174-L175

Added lines #L174 - L175 were not covered by tests
if err := edge.path.UnmarshalBinary(blob[hashOrValueNodeSize:]); err != nil {
return nil, err
}

Check warning on line 178 in core/trie2/node_enc.go

View check run for this annotation

Codecov / codecov/patch

core/trie2/node_enc.go#L177-L178

Added lines #L177 - L178 were not covered by tests

if pathLen+edge.path.Len() == maxPathLen {
edge.child = &valueNode{Felt: edge.child.(*hashNode).Felt}
}
return edge, nil
}
14 changes: 3 additions & 11 deletions core/trie2/node_enc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ func TestNodeEncodingDecoding(t *testing.T) {
node: &edgeNode{
path: newPath(1, 0, 1),
child: &valueNode{Felt: newFelt(123)},
flags: nodeFlag{},
},
pathLen: 8,
maxPath: 8,
Expand All @@ -50,7 +49,6 @@ func TestNodeEncodingDecoding(t *testing.T) {
node: &edgeNode{
path: newPath(0, 1, 0),
child: &hashNode{Felt: newFelt(456)},
flags: nodeFlag{},
},
pathLen: 3,
maxPath: 8,
Expand All @@ -62,7 +60,6 @@ func TestNodeEncodingDecoding(t *testing.T) {
&hashNode{Felt: newFelt(111)},
&hashNode{Felt: newFelt(222)},
},
flags: nodeFlag{},
},
pathLen: 0,
maxPath: 8,
Expand All @@ -74,7 +71,6 @@ func TestNodeEncodingDecoding(t *testing.T) {
&valueNode{Felt: newFelt(555)},
&valueNode{Felt: newFelt(666)},
},
flags: nodeFlag{},
},
pathLen: 7,
maxPath: 8,
Expand All @@ -96,7 +92,6 @@ func TestNodeEncodingDecoding(t *testing.T) {
node: &edgeNode{
path: newPath(),
child: &hashNode{Felt: newFelt(1111)},
flags: nodeFlag{},
},
pathLen: 0,
maxPath: 8,
Expand All @@ -106,7 +101,6 @@ func TestNodeEncodingDecoding(t *testing.T) {
node: &edgeNode{
path: newPath(1, 1, 1, 1, 1, 1, 1, 1),
child: &valueNode{Felt: newFelt(2222)},
flags: nodeFlag{},
},
pathLen: 0,
maxPath: 8,
Expand All @@ -120,7 +114,7 @@ func TestNodeEncodingDecoding(t *testing.T) {

// Try to decode
hash := tt.node.hash(crypto.Pedersen)
decoded, err := decodeNode(encoded, *hash, tt.pathLen, tt.maxPath)
decoded, err := decodeNode(encoded, hash, tt.pathLen, tt.maxPath)

if tt.wantErr {
require.Error(t, err)
Expand Down Expand Up @@ -163,16 +157,14 @@ func TestNodeEncodingDecodingBoundary(t *testing.T) {

// Test with invalid path lengths
t.Run("invalid path lengths", func(t *testing.T) {
hash := felt.Zero
blob := make([]byte, hashOrValueNodeSize)
_, err := decodeNode(blob, hash, 255, 8) // pathLen > maxPath
_, err := decodeNode(blob, nil, 255, 8) // pathLen > maxPath
require.Error(t, err)
})

// Test with empty buffer
t.Run("empty buffer", func(t *testing.T) {
hash := felt.Zero
_, err := decodeNode([]byte{}, hash, 0, 8)
_, err := decodeNode([]byte{}, nil, 0, 8)
require.Error(t, err)
})
}
7 changes: 1 addition & 6 deletions core/trie2/proof.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,19 +152,14 @@ func VerifyProof(root, key *felt.Felt, proof *ProofNodeSet, hash crypto.HashFn)
// - Zero element proof: A single edge proof suffices for verification. The proof is invalid if there are additional elements.
//
// The function returns a boolean indicating if there are more elements and an error if the range proof is invalid.
//
// TODO(weiihann): Given a binary leaf and a left-sibling first key, if the right sibling is removed, the proof would still be valid.
// Conversely, given a binary leaf and a right-sibling last key, if the left sibling is removed, the proof would still be valid.
// Range proof should not be valid for both of these cases, but currently is, which is an attack vector.
// The problem probably lies in how we do root hash calculation.
func VerifyRangeProof(rootHash, first *felt.Felt, keys, values []*felt.Felt, proof *ProofNodeSet) (bool, error) { //nolint:funlen,gocyclo
// Ensure the number of keys and values are the same
if len(keys) != len(values) {
return false, fmt.Errorf("inconsistent length of proof data, keys: %d, values: %d", len(keys), len(values))
}

Check warning on line 159 in core/trie2/proof.go

View check run for this annotation

Codecov / codecov/patch

core/trie2/proof.go#L158-L159

Added lines #L158 - L159 were not covered by tests

// Ensure all keys are monotonically increasing and values contain no deletions
for i := 0; i < len(keys); i++ {
for i := range keys {
if i < len(keys)-1 && keys[i].Cmp(keys[i+1]) > 0 {
return false, errors.New("keys are not monotonic increasing")
}
Expand Down
12 changes: 6 additions & 6 deletions core/trie2/proof_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ func TestSingleSideRangeProof(t *testing.T) {

keys := make([]*felt.Felt, i+1)
values := make([]*felt.Felt, i+1)
for j := 0; j < i+1; j++ {
for j := range i + 1 {
keys[j] = records[j].key
values[j] = records[j].value
}
Expand Down Expand Up @@ -487,7 +487,7 @@ func TestBadRangeProof(t *testing.T) {
tr, records := randomTrie(t, 5000)
root := tr.Hash()

for i := 0; i < 500; i++ {
for range 500 {
start := rand.Intn(len(records))
end := rand.Intn(len(records)-start) + start + 1

Expand Down Expand Up @@ -539,7 +539,7 @@ func TestBadRangeProof(t *testing.T) {
func BenchmarkProve(b *testing.B) {
tr, records := randomTrie(b, 1000)
b.ResetTimer()
for i := 0; i < b.N; i++ {
for i := range b.N {
proof := NewProofNodeSet()
key := records[i%len(records)].key
if err := tr.Prove(key, proof); err != nil {
Expand All @@ -552,7 +552,7 @@ func BenchmarkVerifyProof(b *testing.B) {
tr, records := randomTrie(b, 1000)
root := tr.Hash()

var proofs []*ProofNodeSet
proofs := make([]*ProofNodeSet, 0, len(records))
for _, record := range records {
proof := NewProofNodeSet()
if err := tr.Prove(record.key, proof); err != nil {
Expand All @@ -562,7 +562,7 @@ func BenchmarkVerifyProof(b *testing.B) {
}

b.ResetTimer()
for i := 0; i < b.N; i++ {
for i := range b.N {
index := i % len(records)
if _, err := VerifyProof(&root, records[index].key, proofs[index], crypto.Pedersen); err != nil {
b.Fatal(err)
Expand All @@ -589,7 +589,7 @@ func BenchmarkVerifyRangeProof(b *testing.B) {
}

b.ResetTimer()
for i := 0; i < b.N; i++ {
for range b.N {
_, err := VerifyRangeProof(&root, keys[0], keys, values, proof)
require.NoError(b, err)
}
Expand Down
Loading

0 comments on commit 18b6b2a

Please sign in to comment.