From cebbeb93dc85cd2e5f4fa87ce1ddc1dc28531869 Mon Sep 17 00:00:00 2001 From: weiihann Date: Fri, 13 Dec 2024 22:58:24 +0800 Subject: [PATCH 01/44] add bytes benchmark --- core/trie/bitarray_test.go | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index e711a9ddd6..6ad6bf0ad3 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -2089,3 +2089,40 @@ func TestSubset(t *testing.T) { }) } } + +func BenchmarkBitArrayBytes(b *testing.B) { + testCases := []struct { + name string + ba bitArray + }{ + { + name: "empty", + ba: bitArray{pos: 0, words: maxBitArray}, + }, + { + name: "pos_38", + ba: bitArray{pos: 38, words: maxBitArray}, + }, + { + name: "pos_100", + ba: bitArray{pos: 100, words: maxBitArray}, + }, + { + name: "pos_201", + ba: bitArray{pos: 201, words: maxBitArray}, + }, + { + name: "pos_255", + ba: bitArray{pos: 255, words: maxBitArray}, + }, + } + + for _, tc := range testCases { + b.Run(tc.name, func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _ = tc.ba.Bytes() + } + }) + } +} From e53ac85a55952d2b596b01eb33dde5a1370105c8 Mon Sep 17 00:00:00 2001 From: weiihann Date: Sat, 14 Dec 2024 18:40:04 +0800 Subject: [PATCH 02/44] add Rsh test --- core/trie/bitarray_test.go | 103 +++++++++++++++++++++++++++++++++++++ 1 file changed, 103 insertions(+) diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index 6ad6bf0ad3..c446bd38bd 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -2126,3 +2126,106 @@ func BenchmarkBitArrayBytes(b *testing.B) { }) } } + +func TestRsh(t *testing.T) { + tests := []struct { + name string + initial *bitArray + shiftBy uint8 + expected *bitArray + }{ + { + name: "zero length array", + initial: &bitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + shiftBy: 5, + expected: &bitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "shift by 0", + initial: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + shiftBy: 0, + expected: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "shift by more than length", + initial: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + }, + shiftBy: 65, + expected: &bitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "shift by less than 64", + initial: &bitArray{ + len: 128, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + }, + shiftBy: 32, + expected: &bitArray{ + len: 96, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0x00000000FFFFFFFF, 0, 0}, + }, + }, + { + name: "shift by exactly 64", + initial: &bitArray{ + len: 128, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + }, + shiftBy: 64, + expected: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "shift by 128", + initial: &bitArray{ + len: 251, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF}, + }, + shiftBy: 128, + expected: &bitArray{ + len: 123, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + }, + }, + { + name: "shift by 192", + initial: &bitArray{ + len: 251, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF}, + }, + shiftBy: 192, + expected: &bitArray{ + len: 59, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := new(bitArray).Rsh(tt.initial, tt.shiftBy) + if !result.Equal(tt.expected) { + t.Errorf("Rsh() got = %+v, want %+v", result, tt.expected) + } + }) + } +} From 4cc9075366693a885ef22747a2f395657384862e Mon Sep 17 00:00:00 2001 From: weiihann Date: Sat, 14 Dec 2024 19:20:55 +0800 Subject: [PATCH 03/44] add Truncate --- core/trie/bitarray_test.go | 281 +++++++++++++++++++++++++++++++++++++ 1 file changed, 281 insertions(+) diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index c446bd38bd..35b0568fbe 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -2229,3 +2229,284 @@ func TestRsh(t *testing.T) { }) } } + +func TestPrefixEqual(t *testing.T) { + tests := []struct { + name string + a *bitArray + b *bitArray + want bool + }{ + { + name: "equal lengths, equal values", + a: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + b: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + want: true, + }, + { + name: "equal lengths, different values", + a: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + b: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFF0, 0, 0, 0}, + }, + want: false, + }, + { + name: "different lengths, a longer but same prefix", + a: &bitArray{ + len: 128, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + }, + b: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + want: true, + }, + { + name: "different lengths, b longer but same prefix", + a: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + b: &bitArray{ + len: 128, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + }, + want: true, + }, + { + name: "different lengths, different prefix", + a: &bitArray{ + len: 128, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + }, + b: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFF0, 0, 0, 0}, + }, + want: false, + }, + { + name: "zero length arrays", + a: &bitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + b: &bitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + want: true, + }, + { + name: "one zero length array", + a: &bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + b: &bitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + want: true, + }, + { + name: "max length difference", + a: &bitArray{ + len: 251, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFF}, + }, + b: &bitArray{ + len: 1, + words: [4]uint64{0x1, 0, 0, 0}, + }, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.a.PrefixEqual(tt.b); got != tt.want { + t.Errorf("PrefixEqual() = %v, want %v", got, tt.want) + } + // Test symmetry: a.PrefixEqual(b) should equal b.PrefixEqual(a) + if got := tt.b.PrefixEqual(tt.a); got != tt.want { + t.Errorf("PrefixEqual() symmetric test = %v, want %v", got, tt.want) + } + }) + } +} + +func TestTruncate(t *testing.T) { + tests := []struct { + name string + initial bitArray + length uint8 + expected bitArray + }{ + { + name: "truncate to zero", + initial: bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + length: 0, + expected: bitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "truncate within first word - 32 bits", + initial: bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + length: 32, + expected: bitArray{ + len: 32, + words: [4]uint64{0x00000000FFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "truncate to single bit", + initial: bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + length: 1, + expected: bitArray{ + len: 1, + words: [4]uint64{0x0000000000000001, 0, 0, 0}, + }, + }, + { + name: "truncate across words - 100 bits", + initial: bitArray{ + len: 128, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + }, + length: 100, + expected: bitArray{ + len: 100, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0x0000000FFFFFFFFF, 0, 0}, + }, + }, + { + name: "truncate at word boundary - 64 bits", + initial: bitArray{ + len: 128, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + }, + length: 64, + expected: bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "truncate at word boundary - 128 bits", + initial: bitArray{ + len: 192, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0}, + }, + length: 128, + expected: bitArray{ + len: 128, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + }, + }, + { + name: "truncate in third word - 150 bits", + initial: bitArray{ + len: 192, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0}, + }, + length: 150, + expected: bitArray{ + len: 150, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x3FFFFF, 0}, + }, + }, + { + name: "truncate in fourth word - 220 bits", + initial: bitArray{ + len: 255, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF}, + }, + length: 220, + expected: bitArray{ + len: 220, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFF}, + }, + }, + { + name: "truncate max length - 251 bits", + initial: bitArray{ + len: 255, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF}, + }, + length: 251, + expected: bitArray{ + len: 251, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFF}, + }, + }, + { + name: "truncate sparse bits", + initial: bitArray{ + len: 128, + words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0x5555555555555555, 0, 0}, + }, + length: 100, + expected: bitArray{ + len: 100, + words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0x0000000555555555, 0, 0}, + }, + }, + { + name: "no change when new length equals current length", + initial: bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + length: 64, + expected: bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "no change when new length greater than current length", + initial: bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + length: 128, + expected: bitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := new(bitArray).Truncate(&tt.initial, tt.length) + if !result.Equal(&tt.expected) { + t.Errorf("Truncate() got = %+v, want %+v", result, tt.expected) + } + }) + } +} From c4f75833ddbe01ba4d9f44fd4d5cc7eef556e50a Mon Sep 17 00:00:00 2001 From: weiihann Date: Tue, 17 Dec 2024 23:48:36 +0800 Subject: [PATCH 04/44] all tests passed --- core/trie/node_test.go | 2 +- core/trie/proof_test.go | 24 ++++++++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/core/trie/node_test.go b/core/trie/node_test.go index cc1bb06eda..b222732f4b 100644 --- a/core/trie/node_test.go +++ b/core/trie/node_test.go @@ -24,5 +24,5 @@ func TestNodeHash(t *testing.T) { } path := trie.NewBitArray(6, 42) - assert.Equal(t, expected, node.Hash(&path, crypto.Pedersen), "TestTrieNode_Hash failed") + assert.Equal(t, expected, node.Hash(path, crypto.Pedersen), "TestTrieNode_Hash failed") } diff --git a/core/trie/proof_test.go b/core/trie/proof_test.go index 5a43932042..3780761551 100644 --- a/core/trie/proof_test.go +++ b/core/trie/proof_test.go @@ -13,6 +13,30 @@ import ( "github.com/stretchr/testify/require" ) +func TestFix(t *testing.T) { + numKeys := 1000 + memdb := pebble.NewMemTest(t) + txn, err := memdb.NewTransaction(true) + require.NoError(t, err) + + tempTrie, err := trie.NewTriePedersen(trie.NewStorage(txn, []byte{0}), 251) + require.NoError(t, err) + + records := make([]*keyValue, numKeys) + for i := 1; i < numKeys+1; i++ { + key := new(felt.Felt).SetUint64(uint64(i)) + records[i-1] = &keyValue{key: key, value: key} + _, err := tempTrie.Put(key, key) + require.NoError(t, err) + } + + sort.Slice(records, func(i, j int) bool { + return records[i].key.Cmp(records[j].key) < 0 + }) + + require.NoError(t, tempTrie.Commit()) +} + func TestProve(t *testing.T) { t.Parallel() From 97ea74b3c61fedfc90a9eebd27a67764f5c50184 Mon Sep 17 00:00:00 2001 From: weiihann Date: Wed, 18 Dec 2024 10:35:09 +0800 Subject: [PATCH 05/44] improve comments --- core/trie/bitarray_test.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index 35b0568fbe..5581346899 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -3,6 +3,7 @@ package trie import ( "bytes" "encoding/binary" + "math" "math/bits" "testing" @@ -11,6 +12,8 @@ import ( "github.com/stretchr/testify/require" ) +var maxBits = [4]uint64{math.MaxUint64, math.MaxUint64, math.MaxUint64, math.MaxUint64} + const ( ones63 = 0x7FFFFFFFFFFFFFFF // 63 bits of 1 ) From 44ba4c508749fbd01ca57bc3e51dfd239b8e953d Mon Sep 17 00:00:00 2001 From: weiihann Date: Wed, 18 Dec 2024 11:04:28 +0800 Subject: [PATCH 06/44] minor chore --- core/trie/proof_test.go | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/core/trie/proof_test.go b/core/trie/proof_test.go index 3780761551..5a43932042 100644 --- a/core/trie/proof_test.go +++ b/core/trie/proof_test.go @@ -13,30 +13,6 @@ import ( "github.com/stretchr/testify/require" ) -func TestFix(t *testing.T) { - numKeys := 1000 - memdb := pebble.NewMemTest(t) - txn, err := memdb.NewTransaction(true) - require.NoError(t, err) - - tempTrie, err := trie.NewTriePedersen(trie.NewStorage(txn, []byte{0}), 251) - require.NoError(t, err) - - records := make([]*keyValue, numKeys) - for i := 1; i < numKeys+1; i++ { - key := new(felt.Felt).SetUint64(uint64(i)) - records[i-1] = &keyValue{key: key, value: key} - _, err := tempTrie.Put(key, key) - require.NoError(t, err) - } - - sort.Slice(records, func(i, j int) bool { - return records[i].key.Cmp(records[j].key) < 0 - }) - - require.NoError(t, tempTrie.Commit()) -} - func TestProve(t *testing.T) { t.Parallel() From dcec3a5f2931e49ca53582d963d8ec1e3f6a92b7 Mon Sep 17 00:00:00 2001 From: weiihann Date: Thu, 19 Dec 2024 12:53:03 +0800 Subject: [PATCH 07/44] improvements --- core/trie/node_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/trie/node_test.go b/core/trie/node_test.go index b222732f4b..cc1bb06eda 100644 --- a/core/trie/node_test.go +++ b/core/trie/node_test.go @@ -24,5 +24,5 @@ func TestNodeHash(t *testing.T) { } path := trie.NewBitArray(6, 42) - assert.Equal(t, expected, node.Hash(path, crypto.Pedersen), "TestTrieNode_Hash failed") + assert.Equal(t, expected, node.Hash(&path, crypto.Pedersen), "TestTrieNode_Hash failed") } From 60fe040965a7d53d7dde05b90b48f5dcb14176b2 Mon Sep 17 00:00:00 2001 From: weiihann Date: Thu, 26 Dec 2024 11:51:54 +0800 Subject: [PATCH 08/44] Implement trie nodes use hashFn --- core/trie2/bitarray.go | 455 +++++++++++++++++++++++++++++++++++++++++ core/trie2/node.go | 142 +++++++++++++ 2 files changed, 597 insertions(+) create mode 100644 core/trie2/bitarray.go create mode 100644 core/trie2/node.go diff --git a/core/trie2/bitarray.go b/core/trie2/bitarray.go new file mode 100644 index 0000000000..62bacbdcb5 --- /dev/null +++ b/core/trie2/bitarray.go @@ -0,0 +1,455 @@ +package trie2 + +import ( + "bytes" + "encoding/binary" + "encoding/hex" + "fmt" + "math" + "math/bits" + + "github.com/NethermindEth/juno/core/felt" +) + +const maxUint64 = uint64(math.MaxUint64) // 0xFFFFFFFFFFFFFFFF + +var emptyBitArray = new(BitArray) + +// Represents a bit array with length representing the number of used bits. +// It uses a little endian representation to do bitwise operations of the words efficiently. +// For example, if len is 10, it means that the 2^9, 2^8, ..., 2^0 bits are used. +// The max length is 255 bits (uint8), because our use case only need up to 251 bits for a given trie key. +// Although words can be used to represent 256 bits, we don't want to add an additional byte for the length. +type BitArray struct { + len uint8 // number of used bits + words [4]uint64 // little endian (i.e. words[0] is the least significant) +} + +func NewBitArray(length uint8, val uint64) BitArray { + var b BitArray + b.SetUint64(length, val) + return b +} + +// Returns the felt representation of the bit array. +func (b *BitArray) Felt() felt.Felt { + var f felt.Felt + f.SetBytes(b.Bytes()) + return f +} + +func (b *BitArray) Len() uint8 { + return b.len +} + +// Returns the bytes representation of the bit array in big endian format +func (b *BitArray) Bytes() []byte { + var res [32]byte + + b.truncateToLength() + binary.BigEndian.PutUint64(res[0:8], b.words[3]) + binary.BigEndian.PutUint64(res[8:16], b.words[2]) + binary.BigEndian.PutUint64(res[16:24], b.words[1]) + binary.BigEndian.PutUint64(res[24:32], b.words[0]) + + return res[:] +} + +// Sets the bit array to the least significant 'n' bits of x. +// If length >= x.len, the bit array is an exact copy of x. +// For example: +// +// x = 11001011 (len=8) +// LSBs(x, 4) = 1011 (len=4) +// LSBs(x, 10) = 11001011 (len=8, original x) +// LSBs(x, 0) = 0 (len=0) +// +//nolint:mnd +func (b *BitArray) LSBs(x *BitArray, n uint8) *BitArray { + if n >= x.len { + return b.Set(x) + } + + b.Set(x) + b.len = n + + // Clear all words beyond what's needed + switch { + case n == 0: + b.words = [4]uint64{0, 0, 0, 0} + case n <= 64: + mask := maxUint64 >> (64 - n) + b.words[0] &= mask + b.words[1] = 0 + b.words[2] = 0 + b.words[3] = 0 + case n <= 128: + mask := maxUint64 >> (128 - n) + b.words[1] &= mask + b.words[2] = 0 + b.words[3] = 0 + case n <= 192: + mask := maxUint64 >> (192 - n) + b.words[2] &= mask + b.words[3] = 0 + default: + mask := maxUint64 >> (256 - uint16(n)) + b.words[3] &= mask + } + + return b +} + +// Checks if the current bit array share the same most significant bits with another, where the length of +// the check is determined by the shorter array. Returns true if either array has +// length 0, or if the first min(b.len, x.len) MSBs are identical. +// +// For example: +// +// a = 1101 (len=4) +// b = 11010111 (len=8) +// a.EqualMSBs(b) = true // First 4 MSBs match +// +// a = 1100 (len=4) +// b = 1101 (len=4) +// a.EqualMSBs(b) = false // All bits compared, not equal +// +// a = 1100 (len=4) +// b = [] (len=0) +// a.EqualMSBs(b) = true // Zero length is always a prefix match +func (b *BitArray) EqualMSBs(x *BitArray) bool { + if b.len == x.len { + return b.Equal(x) + } + + if b.len == 0 || x.len == 0 { + return true + } + + // Compare only the first min(b.len, x.len) bits + minLen := b.len + if x.len < minLen { + minLen = x.len + } + + return new(BitArray).MSBs(b, minLen).Equal(new(BitArray).MSBs(x, minLen)) +} + +// Sets the bit array to the most significant 'n' bits of x. +// If n >= x.len, the bit array is an exact copy of x. +// For example: +// +// x = 11001011 (len=8) +// MSBs(x, 4) = 1100 (len=4) +// MSBs(x, 10) = 11001011 (len=8, original x) +// MSBs(x, 0) = 0 (len=0) +func (b *BitArray) MSBs(x *BitArray, n uint8) *BitArray { + if n >= x.len { + return b.Set(x) + } + + return b.Rsh(x, x.len-n) +} + +// Sets the bit array to the longest sequence of matching most significant bits between two bit arrays. +// For example: +// +// x = 1101 0111 (len=8) +// y = 1101 0000 (len=8) +// CommonMSBs(x,y) = 1101 (len=4) +func (b *BitArray) CommonMSBs(x, y *BitArray) *BitArray { + if x.len == 0 || y.len == 0 { + return b.clear() + } + + long, short := x, y + if x.len < y.len { + long, short = y, x + } + + // Align arrays by right-shifting longer array and then XOR to find differences + // Example: + // short = 1100 (len=4) + // long = 1101 0111 (len=8) + // + // Step 1: Right shift longer array by 4 + // short = 1100 + // long = 1101 + // + // Step 2: XOR shows difference at last bit + // 1100 (short) + // 1101 (aligned long) + // ---- XOR + // 0001 (difference at last position) + // We can then use the position of the first set bit and right-shift to get the common MSBs + diff := long.len - short.len + b.Rsh(long, diff).Xor(b, short) + divergentBit := findFirstSetBit(b) + + return b.Rsh(short, divergentBit) +} + +// Sets the bit array to x >> n and returns the bit array. +// +//nolint:mnd +func (b *BitArray) Rsh(x *BitArray, n uint8) *BitArray { + if x.len == 0 { + return b.Set(x) + } + + if n >= x.len { + return b.clear() + } + + switch { + case n == 0: + return b.Set(x) + case n >= 192: + b.rsh192(x) + b.len = x.len - n + n -= 192 + b.words[0] >>= n + case n >= 128: + b.rsh128(x) + b.len = x.len - n + n -= 128 + b.words[0] = (b.words[0] >> n) | (b.words[1] << (64 - n)) + b.words[1] >>= n + case n >= 64: + b.rsh64(x) + b.len = x.len - n + n -= 64 + b.words[0] = (b.words[0] >> n) | (b.words[1] << (64 - n)) + b.words[1] = (b.words[1] >> n) | (b.words[2] << (64 - n)) + b.words[2] >>= n + default: + b.Set(x) + b.len -= n + b.words[0] = (b.words[0] >> n) | (b.words[1] << (64 - n)) + b.words[1] = (b.words[1] >> n) | (b.words[2] << (64 - n)) + b.words[2] = (b.words[2] >> n) | (b.words[3] << (64 - n)) + b.words[3] >>= n + } + + return b +} + +// Sets the bit array to x ^ y and returns the bit array. +func (b *BitArray) Xor(x, y *BitArray) *BitArray { + b.words[0] = x.words[0] ^ y.words[0] + b.words[1] = x.words[1] ^ y.words[1] + b.words[2] = x.words[2] ^ y.words[2] + b.words[3] = x.words[3] ^ y.words[3] + return b +} + +// Checks if two bit arrays are equal +func (b *BitArray) Equal(x *BitArray) bool { + // TODO(weiihann): this is really not a good thing to do... + if b == nil && x == nil { + return true + } else if b == nil || x == nil { + return false + } + + return b.len == x.len && + b.words[0] == x.words[0] && + b.words[1] == x.words[1] && + b.words[2] == x.words[2] && + b.words[3] == x.words[3] +} + +// Returns true if bit n-th is set, where n = 0 is LSB. +// The n must be <= 255. +func (b *BitArray) IsBitSet(n uint8) bool { + if n >= b.len { + return false + } + + return (b.words[n/64] & (1 << (n % 64))) != 0 +} + +// Serialises the BitArray into a bytes buffer in the following format: +// - First byte: length of the bit array (0-255) +// - Remaining bytes: the necessary bytes included in big endian order +// Example: +// +// BitArray{len: 10, words: [4]uint64{0x03FF}} -> [0x0A, 0x03, 0xFF] +func (b *BitArray) Write(buf *bytes.Buffer) (int, error) { + if err := buf.WriteByte(b.len); err != nil { + return 0, err + } + + n, err := buf.Write(b.activeBytes()) + return n + 1, err +} + +// Deserialises the BitArray from a bytes buffer in the following format: +// - First byte: length of the bit array (0-255) +// - Remaining bytes: the necessary bytes included in big endian order +// Example: +// +// [0x0A, 0x03, 0xFF] -> BitArray{len: 10, words: [4]uint64{0x03FF}} +func (b *BitArray) UnmarshalBinary(data []byte) { + b.len = data[0] + + var bs [32]byte + copy(bs[32-b.byteCount():], data[1:]) + b.setBytes32(bs[:]) +} + +// Sets the bit array to the same value as x. +func (b *BitArray) Set(x *BitArray) *BitArray { + b.len = x.len + b.words[0] = x.words[0] + b.words[1] = x.words[1] + b.words[2] = x.words[2] + b.words[3] = x.words[3] + return b +} + +// Sets the bit array to the bytes representation of a felt. +func (b *BitArray) SetFelt(length uint8, f *felt.Felt) *BitArray { + b.len = length + b.setFelt(f) + b.truncateToLength() + return b +} + +// Sets the bit array to the bytes representation of a felt with length 251. +func (b *BitArray) SetFelt251(f *felt.Felt) *BitArray { + b.len = 251 + b.setFelt(f) + b.truncateToLength() + return b +} + +// Interprets the data as the big-endian bytes, sets the bit array to that value and returns it. +// If the data is larger than 32 bytes, only the first 32 bytes are used. +func (b *BitArray) SetBytes(length uint8, data []byte) *BitArray { + b.setBytes32(data) + b.len = length + b.truncateToLength() + return b +} + +// Sets the bit array to the uint64 representation of a bit array. +func (b *BitArray) SetUint64(length uint8, data uint64) *BitArray { + b.words[0] = data + b.len = length + b.truncateToLength() + return b +} + +// Returns the length of the encoded bit array in bytes. +func (b *BitArray) EncodedLen() uint { + return b.byteCount() + 1 +} + +// Returns a deep copy of the bit array. +func (b *BitArray) Copy() BitArray { + var res BitArray + res.Set(b) + return res +} + +// Returns a string representation of the bit array. +func (b *BitArray) String() string { + return fmt.Sprintf("(%d) %s", b.len, hex.EncodeToString(b.Bytes())) +} + +func (b *BitArray) setFelt(f *felt.Felt) { + res := f.Bytes() + b.words[3] = binary.BigEndian.Uint64(res[0:8]) + b.words[2] = binary.BigEndian.Uint64(res[8:16]) + b.words[1] = binary.BigEndian.Uint64(res[16:24]) + b.words[0] = binary.BigEndian.Uint64(res[24:32]) +} + +func (b *BitArray) setBytes32(data []byte) { + _ = data[31] + b.words[3] = binary.BigEndian.Uint64(data[0:8]) + b.words[2] = binary.BigEndian.Uint64(data[8:16]) + b.words[1] = binary.BigEndian.Uint64(data[16:24]) + b.words[0] = binary.BigEndian.Uint64(data[24:32]) +} + +// Returns the minimum number of bytes needed to represent the bit array. +// It rounds up to the nearest byte. +func (b *BitArray) byteCount() uint { + const bits8 = 8 + // Cast to uint16 to avoid overflow + return (uint(b.len) + (bits8 - 1)) / uint(bits8) +} + +// Returns a slice containing only the bytes that are actually used by the bit array, +// as specified by the length. The returned slice is in big-endian order. +// +// Example: +// +// len = 10, words = [0x3FF, 0, 0, 0] -> [0x03, 0xFF] +func (b *BitArray) activeBytes() []byte { + wordsBytes := b.Bytes() + return wordsBytes[32-b.byteCount():] +} + +func (b *BitArray) rsh64(x *BitArray) { + b.words[3], b.words[2], b.words[1], b.words[0] = 0, x.words[3], x.words[2], x.words[1] +} + +func (b *BitArray) rsh128(x *BitArray) { + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, x.words[3], x.words[2] +} + +func (b *BitArray) rsh192(x *BitArray) { + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, 0, x.words[3] +} + +func (b *BitArray) clear() *BitArray { + b.len = 0 + b.words[0], b.words[1], b.words[2], b.words[3] = 0, 0, 0, 0 + return b +} + +// Truncates the bit array to the specified length, ensuring that any unused bits are all zeros. +// +//nolint:mnd +func (b *BitArray) truncateToLength() { + switch { + case b.len == 0: + b.words = [4]uint64{0, 0, 0, 0} + case b.len <= 64: + b.words[0] &= maxUint64 >> (64 - b.len) + b.words[1], b.words[2], b.words[3] = 0, 0, 0 + case b.len <= 128: + b.words[1] &= maxUint64 >> (128 - b.len) + b.words[2], b.words[3] = 0, 0 + case b.len <= 192: + b.words[2] &= maxUint64 >> (192 - b.len) + b.words[3] = 0 + default: + b.words[3] &= maxUint64 >> (256 - uint16(b.len)) + } +} + +// Returns the position of the first '1' bit in the array, scanning from most significant to least significant bit. +// The bit position is counted from the least significant bit, starting at 0. +// For example: +// +// array = 0000 0000 ... 0100 (len=251) +// findFirstSetBit() = 3 // third bit from right is set +func findFirstSetBit(b *BitArray) uint8 { + if b.len == 0 { + return 0 + } + + // Start from the most significant and move towards the least significant + for i := 3; i >= 0; i-- { + if word := b.words[i]; word != 0 { + return uint8((i+1)*64 - bits.LeadingZeros64(word)) + } + } + + // All bits are zero, no set bit found + return 0 +} diff --git a/core/trie2/node.go b/core/trie2/node.go new file mode 100644 index 0000000000..c1878433f7 --- /dev/null +++ b/core/trie2/node.go @@ -0,0 +1,142 @@ +package trie2 + +import ( + "bytes" + "fmt" + "strings" + + "github.com/NethermindEth/juno/core/crypto" + "github.com/NethermindEth/juno/core/felt" +) + +var ( + _ node = (*internalNode)(nil) + _ node = (*edgeNode)(nil) + _ node = (*hashNode)(nil) + _ node = (*valueNode)(nil) +) + +type node interface { + hash(crypto.HashFn) *felt.Felt // TODO(weiihann): return felt value instead of pointers + cache() (*hashNode, bool) + encode(*bytes.Buffer) error + String() string +} + +type ( + internalNode struct { + children [2]node // 0 = left, 1 = right + flags nodeFlag + } + edgeNode struct { + child node + path *BitArray + flags nodeFlag + } + hashNode struct{ *felt.Felt } + valueNode struct{ *felt.Felt } +) + +type nodeFlag struct { + hash *hashNode + dirty bool +} + +func (n *internalNode) hash(hf crypto.HashFn) *felt.Felt { + return hf(n.children[0].hash(hf), n.children[1].hash(hf)) +} + +func (n *edgeNode) hash(hf crypto.HashFn) *felt.Felt { + var length [32]byte + length[31] = n.path.len + pathFelt := n.path.Felt() + lengthFelt := new(felt.Felt).SetBytes(length[:]) + return new(felt.Felt).Add(hf(n.child.hash(hf), &pathFelt), lengthFelt) +} + +func (n hashNode) hash(crypto.HashFn) *felt.Felt { return n.Felt } +func (n valueNode) hash(crypto.HashFn) *felt.Felt { return n.Felt } + +func (n *internalNode) cache() (*hashNode, bool) { return n.flags.hash, n.flags.dirty } +func (n *edgeNode) cache() (*hashNode, bool) { return n.flags.hash, n.flags.dirty } +func (n hashNode) cache() (*hashNode, bool) { return nil, true } +func (n valueNode) cache() (*hashNode, bool) { return nil, true } + +func (n *internalNode) String() string { + return fmt.Sprintf("Internal[\n left: %s\n right: %s\n]", + indent(n.children[0].String()), + indent(n.children[1].String())) +} + +func (n *edgeNode) String() string { + return fmt.Sprintf("Edge{\n path: %s\n child: %s\n}", + n.path.String(), + indent(n.child.String())) +} + +func (n hashNode) String() string { + return fmt.Sprintf("Hash(%s)", n.Felt.String()) +} + +func (n valueNode) String() string { + return fmt.Sprintf("Value(%s)", n.Felt.String()) +} + +func (n *internalNode) encode(buf *bytes.Buffer) error { + if err := n.children[0].encode(buf); err != nil { + return err + } + + if err := n.children[1].encode(buf); err != nil { + return err + } + + return nil +} + +func (n *edgeNode) encode(buf *bytes.Buffer) error { + if _, err := n.path.Write(buf); err != nil { + return err + } + + if err := n.child.encode(buf); err != nil { + return err + } + + return nil +} + +func (n hashNode) encode(buf *bytes.Buffer) error { + if _, err := buf.Write(n.Felt.Marshal()); err != nil { + return err + } + + return nil +} + +func (n valueNode) encode(buf *bytes.Buffer) error { + if _, err := buf.Write(n.Felt.Marshal()); err != nil { + return err + } + + return nil +} + +func (n *edgeNode) PathMatches(key *BitArray) bool { + return n.path.EqualMSBs(key) +} + +func (n *edgeNode) CommonPath(key *BitArray) BitArray { + var commonPath BitArray + commonPath.CommonMSBs(n.path, key) + return commonPath +} + +// Helper function to indent each line of a string +func indent(s string) string { + lines := strings.Split(s, "\n") + for i, line := range lines { + lines[i] = " " + line + } + return strings.Join(lines, "\n") +} From d8014296e1a1782f20f05a20128dffa85336c545 Mon Sep 17 00:00:00 2001 From: weiihann Date: Mon, 30 Dec 2024 22:24:25 +0800 Subject: [PATCH 09/44] Update() works on TrieD d --- core/trie2/bitarray.go | 168 ++++++++++++++++++++++++++++-- core/trie2/errors.go | 5 + core/trie2/node.go | 51 +++++---- core/trie2/trie.go | 225 ++++++++++++++++++++++++++++++++++++++++ core/trie2/trie_test.go | 89 ++++++++++++++++ 5 files changed, 510 insertions(+), 28 deletions(-) create mode 100644 core/trie2/errors.go create mode 100644 core/trie2/trie.go create mode 100644 core/trie2/trie_test.go diff --git a/core/trie2/bitarray.go b/core/trie2/bitarray.go index 62bacbdcb5..2d408d550d 100644 --- a/core/trie2/bitarray.go +++ b/core/trie2/bitarray.go @@ -11,7 +11,10 @@ import ( "github.com/NethermindEth/juno/core/felt" ) -const maxUint64 = uint64(math.MaxUint64) // 0xFFFFFFFFFFFFFFFF +const ( + maxUint64 = uint64(math.MaxUint64) // 0xFFFFFFFFFFFFFFFF + maxUint8 = uint8(math.MaxUint8) +) var emptyBitArray = new(BitArray) @@ -56,16 +59,17 @@ func (b *BitArray) Bytes() []byte { } // Sets the bit array to the least significant 'n' bits of x. +// n is counted from the least significant bit, starting at 0. // If length >= x.len, the bit array is an exact copy of x. // For example: // // x = 11001011 (len=8) -// LSBs(x, 4) = 1011 (len=4) -// LSBs(x, 10) = 11001011 (len=8, original x) -// LSBs(x, 0) = 0 (len=0) +// LSBsFromLSB(x, 4) = 1011 (len=4) +// LSBsFromLSB(x, 10) = 11001011 (len=8, original x) +// LSBsFromLSB(x, 0) = 0 (len=0) // //nolint:mnd -func (b *BitArray) LSBs(x *BitArray, n uint8) *BitArray { +func (b *BitArray) LSBsFromLSB(x *BitArray, n uint8) *BitArray { if n >= x.len { return b.Set(x) } @@ -100,6 +104,25 @@ func (b *BitArray) LSBs(x *BitArray, n uint8) *BitArray { return b } +// Returns the least significant bits of `x` with `n` counted from the most significant bit, starting at 0. +// For example: +// +// x = 11001011 (len=8) +// LSBs(x, 1) = 1001011 (len=7) +// LSBs(x, 10) = 0 (len=0) +// LSBs(x, 0) = 11001011 (len=8, original x) +func (b *BitArray) LSBs(x *BitArray, n uint8) *BitArray { + if n == 0 { + return b.Set(x) + } + + if n > x.Len() { + return b.clear() + } + + return b.LSBsFromLSB(x, x.Len()-n) +} + // Checks if the current bit array share the same most significant bits with another, where the length of // the check is determined by the shorter array. Returns true if either array has // length 0, or if the first min(b.len, x.len) MSBs are identical. @@ -231,6 +254,85 @@ func (b *BitArray) Rsh(x *BitArray, n uint8) *BitArray { b.words[3] >>= n } + b.truncateToLength() + return b +} + +// Lsh sets the bit array to x << n and returns the bit array. +// +//nolint:mnd +func (b *BitArray) Lsh(x *BitArray, n uint8) *BitArray { + b.Set(x) + + if x.len == 0 || n == 0 { + return b + } + + // If the result will overflow, we set the length to the max length + // but we still shift `n` bits + if n > maxUint8-x.len { + b.len = maxUint8 + } else { + b.len = x.len + n + } + + switch { + case n == 0: + return b + case n >= 192: + b.lsh192(x) + n -= 192 + b.words[3] <<= n + case n >= 128: + b.lsh128(x) + n -= 128 + b.words[3] = (b.words[3] << n) | (b.words[2] >> (64 - n)) + b.words[2] <<= n + case n >= 64: + b.lsh64(x) + n -= 64 + b.words[3] = (b.words[3] << n) | (b.words[2] >> (64 - n)) + b.words[2] = (b.words[2] << n) | (b.words[1] >> (64 - n)) + b.words[1] <<= n + default: + b.words[3] = (b.words[3] << n) | (b.words[2] >> (64 - n)) + b.words[2] = (b.words[2] << n) | (b.words[1] >> (64 - n)) + b.words[1] = (b.words[1] << n) | (b.words[0] >> (64 - n)) + b.words[0] <<= n + } + + b.truncateToLength() + return b +} + +// Sets the bit array to the concatenation of x and y and returns the bit array. +// For example: +// +// x = 000 (len=3) +// y = 111 (len=3) +// Append(x,y) = 000111 (len=6) +func (b *BitArray) Append(x, y *BitArray) *BitArray { + if x.len == 0 { + return b.Set(y) + } + if y.len == 0 { + return b.Set(x) + } + + // First copy x + b.Set(x) + + // Then shift left by y's length and OR with y + return b.Lsh(b, y.len).Or(b, y) +} + +// Sets the bit array to x | y and returns the bit array. +func (b *BitArray) Or(x, y *BitArray) *BitArray { + b.words[0] = x.words[0] | y.words[0] + b.words[1] = x.words[1] | y.words[1] + b.words[2] = x.words[2] | y.words[2] + b.words[3] = x.words[3] | y.words[3] + b.len = x.len return b } @@ -260,13 +362,49 @@ func (b *BitArray) Equal(x *BitArray) bool { } // Returns true if bit n-th is set, where n = 0 is LSB. -// The n must be <= 255. -func (b *BitArray) IsBitSet(n uint8) bool { +func (b *BitArray) IsBitSetFromLSB(n uint8) bool { + return b.BitSetFromLSB(n) == 1 +} + +// Returns the bit value at position n, where n = 0 is LSB. +// If n is out of bounds, returns 0. +func (b *BitArray) BitSetFromLSB(n uint8) uint8 { if n >= b.len { - return false + return 0 + } + + if (b.words[n/64] & (1 << (n % 64))) != 0 { + return 1 + } + + return 0 +} + +func (b *BitArray) IsBitSet(n uint8) bool { + return b.BitSet(n) == 1 +} + +// Returns the bit value at position n, where n = 0 is MSB. +// If n is out of bounds, returns 0. +func (b *BitArray) BitSet(n uint8) uint8 { + if n >= b.Len() { + return 0 } - return (b.words[n/64] & (1 << (n % 64))) != 0 + return b.BitSetFromLSB(b.Len() - n - 1) +} + +// Returns the bit value at the most significant bit +func (b *BitArray) MSB() uint8 { + return b.BitSet(0) +} + +func (b *BitArray) LSB() uint8 { + return b.BitSetFromLSB(0) +} + +func (b *BitArray) IsEmpty() bool { + return b.len == 0 } // Serialises the BitArray into a bytes buffer in the following format: @@ -405,6 +543,18 @@ func (b *BitArray) rsh192(x *BitArray) { b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, 0, x.words[3] } +func (b *BitArray) lsh64(x *BitArray) { + b.words[3], b.words[2], b.words[1], b.words[0] = x.words[2], x.words[1], x.words[0], 0 +} + +func (b *BitArray) lsh128(x *BitArray) { + b.words[3], b.words[2], b.words[1], b.words[0] = x.words[1], x.words[0], 0, 0 +} + +func (b *BitArray) lsh192(x *BitArray) { + b.words[3], b.words[2], b.words[1], b.words[0] = x.words[0], 0, 0, 0 +} + func (b *BitArray) clear() *BitArray { b.len = 0 b.words[0], b.words[1], b.words[2], b.words[3] = 0, 0, 0, 0 diff --git a/core/trie2/errors.go b/core/trie2/errors.go new file mode 100644 index 0000000000..306fe62059 --- /dev/null +++ b/core/trie2/errors.go @@ -0,0 +1,5 @@ +package trie2 + +import "errors" + +var ErrCommitted = errors.New("trie is committed") diff --git a/core/trie2/node.go b/core/trie2/node.go index c1878433f7..5e072f7aae 100644 --- a/core/trie2/node.go +++ b/core/trie2/node.go @@ -10,7 +10,7 @@ import ( ) var ( - _ node = (*internalNode)(nil) + _ node = (*binaryNode)(nil) _ node = (*edgeNode)(nil) _ node = (*hashNode)(nil) _ node = (*valueNode)(nil) @@ -19,12 +19,12 @@ var ( type node interface { hash(crypto.HashFn) *felt.Felt // TODO(weiihann): return felt value instead of pointers cache() (*hashNode, bool) - encode(*bytes.Buffer) error + write(*bytes.Buffer) error String() string } type ( - internalNode struct { + binaryNode struct { children [2]node // 0 = left, 1 = right flags nodeFlag } @@ -37,12 +37,21 @@ type ( valueNode struct{ *felt.Felt } ) +const ( + binaryNodeType byte = iota + edgeNodeType + hashNodeType + valueNodeType +) + type nodeFlag struct { hash *hashNode dirty bool } -func (n *internalNode) hash(hf crypto.HashFn) *felt.Felt { +func newFlag() nodeFlag { return nodeFlag{dirty: false} } + +func (n *binaryNode) hash(hf crypto.HashFn) *felt.Felt { return hf(n.children[0].hash(hf), n.children[1].hash(hf)) } @@ -57,13 +66,13 @@ func (n *edgeNode) hash(hf crypto.HashFn) *felt.Felt { func (n hashNode) hash(crypto.HashFn) *felt.Felt { return n.Felt } func (n valueNode) hash(crypto.HashFn) *felt.Felt { return n.Felt } -func (n *internalNode) cache() (*hashNode, bool) { return n.flags.hash, n.flags.dirty } -func (n *edgeNode) cache() (*hashNode, bool) { return n.flags.hash, n.flags.dirty } -func (n hashNode) cache() (*hashNode, bool) { return nil, true } -func (n valueNode) cache() (*hashNode, bool) { return nil, true } +func (n *binaryNode) cache() (*hashNode, bool) { return n.flags.hash, n.flags.dirty } +func (n *edgeNode) cache() (*hashNode, bool) { return n.flags.hash, n.flags.dirty } +func (n hashNode) cache() (*hashNode, bool) { return nil, true } +func (n valueNode) cache() (*hashNode, bool) { return nil, true } -func (n *internalNode) String() string { - return fmt.Sprintf("Internal[\n left: %s\n right: %s\n]", +func (n *binaryNode) String() string { + return fmt.Sprintf("Binary[\n left: %s\n right: %s\n]", indent(n.children[0].String()), indent(n.children[1].String())) } @@ -82,31 +91,31 @@ func (n valueNode) String() string { return fmt.Sprintf("Value(%s)", n.Felt.String()) } -func (n *internalNode) encode(buf *bytes.Buffer) error { - if err := n.children[0].encode(buf); err != nil { +func (n *binaryNode) write(buf *bytes.Buffer) error { + if err := n.children[0].write(buf); err != nil { return err } - if err := n.children[1].encode(buf); err != nil { + if err := n.children[1].write(buf); err != nil { return err } return nil } -func (n *edgeNode) encode(buf *bytes.Buffer) error { +func (n *edgeNode) write(buf *bytes.Buffer) error { if _, err := n.path.Write(buf); err != nil { return err } - if err := n.child.encode(buf); err != nil { + if err := n.child.write(buf); err != nil { return err } return nil } -func (n hashNode) encode(buf *bytes.Buffer) error { +func (n hashNode) write(buf *bytes.Buffer) error { if _, err := buf.Write(n.Felt.Marshal()); err != nil { return err } @@ -114,7 +123,7 @@ func (n hashNode) encode(buf *bytes.Buffer) error { return nil } -func (n valueNode) encode(buf *bytes.Buffer) error { +func (n valueNode) write(buf *bytes.Buffer) error { if _, err := buf.Write(n.Felt.Marshal()); err != nil { return err } @@ -122,11 +131,15 @@ func (n valueNode) encode(buf *bytes.Buffer) error { return nil } -func (n *edgeNode) PathMatches(key *BitArray) bool { +// TODO(weiihann): check if we want to return a pointer or a value +func (n *binaryNode) copy() *binaryNode { cpy := *n; return &cpy } +func (n *edgeNode) copy() *edgeNode { cpy := *n; return &cpy } + +func (n *edgeNode) pathMatches(key *BitArray) bool { return n.path.EqualMSBs(key) } -func (n *edgeNode) CommonPath(key *BitArray) BitArray { +func (n *edgeNode) commonPath(key *BitArray) BitArray { var commonPath BitArray commonPath.CommonMSBs(n.path, key) return commonPath diff --git a/core/trie2/trie.go b/core/trie2/trie.go new file mode 100644 index 0000000000..2a60ced1fe --- /dev/null +++ b/core/trie2/trie.go @@ -0,0 +1,225 @@ +package trie2 + +import ( + "fmt" + + "github.com/NethermindEth/juno/core/felt" +) + +type Trie struct { + height uint8 + root node + reader interface{} // TODO(weiihann): implement reader + // committed bool +} + +// TODO(weiihann): implement this +func NewTrie(height uint8) *Trie { + return &Trie{height: height} +} + +func (t *Trie) Update(key, value *felt.Felt) error { + // if t.commited { + // return ErrCommitted + // } + return t.update(key, value) +} + +func (t *Trie) Get(key *felt.Felt) (*felt.Felt, error) { + k := t.FeltToKey(key) + // TODO(weiihann): get the value directly from the reader + val, root, didResolve, err := t.get(t.root, &k) + // In Starknet, a non-existent key is mapped to felt.Zero + if val == nil { + val = &felt.Zero + } + if err == nil && didResolve { + t.root = root + } + return val, err +} + +func (t *Trie) Delete(key *felt.Felt) error { + panic("TODO(weiihann): implement me") +} + +// Traverses the trie recursively to find the value that corresponds to the key. +func (t *Trie) get(n node, key *BitArray) (*felt.Felt, node, bool, error) { + switch n := n.(type) { + case *edgeNode: + if !n.pathMatches(key) { + return nil, nil, false, nil + } + val, child, didResolve, err := t.get(n.child, key.LSBs(key, n.path.Len())) + if err == nil && didResolve { + n = n.copy() + n.child = child + } + return val, n, didResolve, err + case *binaryNode: + bit := key.MSB() + val, child, didResolve, err := t.get(n.children[bit], key.LSBs(key, 1)) + if err == nil && didResolve { + n = n.copy() + n.children[bit] = child + } + return val, n, didResolve, err + case hashNode: + panic("TODO(weiihann): implement me") + case valueNode: + return n.Felt, n, false, nil + case nil: + return nil, nil, false, nil + default: + panic(fmt.Sprintf("unknown node type: %T", n)) + } +} + +func (t *Trie) update(key, value *felt.Felt) error { + k := t.FeltToKey(key) + if value.IsZero() { + _, n, err := t.delete(t.root, &k) + if err != nil { + return err + } + t.root = n + } else { + _, n, err := t.insert(t.root, &k, valueNode{Felt: value}) + if err != nil { + return err + } + t.root = n + } + return nil +} + +func (t *Trie) insert(n node, key *BitArray, value node) (bool, node, error) { + // We reach the end of the key + if key.Len() == 0 { + if v, ok := n.(valueNode); ok { + return v.Equal(value.(valueNode).Felt), value, nil + } + return true, value, nil + } + + switch n := n.(type) { + case *edgeNode: + match := n.commonPath(key) + // If the whole key matches, just keep this edge node as it is and update the value + if match.Len() == n.path.Len() { + dirty, newNode, err := t.insert(n.child, key.LSBs(key, match.Len()), value) + if !dirty || err != nil { + return false, n, err + } + return true, &edgeNode{ + path: n.path, + child: newNode, + flags: newFlag(), + }, nil + } + // Otherwise branch out at the bit index where they differ + branch := &binaryNode{flags: newFlag()} + var err error + _, branch.children[n.path.BitSet(match.Len())], err = t.insert(nil, new(BitArray).LSBs(n.path, match.Len()+1), n.child) + if err != nil { + return false, n, err + } + + _, branch.children[key.BitSet(match.Len())], err = t.insert(nil, new(BitArray).LSBs(key, match.Len()+1), value) + if err != nil { + return false, n, err + } + + // Replace this edge node with the new binary node if it occurs at the current MSB + if match.IsEmpty() { + return true, branch, nil + } + + return true, &edgeNode{path: new(BitArray).MSBs(key, match.Len()), child: branch, flags: newFlag()}, nil + + case *binaryNode: + bit := key.MSB() + dirty, newNode, err := t.insert(n.children[bit], new(BitArray).LSBs(key, 1), value) + if !dirty || err != nil { + return false, n, err + } + n = n.copy() + n.flags = newFlag() + n.children[bit] = newNode + return true, n, nil + case nil: + if key.IsEmpty() { + return true, value, nil + } + return true, &edgeNode{path: key, child: value, flags: newFlag()}, nil + case hashNode: + panic("TODO(weiihann): implement me") + default: + panic(fmt.Sprintf("unknown node type: %T", n)) + } +} + +func (t *Trie) delete(n node, key *BitArray) (bool, node, error) { + switch n := n.(type) { + case *edgeNode: + match := n.commonPath(key) + // Mismatched, don't do anything + if match.Len() < n.path.Len() { + return false, n, nil + } + // If the whole key matches, just delete the edge node + if match.Len() == key.Len() { + return true, nil, nil + } + + // Otherwise, we need to delete the child node + dirty, child, err := t.delete(n.child, key.LSBs(key, match.Len())) + if !dirty || err != nil { + return false, n, err + } + switch child := child.(type) { + case *edgeNode: + return true, &edgeNode{path: n.path, child: child.child, flags: newFlag()}, nil + default: + return true, &edgeNode{path: n.path, child: child, flags: newFlag()}, nil + } + case *binaryNode: + bit := key.MSB() + dirty, newNode, err := t.delete(n.children[bit], key.LSBs(key, 1)) + if !dirty || err != nil { + return false, n, err + } + n = n.copy() + n.flags = newFlag() + n.children[bit] = newNode + + if newNode != nil { + return true, n, nil + } + + // TODO(weiihann): combine this binary node with the child + + return true, n, nil + case valueNode: + return true, nil, nil + case nil: + return false, nil, nil + case hashNode: + panic("TODO(weiihann): implement me") + default: + panic(fmt.Sprintf("unknown node type: %T", n)) + } +} + +func (t *Trie) String() string { + if t.root == nil { + return "" + } + return t.root.String() +} + +func (t *Trie) FeltToKey(f *felt.Felt) BitArray { + var key BitArray + key.SetFelt(t.height, f) + return key +} diff --git a/core/trie2/trie_test.go b/core/trie2/trie_test.go new file mode 100644 index 0000000000..5f91ecc40e --- /dev/null +++ b/core/trie2/trie_test.go @@ -0,0 +1,89 @@ +package trie2 + +import ( + "math/rand" + "testing" + + "github.com/NethermindEth/juno/core/felt" + "github.com/stretchr/testify/require" +) + +func TestUpdate(t *testing.T) { + trie := NewTrie(251) + + key := new(felt.Felt).SetUint64(1) + value := new(felt.Felt).SetUint64(2) + err := trie.Update(key, value) + require.NoError(t, err) + + got, err := trie.Get(key) + require.NoError(t, err) + require.Equal(t, value, got) +} + +func TestUpdateRandom(t *testing.T) { + tr, records := randomTrie(t, 1000) + + for _, record := range records { + got, err := tr.Get(record.key) + require.NoError(t, err) + + if !got.Equal(record.value) { + t.Fatalf("expected %s, got %s", record.value, got) + } + } +} + +func Test4KeysTrieD(t *testing.T) { + tr, _ := build4KeysTrieD(t) + t.Log(tr.String()) +} + +type keyValue struct { + key *felt.Felt + value *felt.Felt +} + +func randomTrie(t testing.TB, n int) (*Trie, []*keyValue) { + rrand := rand.New(rand.NewSource(3)) + + tr := NewTrie(251) + records := make([]*keyValue, n) + + for i := 0; i < n; i++ { + key := new(felt.Felt).SetUint64(uint64(rrand.Uint32() + 1)) + records[i] = &keyValue{key: key, value: key} + err := tr.Update(key, key) + require.NoError(t, err) + } + + return tr, records +} + +func build4KeysTrieD(t *testing.T) (*Trie, []*keyValue) { + records := []*keyValue{ + {key: new(felt.Felt).SetUint64(1), value: new(felt.Felt).SetUint64(4)}, + {key: new(felt.Felt).SetUint64(4), value: new(felt.Felt).SetUint64(5)}, + {key: new(felt.Felt).SetUint64(6), value: new(felt.Felt).SetUint64(6)}, + {key: new(felt.Felt).SetUint64(7), value: new(felt.Felt).SetUint64(7)}, + } + + return buildTrie(t, records), records +} + +func buildTrie(t *testing.T, records []*keyValue) *Trie { + if len(records) == 0 { + t.Fatal("records must have at least one element") + } + + tempTrie := NewTrie(251) + + for _, record := range records { + err := tempTrie.Update(record.key, record.value) + t.Log("--------------------------------") + t.Log(tempTrie.String()) + require.NoError(t, err) + } + + return tempTrie +} From 42f4785182e54c2555ddb5a5939e89da8c4264bb Mon Sep 17 00:00:00 2001 From: weiihann Date: Tue, 31 Dec 2024 13:58:50 +0800 Subject: [PATCH 10/44] add docs --- core/trie2/bitarray.go | 26 +++++++++---- core/trie2/node.go | 1 + core/trie2/trie.go | 81 +++++++++++++++++++++++++++++------------ core/trie2/trie_test.go | 59 +++++++++++++++++++++++------- 4 files changed, 124 insertions(+), 43 deletions(-) diff --git a/core/trie2/bitarray.go b/core/trie2/bitarray.go index 2d408d550d..6fe218df9e 100644 --- a/core/trie2/bitarray.go +++ b/core/trie2/bitarray.go @@ -363,12 +363,12 @@ func (b *BitArray) Equal(x *BitArray) bool { // Returns true if bit n-th is set, where n = 0 is LSB. func (b *BitArray) IsBitSetFromLSB(n uint8) bool { - return b.BitSetFromLSB(n) == 1 + return b.BitFromLSB(n) == 1 } // Returns the bit value at position n, where n = 0 is LSB. // If n is out of bounds, returns 0. -func (b *BitArray) BitSetFromLSB(n uint8) uint8 { +func (b *BitArray) BitFromLSB(n uint8) uint8 { if n >= b.len { return 0 } @@ -381,26 +381,26 @@ func (b *BitArray) BitSetFromLSB(n uint8) uint8 { } func (b *BitArray) IsBitSet(n uint8) bool { - return b.BitSet(n) == 1 + return b.Bit(n) == 1 } // Returns the bit value at position n, where n = 0 is MSB. // If n is out of bounds, returns 0. -func (b *BitArray) BitSet(n uint8) uint8 { +func (b *BitArray) Bit(n uint8) uint8 { if n >= b.Len() { return 0 } - return b.BitSetFromLSB(b.Len() - n - 1) + return b.BitFromLSB(b.Len() - n - 1) } // Returns the bit value at the most significant bit func (b *BitArray) MSB() uint8 { - return b.BitSet(0) + return b.Bit(0) } func (b *BitArray) LSB() uint8 { - return b.BitSetFromLSB(0) + return b.BitFromLSB(0) } func (b *BitArray) IsEmpty() bool { @@ -479,6 +479,18 @@ func (b *BitArray) SetUint64(length uint8, data uint64) *BitArray { return b } +// Sets the bit array to a single bit. +func (b *BitArray) SetBit(bit bool) *BitArray { + b.len = 1 + if bit { + b.words[0] = 1 + } else { + b.words[0] = 0 + } + b.truncateToLength() + return b +} + // Returns the length of the encoded bit array in bytes. func (b *BitArray) EncodedLen() uint { return b.byteCount() + 1 diff --git a/core/trie2/node.go b/core/trie2/node.go index 5e072f7aae..2f2277b36b 100644 --- a/core/trie2/node.go +++ b/core/trie2/node.go @@ -139,6 +139,7 @@ func (n *edgeNode) pathMatches(key *BitArray) bool { return n.path.EqualMSBs(key) } +// Returns the common bits between the current node and the given key, starting from the most significant bit func (n *edgeNode) commonPath(key *BitArray) BitArray { var commonPath BitArray commonPath.CommonMSBs(n.path, key) diff --git a/core/trie2/trie.go b/core/trie2/trie.go index 2a60ced1fe..2e97770488 100644 --- a/core/trie2/trie.go +++ b/core/trie2/trie.go @@ -18,6 +18,8 @@ func NewTrie(height uint8) *Trie { return &Trie{height: height} } +// Modifies or inserts a key-value pair in the trie. +// If value is zero, the key is deleted from the trie. func (t *Trie) Update(key, value *felt.Felt) error { // if t.commited { // return ErrCommitted @@ -25,6 +27,9 @@ func (t *Trie) Update(key, value *felt.Felt) error { return t.update(key, value) } +// Retrieves the value associated with the given key. +// Returns felt.Zero if the key doesn't exist. +// May update the trie's internal structure if nodes need to be resolved. func (t *Trie) Get(key *felt.Felt) (*felt.Felt, error) { k := t.FeltToKey(key) // TODO(weiihann): get the value directly from the reader @@ -39,11 +44,17 @@ func (t *Trie) Get(key *felt.Felt) (*felt.Felt, error) { return val, err } +// Removes the given key from the trie. func (t *Trie) Delete(key *felt.Felt) error { - panic("TODO(weiihann): implement me") + k := t.FeltToKey(key) + _, n, err := t.delete(t.root, new(BitArray), &k) + if err != nil { + return err + } + t.root = n + return nil } -// Traverses the trie recursively to find the value that corresponds to the key. func (t *Trie) get(n node, key *BitArray) (*felt.Felt, node, bool, error) { switch n := n.(type) { case *edgeNode: @@ -75,10 +86,12 @@ func (t *Trie) get(n node, key *BitArray) (*felt.Felt, node, bool, error) { } } +// Modifies the trie by either inserting/updating a value or deleting a key. +// The operation is determined by whether the value is zero (delete) or non-zero (insert/update). func (t *Trie) update(key, value *felt.Felt) error { k := t.FeltToKey(key) if value.IsZero() { - _, n, err := t.delete(t.root, &k) + _, n, err := t.delete(t.root, new(BitArray), &k) if err != nil { return err } @@ -104,8 +117,8 @@ func (t *Trie) insert(n node, key *BitArray, value node) (bool, node, error) { switch n := n.(type) { case *edgeNode: - match := n.commonPath(key) - // If the whole key matches, just keep this edge node as it is and update the value + match := n.commonPath(key) // get the matching bits between the current node and the key + // If the match is the same as the path, just keep this edge node as it is and update the value if match.Len() == n.path.Len() { dirty, newNode, err := t.insert(n.child, key.LSBs(key, match.Len()), value) if !dirty || err != nil { @@ -117,15 +130,15 @@ func (t *Trie) insert(n node, key *BitArray, value node) (bool, node, error) { flags: newFlag(), }, nil } - // Otherwise branch out at the bit index where they differ + // Otherwise branch out at the bit position where they differ branch := &binaryNode{flags: newFlag()} var err error - _, branch.children[n.path.BitSet(match.Len())], err = t.insert(nil, new(BitArray).LSBs(n.path, match.Len()+1), n.child) + _, branch.children[n.path.Bit(match.Len())], err = t.insert(nil, new(BitArray).LSBs(n.path, match.Len()+1), n.child) if err != nil { return false, n, err } - _, branch.children[key.BitSet(match.Len())], err = t.insert(nil, new(BitArray).LSBs(key, match.Len()+1), value) + _, branch.children[key.Bit(match.Len())], err = t.insert(nil, new(BitArray).LSBs(key, match.Len()+1), value) if err != nil { return false, n, err } @@ -135,22 +148,27 @@ func (t *Trie) insert(n node, key *BitArray, value node) (bool, node, error) { return true, branch, nil } + // Otherwise, create a new edge node with the path being the common path and the branch as the child return true, &edgeNode{path: new(BitArray).MSBs(key, match.Len()), child: branch, flags: newFlag()}, nil case *binaryNode: + // Go to the child node based on the MSB of the key bit := key.MSB() dirty, newNode, err := t.insert(n.children[bit], new(BitArray).LSBs(key, 1), value) if !dirty || err != nil { return false, n, err } + // Replace the child node with the new node n = n.copy() n.flags = newFlag() n.children[bit] = newNode return true, n, nil case nil: + // We reach the end of the key, return the value node if key.IsEmpty() { return true, value, nil } + // Otherwise, return a new edge node with the path being the key and the value as the child return true, &edgeNode{path: key, child: value, flags: newFlag()}, nil case hashNode: panic("TODO(weiihann): implement me") @@ -159,7 +177,7 @@ func (t *Trie) insert(n node, key *BitArray, value node) (bool, node, error) { } } -func (t *Trie) delete(n node, key *BitArray) (bool, node, error) { +func (t *Trie) delete(n node, prefix, key *BitArray) (bool, node, error) { switch n := n.(type) { case *edgeNode: match := n.commonPath(key) @@ -167,25 +185,28 @@ func (t *Trie) delete(n node, key *BitArray) (bool, node, error) { if match.Len() < n.path.Len() { return false, n, nil } - // If the whole key matches, just delete the edge node + // If the whole key matches, remove the entire edge node if match.Len() == key.Len() { return true, nil, nil } - // Otherwise, we need to delete the child node - dirty, child, err := t.delete(n.child, key.LSBs(key, match.Len())) + // Otherwise, key is longer than current node path, so we need to delete the child. + // Child can never be nil because it's guaranteed that we have at least 2 other values in the subtrie. + keyPrefix := new(BitArray).MSBs(key, n.path.Len()) + dirty, child, err := t.delete(n.child, new(BitArray).Append(prefix, keyPrefix), key.LSBs(key, n.path.Len())) if !dirty || err != nil { return false, n, err } switch child := child.(type) { case *edgeNode: - return true, &edgeNode{path: n.path, child: child.child, flags: newFlag()}, nil + return true, &edgeNode{path: new(BitArray).Append(n.path, child.path), child: child.child, flags: newFlag()}, nil default: - return true, &edgeNode{path: n.path, child: child, flags: newFlag()}, nil + return true, &edgeNode{path: new(BitArray).Set(n.path), child: child, flags: newFlag()}, nil } case *binaryNode: bit := key.MSB() - dirty, newNode, err := t.delete(n.children[bit], key.LSBs(key, 1)) + keyPrefix := new(BitArray).MSBs(key, 1) + dirty, newNode, err := t.delete(n.children[bit], new(BitArray).Append(prefix, keyPrefix), key.LSBs(key, 1)) if !dirty || err != nil { return false, n, err } @@ -193,13 +214,25 @@ func (t *Trie) delete(n node, key *BitArray) (bool, node, error) { n.flags = newFlag() n.children[bit] = newNode + // If the child node is not nil, that means we still have 2 children in this binary node if newNode != nil { return true, n, nil } - // TODO(weiihann): combine this binary node with the child + // Otherwise, we need to combine this binary node with the other child + other := bit ^ 1 + bitPrefix := new(BitArray).SetBit(other == 1) + if cn, ok := n.children[other].(*edgeNode); ok { // other child is an edge node, append the bit prefix to the child path + return true, &edgeNode{ + path: new(BitArray).Append(bitPrefix, cn.path), + child: cn.child, + flags: newFlag(), + }, nil + } - return true, n, nil + // other child is not an edge node, create a new edge node with the bit prefix as the path + // containing the other child as the child + return true, &edgeNode{path: bitPrefix, child: n.children[other], flags: newFlag()}, nil case valueNode: return true, nil, nil case nil: @@ -211,15 +244,17 @@ func (t *Trie) delete(n node, key *BitArray) (bool, node, error) { } } +// Converts a Felt value into a BitArray representation suitable for +// use as a trie key with the specified height. +func (t *Trie) FeltToKey(f *felt.Felt) BitArray { + var key BitArray + key.SetFelt(t.height, f) + return key +} + func (t *Trie) String() string { if t.root == nil { return "" } return t.root.String() } - -func (t *Trie) FeltToKey(f *felt.Felt) BitArray { - var key BitArray - key.SetFelt(t.height, f) - return key -} diff --git a/core/trie2/trie_test.go b/core/trie2/trie_test.go index 5f91ecc40e..f3001b694b 100644 --- a/core/trie2/trie_test.go +++ b/core/trie2/trie_test.go @@ -9,16 +9,16 @@ import ( ) func TestUpdate(t *testing.T) { - trie := NewTrie(251) + tr, records := nonRandomTrie(t, 1000) - key := new(felt.Felt).SetUint64(1) - value := new(felt.Felt).SetUint64(2) - err := trie.Update(key, value) - require.NoError(t, err) + for _, record := range records { + err := tr.Update(record.key, record.value) + require.NoError(t, err) - got, err := trie.Get(key) - require.NoError(t, err) - require.Equal(t, value, got) + got, err := tr.Get(record.key) + require.NoError(t, err) + require.Equal(t, record.value, got) + } } func TestUpdateRandom(t *testing.T) { @@ -34,9 +34,30 @@ func TestUpdateRandom(t *testing.T) { } } -func Test4KeysTrieD(t *testing.T) { - tr, _ := build4KeysTrieD(t) - t.Log(tr.String()) +func TestDelete(t *testing.T) { + tr, records := nonRandomTrie(t, 10000) + + for _, record := range records { + err := tr.Delete(record.key) + require.NoError(t, err) + + got, err := tr.Get(record.key) + require.NoError(t, err) + require.Equal(t, got, &felt.Zero) + } +} + +func TestDeleteRandom(t *testing.T) { + tr, records := randomTrie(t, 10000) + + for i := len(records) - 1; i >= 0; i-- { + err := tr.Delete(records[i].key) + require.NoError(t, err) + + got, err := tr.Get(records[i].key) + require.NoError(t, err) + require.Equal(t, got, &felt.Zero) + } } type keyValue struct { @@ -44,6 +65,20 @@ type keyValue struct { value *felt.Felt } +func nonRandomTrie(t *testing.T, numKeys int) (*Trie, []*keyValue) { + tr := NewTrie(251) + records := make([]*keyValue, numKeys) + + for i := 1; i < numKeys+1; i++ { + key := new(felt.Felt).SetUint64(uint64(i)) + records[i-1] = &keyValue{key: key, value: key} + err := tr.Update(key, key) + require.NoError(t, err) + } + + return tr, records +} + func randomTrie(t testing.TB, n int) (*Trie, []*keyValue) { rrand := rand.New(rand.NewSource(3)) @@ -80,8 +115,6 @@ func buildTrie(t *testing.T, records []*keyValue) *Trie { for _, record := range records { err := tempTrie.Update(record.key, record.value) - t.Log("--------------------------------") - t.Log(tempTrie.String()) require.NoError(t, err) } From 63649cc5b1ec3471cf3128c9e034698198f0adfe Mon Sep 17 00:00:00 2001 From: weiihann Date: Thu, 2 Jan 2025 11:14:51 +0800 Subject: [PATCH 11/44] Add hasher --- core/trie2/hasher.go | 85 +++++++++++++++++++++++++++++ core/trie2/trie.go | 18 +++++- core/trie2/trie_test.go | 118 +++++++++++++++++++++++++++++++++++++++- 3 files changed, 216 insertions(+), 5 deletions(-) create mode 100644 core/trie2/hasher.go diff --git a/core/trie2/hasher.go b/core/trie2/hasher.go new file mode 100644 index 0000000000..7152cd12ec --- /dev/null +++ b/core/trie2/hasher.go @@ -0,0 +1,85 @@ +package trie2 + +import ( + "fmt" + "sync" + + "github.com/NethermindEth/juno/core/crypto" +) + +// hasher handles node hashing for the trie. It supports both sequential and parallel +// hashing modes. +type hasher struct { + hashFn crypto.HashFn // The hash function to use + parallel bool // Whether to hash binary node children in parallel +} + +func newHasher(hash crypto.HashFn, parallel bool) hasher { + return hasher{ + hashFn: hash, + parallel: parallel, + } +} + +// hash computes the hash of a node and returns both the hash node and a cached +// version of the original node. If the node already has a cached hash, returns +// that instead of recomputing. +func (h *hasher) hash(n node) (node, node) { + if hash, _ := n.cache(); hash != nil { + return hash, n + } + + switch n := n.(type) { + case *edgeNode: + collapsed, cached := h.hashEdgeChild(n) + hn := &hashNode{Felt: collapsed.hash(h.hashFn)} + cached.flags.hash = hn + return hn, cached + case *binaryNode: + collapsed, cached := h.hashBinaryChildren(n) + hn := &hashNode{Felt: collapsed.hash(h.hashFn)} + cached.flags.hash = hn + return hn, cached + case valueNode, hashNode: + return n, n + default: + panic(fmt.Sprintf("unknown node type: %T", n)) + } +} + +func (h *hasher) hashEdgeChild(n *edgeNode) (collapsed, cached *edgeNode) { + collapsed, cached = n.copy(), n.copy() + + switch n.child.(type) { + case *edgeNode, *binaryNode: + collapsed.child, cached.child = h.hash(n.child) + } + + return collapsed, cached +} + +func (h *hasher) hashBinaryChildren(n *binaryNode) (collapsed, cached *binaryNode) { + collapsed, cached = n.copy(), n.copy() + + if h.parallel { // TODO(weiihann): double check this parallel strategy + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + collapsed.children[0], cached.children[0] = h.hash(n.children[0]) + }() + + go func() { + defer wg.Done() + collapsed.children[1], cached.children[1] = h.hash(n.children[1]) + }() + + wg.Wait() + } else { + collapsed.children[0], cached.children[0] = h.hash(n.children[0]) + collapsed.children[1], cached.children[1] = h.hash(n.children[1]) + } + + return collapsed, cached +} diff --git a/core/trie2/trie.go b/core/trie2/trie.go index 2e97770488..cff79c3a3a 100644 --- a/core/trie2/trie.go +++ b/core/trie2/trie.go @@ -3,6 +3,7 @@ package trie2 import ( "fmt" + "github.com/NethermindEth/juno/core/crypto" "github.com/NethermindEth/juno/core/felt" ) @@ -11,11 +12,12 @@ type Trie struct { root node reader interface{} // TODO(weiihann): implement reader // committed bool + hashFn crypto.HashFn } // TODO(weiihann): implement this -func NewTrie(height uint8) *Trie { - return &Trie{height: height} +func NewTrie(height uint8, hashFn crypto.HashFn) *Trie { + return &Trie{height: height, hashFn: hashFn} } // Modifies or inserts a key-value pair in the trie. @@ -55,6 +57,13 @@ func (t *Trie) Delete(key *felt.Felt) error { return nil } +// Returns the root hash of the trie. Calling this method will also cache the hash of each node in the trie. +func (t *Trie) Hash() *felt.Felt { + hash, cached := t.hashRoot() + t.root = cached + return hash.(*hashNode).Felt +} + func (t *Trie) get(n node, key *BitArray) (*felt.Felt, node, bool, error) { switch n := n.(type) { case *edgeNode: @@ -244,6 +253,11 @@ func (t *Trie) delete(n node, prefix, key *BitArray) (bool, node, error) { } } +func (t *Trie) hashRoot() (node, node) { + h := newHasher(t.hashFn, false) // TODO(weiihann): handle parallel hashing + return h.hash(t.root) +} + // Converts a Felt value into a BitArray representation suitable for // use as a trie key with the specified height. func (t *Trie) FeltToKey(f *felt.Felt) BitArray { diff --git a/core/trie2/trie_test.go b/core/trie2/trie_test.go index f3001b694b..6361eba4fa 100644 --- a/core/trie2/trie_test.go +++ b/core/trie2/trie_test.go @@ -4,7 +4,9 @@ import ( "math/rand" "testing" + "github.com/NethermindEth/juno/core/crypto" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/utils" "github.com/stretchr/testify/require" ) @@ -60,13 +62,123 @@ func TestDeleteRandom(t *testing.T) { } } +// The expected hashes are taken from Pathfinder's tests +func TestHash(t *testing.T) { + t.Run("one leaf", func(t *testing.T) { + tr := NewTrie(251, crypto.Pedersen) + err := tr.Update(new(felt.Felt).SetUint64(1), new(felt.Felt).SetUint64(2)) + require.NoError(t, err) + hash := tr.Hash() + + expected := "0x2ab889bd35e684623df9b4ea4a4a1f6d9e0ef39b67c1293b8a89dd17e351330" + require.Equal(t, expected, hash.String(), "expected %s, got %s", expected, hash.String()) + }) + + t.Run("two leaves", func(t *testing.T) { + tr := NewTrie(251, crypto.Pedersen) + err := tr.Update(new(felt.Felt).SetUint64(0), new(felt.Felt).SetUint64(2)) + require.NoError(t, err) + err = tr.Update(new(felt.Felt).SetUint64(1), new(felt.Felt).SetUint64(3)) + require.NoError(t, err) + root := tr.Hash() + + expected := "0x79acdb7a3d78052114e21458e8c4aecb9d781ce79308193c61a2f3f84439f66" + require.Equal(t, expected, root.String(), "expected %s, got %s", expected, root.String()) + }) + + t.Run("three leaves", func(t *testing.T) { + tr := NewTrie(251, crypto.Pedersen) + + keys := []*felt.Felt{ + new(felt.Felt).SetUint64(16), + new(felt.Felt).SetUint64(17), + new(felt.Felt).SetUint64(19), + } + + vals := []*felt.Felt{ + new(felt.Felt).SetUint64(10), + new(felt.Felt).SetUint64(11), + new(felt.Felt).SetUint64(12), + } + + for i := range keys { + err := tr.Update(keys[i], vals[i]) + require.NoError(t, err) + } + + expected := "0x7e2184e9e1a651fd556b42b4ff10e44a71b1709f641e0203dc8bd2b528e5e81" + root := tr.Hash() + require.Equal(t, expected, root.String(), "expected %s, got %s", expected, root.String()) + }) + + t.Run("double binary", func(t *testing.T) { + // (249,0,x3) + // | + // (0, 0, x3) + // / \ + // (0,0,x1) (1, 1, 5) + // / \ | + // (2) (3) (5) + + keys := []*felt.Felt{ + new(felt.Felt).SetUint64(0), + new(felt.Felt).SetUint64(1), + new(felt.Felt).SetUint64(3), + } + + vals := []*felt.Felt{ + new(felt.Felt).SetUint64(2), + new(felt.Felt).SetUint64(3), + new(felt.Felt).SetUint64(5), + } + + tr := NewTrie(251, crypto.Pedersen) + for i := range keys { + err := tr.Update(keys[i], vals[i]) + require.NoError(t, err) + } + + expected := "0x6a316f09913454294c6b6751dea8449bc2e235fdc04b2ab0e1ac7fea25cc34f" + root := tr.Hash() + require.Equal(t, expected, root.String(), "expected %s, got %s", expected, root.String()) + }) + + t.Run("binary root", func(t *testing.T) { + // (0, 0, x) + // / \ + // (250, 0, cc) (250, 11111.., dd) + // | | + // (cc) (dd) + + keys := []*felt.Felt{ + new(felt.Felt).SetUint64(0), + utils.HexToFelt(t, "0x7ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"), + } + + vals := []*felt.Felt{ + utils.HexToFelt(t, "0xcc"), + utils.HexToFelt(t, "0xdd"), + } + + tr := NewTrie(251, crypto.Pedersen) + for i := range keys { + err := tr.Update(keys[i], vals[i]) + require.NoError(t, err) + } + + expected := "0x542ced3b6aeef48339129a03e051693eff6a566d3a0a94035fa16ab610dc9e2" + root := tr.Hash() + require.Equal(t, expected, root.String(), "expected %s, got %s", expected, root.String()) + }) +} + type keyValue struct { key *felt.Felt value *felt.Felt } func nonRandomTrie(t *testing.T, numKeys int) (*Trie, []*keyValue) { - tr := NewTrie(251) + tr := NewTrie(251, crypto.Pedersen) records := make([]*keyValue, numKeys) for i := 1; i < numKeys+1; i++ { @@ -82,7 +194,7 @@ func nonRandomTrie(t *testing.T, numKeys int) (*Trie, []*keyValue) { func randomTrie(t testing.TB, n int) (*Trie, []*keyValue) { rrand := rand.New(rand.NewSource(3)) - tr := NewTrie(251) + tr := NewTrie(251, crypto.Pedersen) records := make([]*keyValue, n) for i := 0; i < n; i++ { @@ -111,7 +223,7 @@ func buildTrie(t *testing.T, records []*keyValue) *Trie { t.Fatal("records must have at least one element") } - tempTrie := NewTrie(251) + tempTrie := NewTrie(251, crypto.Pedersen) for _, record := range records { err := tempTrie.Update(record.key, record.value) From a391cbf6f764d4e4ae034fcc3be2bdcb5c608fd1 Mon Sep 17 00:00:00 2001 From: weiihann Date: Tue, 7 Jan 2025 13:27:46 +0800 Subject: [PATCH 12/44] add tracer --- core/trie2/tracer.go | 45 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 core/trie2/tracer.go diff --git a/core/trie2/tracer.go b/core/trie2/tracer.go new file mode 100644 index 0000000000..f6c67f0491 --- /dev/null +++ b/core/trie2/tracer.go @@ -0,0 +1,45 @@ +package trie2 + +import ( + "maps" +) + +type tracer struct { + inserts map[BitArray]struct{} + deletes map[BitArray]struct{} +} + +func newTracer() *tracer { + return &tracer{ + inserts: make(map[BitArray]struct{}), + deletes: make(map[BitArray]struct{}), + } +} + +func (t *tracer) onInsert(key *BitArray) { + t.inserts[*key] = struct{}{} +} + +func (t *tracer) onDelete(key *BitArray) { + t.deletes[*key] = struct{}{} +} + +func (t *tracer) reset() { + t.inserts = make(map[BitArray]struct{}) + t.deletes = make(map[BitArray]struct{}) +} + +func (t *tracer) copy() *tracer { + return &tracer{ + inserts: maps.Clone(t.inserts), + deletes: maps.Clone(t.deletes), + } +} + +func (t *tracer) deletedNodes() []BitArray { + keys := make([]BitArray, 0, len(t.deletes)) + for k := range t.deletes { + keys = append(keys, k) + } + return keys +} From 1b44cdc31ccba0b5cb5ee0b3cc20e817f438dd52 Mon Sep 17 00:00:00 2001 From: weiihann Date: Thu, 9 Jan 2025 15:36:13 +0800 Subject: [PATCH 13/44] Add package `trienode` --- core/trie2/trienode/node.go | 22 ++++ core/trie2/trienode/nodeset.go | 89 +++++++++++++ core/trie2/trienode/trienode_test.go | 186 +++++++++++++++++++++++++++ 3 files changed, 297 insertions(+) create mode 100644 core/trie2/trienode/node.go create mode 100644 core/trie2/trienode/nodeset.go create mode 100644 core/trie2/trienode/trienode_test.go diff --git a/core/trie2/trienode/node.go b/core/trie2/trienode/node.go new file mode 100644 index 0000000000..f601a65006 --- /dev/null +++ b/core/trie2/trienode/node.go @@ -0,0 +1,22 @@ +package trienode + +import ( + "github.com/NethermindEth/juno/core/felt" +) + +type Node struct { + blob []byte + hash felt.Felt +} + +func (r *Node) IsDeleted() bool { + return len(r.blob) == 0 +} + +func NewNode(hash felt.Felt, blob []byte) *Node { + return &Node{hash: hash, blob: blob} +} + +func NewDeleted() *Node { + return &Node{hash: felt.Felt{}, blob: nil} +} diff --git a/core/trie2/trienode/nodeset.go b/core/trie2/trienode/nodeset.go new file mode 100644 index 0000000000..c596ba6ee1 --- /dev/null +++ b/core/trie2/trienode/nodeset.go @@ -0,0 +1,89 @@ +package trienode + +import ( + "fmt" + "maps" + "sort" + + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/trie2/utils" +) + +type path = utils.BitArray + +type NodeSet struct { + Owner felt.Felt + Nodes map[path]*Node + updates int + deletes int +} + +func NewNodeSet(owner felt.Felt) *NodeSet { + return &NodeSet{Owner: owner, Nodes: make(map[path]*Node)} +} + +func (ns *NodeSet) Add(key path, node *Node) { + if node.IsDeleted() { + ns.deletes += 1 + } else { + ns.updates += 1 + } + ns.Nodes[key] = node +} + +func (ns *NodeSet) ForEach(desc bool, callback func(key path, node *Node)) { + paths := make([]path, 0, len(ns.Nodes)) + for key := range ns.Nodes { + paths = append(paths, key) + } + + if desc { // longest path first + sort.Slice(paths, func(i, j int) bool { + return paths[i].Cmp(&paths[j]) > 0 + }) + } else { + sort.Slice(paths, func(i, j int) bool { + return paths[i].Cmp(&paths[j]) < 0 + }) + } + + for _, key := range paths { + callback(key, ns.Nodes[key]) + } +} + +func (ns *NodeSet) MergeSet(other *NodeSet) error { + if ns.Owner != other.Owner { + return fmt.Errorf("cannot merge node sets with different owners %x-%x", ns.Owner, other.Owner) + } + maps.Copy(ns.Nodes, other.Nodes) + ns.updates += other.updates + ns.deletes += other.deletes + return nil +} + +func (ns *NodeSet) Merge(owner felt.Felt, other map[path]*Node) error { + if ns.Owner != owner { + return fmt.Errorf("cannot merge node sets with different owners %x-%x", ns.Owner, owner) + } + + for path, node := range other { + prev, ok := ns.Nodes[path] + if ok { // node already exists, revoke the counter first + if prev.IsDeleted() { + ns.deletes -= 1 + } else { + ns.updates -= 1 + } + } + // overwrite the existing node (if it exists) + if node.IsDeleted() { + ns.deletes += 1 + } else { + ns.updates += 1 + } + ns.Nodes[path] = node + } + + return nil +} diff --git a/core/trie2/trienode/trienode_test.go b/core/trie2/trienode/trienode_test.go new file mode 100644 index 0000000000..d757559fa4 --- /dev/null +++ b/core/trie2/trienode/trienode_test.go @@ -0,0 +1,186 @@ +package trienode + +import ( + "testing" + + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/trie2/utils" + "github.com/stretchr/testify/require" +) + +func TestNodeSet(t *testing.T) { + t.Run("new node set", func(t *testing.T) { + ns := NewNodeSet(felt.Zero) + require.Equal(t, felt.Zero, ns.Owner) + require.Empty(t, ns.Nodes) + require.Zero(t, ns.updates) + require.Zero(t, ns.deletes) + }) + + t.Run("add nodes", func(t *testing.T) { + ns := NewNodeSet(felt.Zero) + + // Add a regular node + key1 := utils.NewBitArray(8, 0xFF) + node1 := NewNode(felt.Zero, []byte{1, 2, 3}) + ns.Add(key1, node1) + require.Equal(t, 1, ns.updates) + require.Equal(t, 0, ns.deletes) + + // Add a deleted node + key2 := utils.NewBitArray(8, 0xAA) + node2 := NewDeleted() + ns.Add(key2, node2) + require.Equal(t, 1, ns.updates) + require.Equal(t, 1, ns.deletes) + + // Verify nodes are stored correctly + require.Equal(t, node1, ns.Nodes[key1]) + require.Equal(t, node2, ns.Nodes[key2]) + }) + + t.Run("merge sets", func(t *testing.T) { + ns1 := NewNodeSet(felt.Zero) + ns2 := NewNodeSet(felt.Zero) + + // Add nodes to first set + key1 := utils.NewBitArray(8, 0xFF) + node1 := NewNode(felt.Zero, []byte{1, 2, 3}) + ns1.Add(key1, node1) + + // Add nodes to second set + key2 := utils.NewBitArray(8, 0xAA) + node2 := NewDeleted() + ns2.Add(key2, node2) + + // Merge sets + err := ns1.MergeSet(ns2) + require.NoError(t, err) + + // Verify merged state + require.Equal(t, 2, len(ns1.Nodes)) + require.Equal(t, node1, ns1.Nodes[key1]) + require.Equal(t, node2, ns1.Nodes[key2]) + require.Equal(t, 1, ns1.updates) + require.Equal(t, 1, ns1.deletes) + }) + + t.Run("merge with different owners", func(t *testing.T) { + owner1 := new(felt.Felt).SetUint64(123) + owner2 := new(felt.Felt).SetUint64(456) + ns1 := NewNodeSet(*owner1) + ns2 := NewNodeSet(*owner2) + + err := ns1.MergeSet(ns2) + require.Error(t, err) + }) + + t.Run("merge map", func(t *testing.T) { + owner := new(felt.Felt).SetUint64(123) + ns := NewNodeSet(*owner) + + // Create a map to merge + nodes := make(map[utils.BitArray]*Node) + key1 := utils.NewBitArray(8, 0xFF) + node1 := NewNode(felt.Zero, []byte{1, 2, 3}) + nodes[key1] = node1 + + // Merge map + err := ns.Merge(*owner, nodes) + require.NoError(t, err) + + // Verify merged state + require.Equal(t, 1, len(ns.Nodes)) + require.Equal(t, node1, ns.Nodes[key1]) + require.Equal(t, 1, ns.updates) + require.Equal(t, 0, ns.deletes) + }) + + t.Run("foreach", func(t *testing.T) { + ns := NewNodeSet(felt.Zero) + + // Add nodes in random order + keys := []utils.BitArray{ + utils.NewBitArray(8, 0xFF), + utils.NewBitArray(8, 0xAA), + utils.NewBitArray(8, 0x55), + } + for _, key := range keys { + ns.Add(key, NewNode(felt.Zero, []byte{1})) + } + + t.Run("ascending order", func(t *testing.T) { + var visited []utils.BitArray + ns.ForEach(false, func(key utils.BitArray, node *Node) { + visited = append(visited, key) + }) + + // Verify ascending order + for i := 1; i < len(visited); i++ { + require.True(t, visited[i-1].Cmp(&visited[i]) < 0) + } + }) + + t.Run("descending order", func(t *testing.T) { + var visited []utils.BitArray + ns.ForEach(true, func(key utils.BitArray, node *Node) { + visited = append(visited, key) + }) + + // Verify descending order + for i := 1; i < len(visited); i++ { + require.True(t, visited[i-1].Cmp(&visited[i]) > 0) + } + }) + }) +} + +func TestNode(t *testing.T) { + t.Run("new node", func(t *testing.T) { + hash := new(felt.Felt).SetUint64(123) + blob := []byte{1, 2, 3} + node := NewNode(*hash, blob) + + require.Equal(t, *hash, node.hash) + require.Equal(t, blob, node.blob) + require.False(t, node.IsDeleted()) + }) + + t.Run("new deleted node", func(t *testing.T) { + node := NewDeleted() + require.True(t, node.IsDeleted()) + require.Equal(t, felt.Zero, node.hash) + require.Nil(t, node.blob) + }) + + t.Run("is deleted", func(t *testing.T) { + tests := []struct { + name string + blob []byte + expected bool + }{ + { + name: "nil blob", + blob: nil, + expected: true, + }, + { + name: "empty blob", + blob: []byte{}, + expected: true, + }, + { + name: "non-empty blob", + blob: []byte{1, 2, 3}, + expected: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + node := NewNode(felt.Zero, test.blob) + require.Equal(t, test.expected, node.IsDeleted()) + }) + } + }) +} From 47466e212b5fd0bd2f9cbca4354d5829d6b31387 Mon Sep 17 00:00:00 2001 From: weiihann Date: Thu, 9 Jan 2025 15:42:11 +0800 Subject: [PATCH 14/44] Add collector --- core/trie2/collector.go | 116 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 116 insertions(+) create mode 100644 core/trie2/collector.go diff --git a/core/trie2/collector.go b/core/trie2/collector.go new file mode 100644 index 0000000000..aed7019595 --- /dev/null +++ b/core/trie2/collector.go @@ -0,0 +1,116 @@ +package trie2 + +import ( + "fmt" + "sync" +) + +import ( + "github.com/NethermindEth/juno/core/trie2/trienode" +) + +// Used as a tool to collect all dirty nodes in a NodeSet +type collector struct { + nodes *trienode.NodeSet +} + +func newCollector(nodes *trienode.NodeSet) *collector { + return &collector{nodes: nodes} +} + +// Collects the nodes in the node set and collapses a node into a hash node +func (c *collector) Collect(n node, parallel bool) hashNode { + return c.collect(new(Path), n, parallel).(hashNode) +} + +func (c *collector) collect(path *Path, n node, parallel bool) node { + // This path has not been modified, just return the cache + hash, dirty := n.cache() + if hash != nil && !dirty { + return hash + } + + // Collect children and then parent + switch cn := n.(type) { + case *edgeNode: + collapsed := cn.copy() + + // If the child is a binary node, recurse into it. + // Otherwise, it can only be a hashNode or valueNode. + // Combination of edge (parent) + edge (child) is not possible. + if _, ok := cn.child.(*binaryNode); ok { + collapsed.child = c.collect(new(Path).Append(path, cn.path), cn.child, parallel) + } + return c.store(path, collapsed) + case *binaryNode: + collapsed := cn.copy() + collapsed.children = c.collectChildren(path, cn, parallel) + return c.store(path, collapsed) + case hashNode: + return cn + case valueNode: // each leaf node is stored as a single entry in the node set + return c.store(path, cn) + default: + panic(fmt.Sprintf("unknown node type: %T", cn)) + } +} + +// Collects the children of a binary node, may apply parallel processing if configured +func (c *collector) collectChildren(path *Path, n *binaryNode, parallel bool) [2]node { + children := [2]node{} + + // Helper function to process a single child + processChild := func(i int) { + child := n.children[i] + // Return early if it's already a hash node + if hn, ok := child.(*hashNode); ok { + children[i] = hn + return + } + + // Create child path + childPath := new(Path).Append(path, new(Path).SetBit(uint8(i))) + + if !parallel { + children[i] = c.collect(childPath, child, parallel) + return + } + + // Parallel processing + childSet := trienode.NewNodeSet(c.nodes.Owner) + childCollector := newCollector(childSet) + children[i] = childCollector.collect(childPath, child, parallel) + c.nodes.MergeSet(childSet) //nolint:errcheck // guaranteed to succeed because same owner + } + + if !parallel { + // Sequential processing + processChild(0) + processChild(1) + return children + } + + // Parallel processing + var wg sync.WaitGroup + var mu sync.Mutex + + for i := 0; i < 2; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + mu.Lock() + processChild(idx) + mu.Unlock() + }(i) + } + wg.Wait() + + return children +} + +// Stores the node in the node set and returns the hash node +func (c *collector) store(path *Path, n node) node { + hash, _ := n.cache() + c.nodes.Add(*path, trienode.NewNode(hash.Felt, nodeToBytes(n))) + return hash +} From 47c5e9b598c8f1e8cd170d2c391302e39ee0c0cd Mon Sep 17 00:00:00 2001 From: weiihann Date: Thu, 9 Jan 2025 15:43:03 +0800 Subject: [PATCH 15/44] Add Trie.Commit --- core/trie2/hasher.go | 4 +- core/trie2/node.go | 58 ++--------- core/trie2/node_enc.go | 66 +++++++++++++ core/trie2/reader.go | 24 +++++ core/trie2/tracer.go | 32 +++--- core/trie2/trie.go | 152 +++++++++++++++++++++-------- core/trie2/trie_test.go | 8 ++ core/trie2/types.go | 8 ++ core/trie2/{ => utils}/bitarray.go | 56 +++++++++-- 9 files changed, 295 insertions(+), 113 deletions(-) create mode 100644 core/trie2/node_enc.go create mode 100644 core/trie2/reader.go create mode 100644 core/trie2/types.go rename core/trie2/{ => utils}/bitarray.go (92%) diff --git a/core/trie2/hasher.go b/core/trie2/hasher.go index 7152cd12ec..c3f255dbc8 100644 --- a/core/trie2/hasher.go +++ b/core/trie2/hasher.go @@ -32,12 +32,12 @@ func (h *hasher) hash(n node) (node, node) { switch n := n.(type) { case *edgeNode: collapsed, cached := h.hashEdgeChild(n) - hn := &hashNode{Felt: collapsed.hash(h.hashFn)} + hn := &hashNode{Felt: *collapsed.hash(h.hashFn)} cached.flags.hash = hn return hn, cached case *binaryNode: collapsed, cached := h.hashBinaryChildren(n) - hn := &hashNode{Felt: collapsed.hash(h.hashFn)} + hn := &hashNode{Felt: *collapsed.hash(h.hashFn)} cached.flags.hash = hn return hn, cached case valueNode, hashNode: diff --git a/core/trie2/node.go b/core/trie2/node.go index 2f2277b36b..d06b657b10 100644 --- a/core/trie2/node.go +++ b/core/trie2/node.go @@ -30,11 +30,11 @@ type ( } edgeNode struct { child node - path *BitArray + path *Path flags nodeFlag } - hashNode struct{ *felt.Felt } - valueNode struct{ *felt.Felt } + hashNode struct{ felt.Felt } + valueNode struct{ felt.Felt } ) const ( @@ -57,14 +57,14 @@ func (n *binaryNode) hash(hf crypto.HashFn) *felt.Felt { func (n *edgeNode) hash(hf crypto.HashFn) *felt.Felt { var length [32]byte - length[31] = n.path.len + length[31] = n.path.Len() pathFelt := n.path.Felt() lengthFelt := new(felt.Felt).SetBytes(length[:]) return new(felt.Felt).Add(hf(n.child.hash(hf), &pathFelt), lengthFelt) } -func (n hashNode) hash(crypto.HashFn) *felt.Felt { return n.Felt } -func (n valueNode) hash(crypto.HashFn) *felt.Felt { return n.Felt } +func (n hashNode) hash(crypto.HashFn) *felt.Felt { return &n.Felt } +func (n valueNode) hash(crypto.HashFn) *felt.Felt { return &n.Felt } func (n *binaryNode) cache() (*hashNode, bool) { return n.flags.hash, n.flags.dirty } func (n *edgeNode) cache() (*hashNode, bool) { return n.flags.hash, n.flags.dirty } @@ -91,57 +91,17 @@ func (n valueNode) String() string { return fmt.Sprintf("Value(%s)", n.Felt.String()) } -func (n *binaryNode) write(buf *bytes.Buffer) error { - if err := n.children[0].write(buf); err != nil { - return err - } - - if err := n.children[1].write(buf); err != nil { - return err - } - - return nil -} - -func (n *edgeNode) write(buf *bytes.Buffer) error { - if _, err := n.path.Write(buf); err != nil { - return err - } - - if err := n.child.write(buf); err != nil { - return err - } - - return nil -} - -func (n hashNode) write(buf *bytes.Buffer) error { - if _, err := buf.Write(n.Felt.Marshal()); err != nil { - return err - } - - return nil -} - -func (n valueNode) write(buf *bytes.Buffer) error { - if _, err := buf.Write(n.Felt.Marshal()); err != nil { - return err - } - - return nil -} - // TODO(weiihann): check if we want to return a pointer or a value func (n *binaryNode) copy() *binaryNode { cpy := *n; return &cpy } func (n *edgeNode) copy() *edgeNode { cpy := *n; return &cpy } -func (n *edgeNode) pathMatches(key *BitArray) bool { +func (n *edgeNode) pathMatches(key *Path) bool { return n.path.EqualMSBs(key) } // Returns the common bits between the current node and the given key, starting from the most significant bit -func (n *edgeNode) commonPath(key *BitArray) BitArray { - var commonPath BitArray +func (n *edgeNode) commonPath(key *Path) Path { + var commonPath Path commonPath.CommonMSBs(n.path, key) return commonPath } diff --git a/core/trie2/node_enc.go b/core/trie2/node_enc.go new file mode 100644 index 0000000000..f3de40bce7 --- /dev/null +++ b/core/trie2/node_enc.go @@ -0,0 +1,66 @@ +package trie2 + +import ( + "bytes" + "sync" +) + +var bufferPool = sync.Pool{ + New: func() interface{} { + return new(bytes.Buffer) + }, +} + +func (n *binaryNode) write(buf *bytes.Buffer) error { + if err := n.children[0].write(buf); err != nil { + return err + } + + if err := n.children[1].write(buf); err != nil { + return err + } + + return nil +} + +func (n *edgeNode) write(buf *bytes.Buffer) error { + if _, err := n.path.Write(buf); err != nil { + return err + } + + if err := n.child.write(buf); err != nil { + return err + } + + return nil +} + +func (n hashNode) write(buf *bytes.Buffer) error { + if _, err := buf.Write(n.Felt.Marshal()); err != nil { + return err + } + + return nil +} + +func (n valueNode) write(buf *bytes.Buffer) error { + if _, err := buf.Write(n.Felt.Marshal()); err != nil { + return err + } + + return nil +} + +func nodeToBytes(n node) []byte { + buf := bufferPool.Get().(*bytes.Buffer) + buf.Reset() + defer func() { + buf.Reset() + bufferPool.Put(buf) + }() + + if err := n.write(buf); err != nil { + panic(err) + } + return buf.Bytes() +} diff --git a/core/trie2/reader.go b/core/trie2/reader.go new file mode 100644 index 0000000000..a804daa5be --- /dev/null +++ b/core/trie2/reader.go @@ -0,0 +1,24 @@ +package trie2 + +import ( + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/db" +) + +type NodeReader interface { + Node(owner *felt.Felt, path *Path, hash *felt.Felt) ([]byte, error) +} + +type nodeReader struct { + txn db.Transaction + trieType TrieType +} + +func (n *nodeReader) Node(owner *felt.Felt, path *Path, hash *felt.Felt) ([]byte, error) { + panic("implement me") +} + +type trieReader struct { + owner *felt.Felt + reader NodeReader +} diff --git a/core/trie2/tracer.go b/core/trie2/tracer.go index f6c67f0491..81b253718d 100644 --- a/core/trie2/tracer.go +++ b/core/trie2/tracer.go @@ -5,28 +5,36 @@ import ( ) type tracer struct { - inserts map[BitArray]struct{} - deletes map[BitArray]struct{} + inserts map[Path]struct{} + deletes map[Path]struct{} } func newTracer() *tracer { return &tracer{ - inserts: make(map[BitArray]struct{}), - deletes: make(map[BitArray]struct{}), + inserts: make(map[Path]struct{}), + deletes: make(map[Path]struct{}), } } -func (t *tracer) onInsert(key *BitArray) { - t.inserts[*key] = struct{}{} +func (t *tracer) onInsert(key *Path) { + k := *key + if _, present := t.deletes[k]; present { + return + } + t.inserts[k] = struct{}{} } -func (t *tracer) onDelete(key *BitArray) { - t.deletes[*key] = struct{}{} +func (t *tracer) onDelete(key *Path) { + k := *key + if _, present := t.inserts[k]; present { + return + } + t.deletes[k] = struct{}{} } func (t *tracer) reset() { - t.inserts = make(map[BitArray]struct{}) - t.deletes = make(map[BitArray]struct{}) + t.inserts = make(map[Path]struct{}) + t.deletes = make(map[Path]struct{}) } func (t *tracer) copy() *tracer { @@ -36,8 +44,8 @@ func (t *tracer) copy() *tracer { } } -func (t *tracer) deletedNodes() []BitArray { - keys := make([]BitArray, 0, len(t.deletes)) +func (t *tracer) deletedNodes() []Path { + keys := make([]Path, 0, len(t.deletes)) for k := range t.deletes { keys = append(keys, k) } diff --git a/core/trie2/trie.go b/core/trie2/trie.go index cff79c3a3a..b51384714e 100644 --- a/core/trie2/trie.go +++ b/core/trie2/trie.go @@ -5,19 +5,37 @@ import ( "github.com/NethermindEth/juno/core/crypto" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/trie2/trienode" + "github.com/NethermindEth/juno/core/trie2/utils" + "github.com/NethermindEth/juno/db" ) +type Path = utils.BitArray + type Trie struct { + txn db.Transaction + owner felt.Felt height uint8 root node reader interface{} // TODO(weiihann): implement reader // committed bool hashFn crypto.HashFn + tracer *tracer + + // Tracks the number of leaves inserted since the last hashing operation + pendingHashes int + + // Tracks the total number of updates (inserts/deletes) since the last commit + pendingUpdates int } // TODO(weiihann): implement this func NewTrie(height uint8, hashFn crypto.HashFn) *Trie { - return &Trie{height: height, hashFn: hashFn} + return &Trie{ + height: height, + hashFn: hashFn, + tracer: newTracer(), + } } // Modifies or inserts a key-value pair in the trie. @@ -26,16 +44,21 @@ func (t *Trie) Update(key, value *felt.Felt) error { // if t.commited { // return ErrCommitted // } - return t.update(key, value) + if err := t.update(key, value); err != nil { + return err + } + t.pendingUpdates++ + t.pendingHashes++ + return nil } // Retrieves the value associated with the given key. // Returns felt.Zero if the key doesn't exist. // May update the trie's internal structure if nodes need to be resolved. func (t *Trie) Get(key *felt.Felt) (*felt.Felt, error) { - k := t.FeltToKey(key) + k := t.FeltToPath(key) // TODO(weiihann): get the value directly from the reader - val, root, didResolve, err := t.get(t.root, &k) + val, root, didResolve, err := t.get(t.root, new(Path), &k) // In Starknet, a non-existent key is mapped to felt.Zero if val == nil { val = &felt.Zero @@ -48,29 +71,59 @@ func (t *Trie) Get(key *felt.Felt) (*felt.Felt, error) { // Removes the given key from the trie. func (t *Trie) Delete(key *felt.Felt) error { - k := t.FeltToKey(key) - _, n, err := t.delete(t.root, new(BitArray), &k) + k := t.FeltToPath(key) + _, n, err := t.delete(t.root, new(Path), &k) if err != nil { return err } t.root = n + t.pendingUpdates++ + t.pendingHashes++ return nil } // Returns the root hash of the trie. Calling this method will also cache the hash of each node in the trie. -func (t *Trie) Hash() *felt.Felt { +func (t *Trie) Hash() felt.Felt { hash, cached := t.hashRoot() t.root = cached return hash.(*hashNode).Felt } -func (t *Trie) get(n node, key *BitArray) (*felt.Felt, node, bool, error) { +func (t *Trie) Commit() felt.Felt { + rootHash := t.Hash() + if hashedNode, dirty := t.root.cache(); !dirty { + t.root = hashedNode + return rootHash + } + + nodes := trienode.NewNodeSet(t.owner) + for _, Path := range t.tracer.deletedNodes() { + nodes.Add(Path, trienode.NewDeleted()) + } + + t.root = newCollector(nodes).Collect(t.root, t.pendingUpdates > 100) // TODO(weiihann): 100 is arbitrary + t.pendingUpdates = 0 + return rootHash +} + +func (t *Trie) Copy() *Trie { + return &Trie{ + txn: t.txn, + owner: t.owner, + height: t.height, + root: t.root, + hashFn: t.hashFn, + tracer: t.tracer.copy(), + } +} + +func (t *Trie) get(n node, prefix, key *Path) (*felt.Felt, node, bool, error) { switch n := n.(type) { case *edgeNode: if !n.pathMatches(key) { return nil, nil, false, nil } - val, child, didResolve, err := t.get(n.child, key.LSBs(key, n.path.Len())) + val, child, didResolve, err := t.get(n.child, new(Path).Append(prefix, n.path), key.LSBs(key, n.path.Len())) if err == nil && didResolve { n = n.copy() n.child = child @@ -78,7 +131,7 @@ func (t *Trie) get(n node, key *BitArray) (*felt.Felt, node, bool, error) { return val, n, didResolve, err case *binaryNode: bit := key.MSB() - val, child, didResolve, err := t.get(n.children[bit], key.LSBs(key, 1)) + val, child, didResolve, err := t.get(n.children[bit], new(Path).SetBit(bit), key.LSBs(key, 1)) if err == nil && didResolve { n = n.copy() n.children[bit] = child @@ -87,7 +140,7 @@ func (t *Trie) get(n node, key *BitArray) (*felt.Felt, node, bool, error) { case hashNode: panic("TODO(weiihann): implement me") case valueNode: - return n.Felt, n, false, nil + return &n.Felt, n, false, nil case nil: return nil, nil, false, nil default: @@ -98,15 +151,15 @@ func (t *Trie) get(n node, key *BitArray) (*felt.Felt, node, bool, error) { // Modifies the trie by either inserting/updating a value or deleting a key. // The operation is determined by whether the value is zero (delete) or non-zero (insert/update). func (t *Trie) update(key, value *felt.Felt) error { - k := t.FeltToKey(key) + k := t.FeltToPath(key) if value.IsZero() { - _, n, err := t.delete(t.root, new(BitArray), &k) + _, n, err := t.delete(t.root, new(Path), &k) if err != nil { return err } t.root = n } else { - _, n, err := t.insert(t.root, &k, valueNode{Felt: value}) + _, n, err := t.insert(t.root, new(Path), &k, valueNode{Felt: *value}) if err != nil { return err } @@ -115,11 +168,12 @@ func (t *Trie) update(key, value *felt.Felt) error { return nil } -func (t *Trie) insert(n node, key *BitArray, value node) (bool, node, error) { +func (t *Trie) insert(n node, prefix, key *Path, value node) (bool, node, error) { // We reach the end of the key if key.Len() == 0 { if v, ok := n.(valueNode); ok { - return v.Equal(value.(valueNode).Felt), value, nil + vFelt := value.(valueNode).Felt + return v.Equal(&vFelt), value, nil } return true, value, nil } @@ -127,9 +181,9 @@ func (t *Trie) insert(n node, key *BitArray, value node) (bool, node, error) { switch n := n.(type) { case *edgeNode: match := n.commonPath(key) // get the matching bits between the current node and the key - // If the match is the same as the path, just keep this edge node as it is and update the value + // If the match is the same as the Path, just keep this edge node as it is and update the value if match.Len() == n.path.Len() { - dirty, newNode, err := t.insert(n.child, key.LSBs(key, match.Len()), value) + dirty, newNode, err := t.insert(n.child, new(Path).Append(prefix, n.path), key.LSBs(key, match.Len()), value) if !dirty || err != nil { return false, n, err } @@ -142,12 +196,18 @@ func (t *Trie) insert(n node, key *BitArray, value node) (bool, node, error) { // Otherwise branch out at the bit position where they differ branch := &binaryNode{flags: newFlag()} var err error - _, branch.children[n.path.Bit(match.Len())], err = t.insert(nil, new(BitArray).LSBs(n.path, match.Len()+1), n.child) + PathPrefix := new(Path).MSBs(n.path, match.Len()+1) + _, branch.children[n.path.Bit(match.Len())], err = t.insert( + nil, new(Path).Append(prefix, PathPrefix), new(Path).LSBs(n.path, match.Len()+1), n.child, + ) if err != nil { return false, n, err } - _, branch.children[key.Bit(match.Len())], err = t.insert(nil, new(BitArray).LSBs(key, match.Len()+1), value) + keyPrefix := new(Path).MSBs(key, match.Len()+1) + _, branch.children[key.Bit(match.Len())], err = t.insert( + nil, new(Path).Append(prefix, keyPrefix), new(Path).LSBs(key, match.Len()+1), value, + ) if err != nil { return false, n, err } @@ -156,14 +216,18 @@ func (t *Trie) insert(n node, key *BitArray, value node) (bool, node, error) { if match.IsEmpty() { return true, branch, nil } + matchPrefix := new(Path).MSBs(key, match.Len()) + t.tracer.onInsert(new(Path).Append(prefix, matchPrefix)) - // Otherwise, create a new edge node with the path being the common path and the branch as the child - return true, &edgeNode{path: new(BitArray).MSBs(key, match.Len()), child: branch, flags: newFlag()}, nil + // Otherwise, create a new edge node with the Path being the common Path and the branch as the child + return true, &edgeNode{path: matchPrefix, child: branch, flags: newFlag()}, nil case *binaryNode: // Go to the child node based on the MSB of the key bit := key.MSB() - dirty, newNode, err := t.insert(n.children[bit], new(BitArray).LSBs(key, 1), value) + dirty, newNode, err := t.insert( + n.children[bit], new(Path).Append(prefix, new(Path).SetBit(bit)), new(Path).LSBs(key, 1), value, + ) if !dirty || err != nil { return false, n, err } @@ -173,11 +237,12 @@ func (t *Trie) insert(n node, key *BitArray, value node) (bool, node, error) { n.children[bit] = newNode return true, n, nil case nil: + t.tracer.onInsert(prefix) // We reach the end of the key, return the value node if key.IsEmpty() { return true, value, nil } - // Otherwise, return a new edge node with the path being the key and the value as the child + // Otherwise, return a new edge node with the Path being the key and the value as the child return true, &edgeNode{path: key, child: value, flags: newFlag()}, nil case hashNode: panic("TODO(weiihann): implement me") @@ -186,7 +251,7 @@ func (t *Trie) insert(n node, key *BitArray, value node) (bool, node, error) { } } -func (t *Trie) delete(n node, prefix, key *BitArray) (bool, node, error) { +func (t *Trie) delete(n node, prefix, key *Path) (bool, node, error) { switch n := n.(type) { case *edgeNode: match := n.commonPath(key) @@ -196,26 +261,28 @@ func (t *Trie) delete(n node, prefix, key *BitArray) (bool, node, error) { } // If the whole key matches, remove the entire edge node if match.Len() == key.Len() { + t.tracer.onDelete(prefix) return true, nil, nil } - // Otherwise, key is longer than current node path, so we need to delete the child. + // Otherwise, key is longer than current node Path, so we need to delete the child. // Child can never be nil because it's guaranteed that we have at least 2 other values in the subtrie. - keyPrefix := new(BitArray).MSBs(key, n.path.Len()) - dirty, child, err := t.delete(n.child, new(BitArray).Append(prefix, keyPrefix), key.LSBs(key, n.path.Len())) + keyPrefix := new(Path).MSBs(key, n.path.Len()) + dirty, child, err := t.delete(n.child, new(Path).Append(prefix, keyPrefix), key.LSBs(key, n.path.Len())) if !dirty || err != nil { return false, n, err } switch child := child.(type) { case *edgeNode: - return true, &edgeNode{path: new(BitArray).Append(n.path, child.path), child: child.child, flags: newFlag()}, nil + t.tracer.onDelete(new(Path).Append(prefix, n.path)) + return true, &edgeNode{path: new(Path).Append(n.path, child.path), child: child.child, flags: newFlag()}, nil default: - return true, &edgeNode{path: new(BitArray).Set(n.path), child: child, flags: newFlag()}, nil + return true, &edgeNode{path: new(Path).Set(n.path), child: child, flags: newFlag()}, nil } case *binaryNode: bit := key.MSB() - keyPrefix := new(BitArray).MSBs(key, 1) - dirty, newNode, err := t.delete(n.children[bit], new(BitArray).Append(prefix, keyPrefix), key.LSBs(key, 1)) + keyPrefix := new(Path).MSBs(key, 1) + dirty, newNode, err := t.delete(n.children[bit], new(Path).Append(prefix, keyPrefix), key.LSBs(key, 1)) if !dirty || err != nil { return false, n, err } @@ -230,16 +297,17 @@ func (t *Trie) delete(n node, prefix, key *BitArray) (bool, node, error) { // Otherwise, we need to combine this binary node with the other child other := bit ^ 1 - bitPrefix := new(BitArray).SetBit(other == 1) - if cn, ok := n.children[other].(*edgeNode); ok { // other child is an edge node, append the bit prefix to the child path + bitPrefix := new(Path).SetBit(other) + if cn, ok := n.children[other].(*edgeNode); ok { // other child is an edge node, append the bit prefix to the child Path + t.tracer.onDelete(new(Path).Append(prefix, bitPrefix)) return true, &edgeNode{ - path: new(BitArray).Append(bitPrefix, cn.path), + path: new(Path).Append(bitPrefix, cn.path), child: cn.child, flags: newFlag(), }, nil } - // other child is not an edge node, create a new edge node with the bit prefix as the path + // other child is not an edge node, create a new edge node with the bit prefix as the Path // containing the other child as the child return true, &edgeNode{path: bitPrefix, child: n.children[other], flags: newFlag()}, nil case valueNode: @@ -254,14 +322,16 @@ func (t *Trie) delete(n node, prefix, key *BitArray) (bool, node, error) { } func (t *Trie) hashRoot() (node, node) { - h := newHasher(t.hashFn, false) // TODO(weiihann): handle parallel hashing - return h.hash(t.root) + h := newHasher(t.hashFn, t.pendingHashes > 100) // TODO(weiihann): 100 is arbitrary + hashed, cached := h.hash(t.root) + t.pendingHashes = 0 + return hashed, cached } -// Converts a Felt value into a BitArray representation suitable for +// Converts a Felt value into a Path representation suitable to // use as a trie key with the specified height. -func (t *Trie) FeltToKey(f *felt.Felt) BitArray { - var key BitArray +func (t *Trie) FeltToPath(f *felt.Felt) Path { + var key Path key.SetFelt(t.height, f) return key } diff --git a/core/trie2/trie_test.go b/core/trie2/trie_test.go index 6361eba4fa..edfbbe0a8e 100644 --- a/core/trie2/trie_test.go +++ b/core/trie2/trie_test.go @@ -172,6 +172,14 @@ func TestHash(t *testing.T) { }) } +func TestCommit(t *testing.T) { + tr, _ := nonRandomTrie(t, 1000) + tr2 := tr.Copy() + + root := tr.Commit() + require.Equal(t, root, tr2.Hash()) +} + type keyValue struct { key *felt.Felt value *felt.Felt diff --git a/core/trie2/types.go b/core/trie2/types.go new file mode 100644 index 0000000000..23ce5acc3f --- /dev/null +++ b/core/trie2/types.go @@ -0,0 +1,8 @@ +package trie2 + +type TrieType uint8 + +const ( + ClassTrie TrieType = iota + 1 + StateTrie +) diff --git a/core/trie2/bitarray.go b/core/trie2/utils/bitarray.go similarity index 92% rename from core/trie2/bitarray.go rename to core/trie2/utils/bitarray.go index 6fe218df9e..8a81bbe7d9 100644 --- a/core/trie2/bitarray.go +++ b/core/trie2/utils/bitarray.go @@ -1,4 +1,4 @@ -package trie2 +package utils import ( "bytes" @@ -158,7 +158,7 @@ func (b *BitArray) EqualMSBs(x *BitArray) bool { return new(BitArray).MSBs(b, minLen).Equal(new(BitArray).MSBs(x, minLen)) } -// Sets the bit array to the most significant 'n' bits of x. +// Sets the bit array to the most significant 'n' bits of x, that is position 0 to n (exclusive). // If n >= x.len, the bit array is an exact copy of x. // For example: // @@ -480,14 +480,10 @@ func (b *BitArray) SetUint64(length uint8, data uint64) *BitArray { } // Sets the bit array to a single bit. -func (b *BitArray) SetBit(bit bool) *BitArray { +func (b *BitArray) SetBit(bit uint8) *BitArray { b.len = 1 - if bit { - b.words[0] = 1 - } else { - b.words[0] = 0 - } - b.truncateToLength() + b.words[0] = uint64(bit & 1) + b.words[1], b.words[2], b.words[3] = 0, 0, 0 return b } @@ -503,7 +499,16 @@ func (b *BitArray) Copy() BitArray { return res } +// Returns the encoded string representation of the bit array. +func (b *BitArray) EncodedString() string { + var res []byte + res = append(res, b.len) + res = append(res, b.Bytes()...) + return string(res) +} + // Returns a string representation of the bit array. +// This is typically used for logging or debugging. func (b *BitArray) String() string { return fmt.Sprintf("(%d) %s", b.len, hex.EncodeToString(b.Bytes())) } @@ -615,3 +620,36 @@ func findFirstSetBit(b *BitArray) uint8 { // All bits are zero, no set bit found return 0 } + +// Cmp compares two bit arrays lexicographically. +// The comparison is first done by length, then by content if lengths are equal. +// Returns: +// +// -1 if b < x +// 0 if b == x +// 1 if b > x +func (b *BitArray) Cmp(x *BitArray) int { + // First compare lengths + if b.len < x.len { + return -1 + } + if b.len > x.len { + return 1 + } + + // Lengths are equal, compare the actual bits + d0, carry := bits.Sub64(b.words[0], x.words[0], 0) + d1, carry := bits.Sub64(b.words[1], x.words[1], carry) + d2, carry := bits.Sub64(b.words[2], x.words[2], carry) + d3, carry := bits.Sub64(b.words[3], x.words[3], carry) + + if carry == 1 { + return -1 + } + + if d0|d1|d2|d3 == 0 { + return 0 + } + + return 1 +} From 0cf80f446c3fe33309411ad5c77ac8ad914b9758 Mon Sep 17 00:00:00 2001 From: weiihann Date: Wed, 15 Jan 2025 18:29:03 +0800 Subject: [PATCH 16/44] add Clear() to OrderedSet --- utils/orderedset.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/utils/orderedset.go b/utils/orderedset.go index 9d78809731..b5ab2a6680 100644 --- a/utils/orderedset.go +++ b/utils/orderedset.go @@ -75,3 +75,11 @@ func (o *OrderedSet[K, V]) Keys() []K { } return keys } + +func (o *OrderedSet[K, V]) Clear() { + o.lock.Lock() + defer o.lock.Unlock() + + o.items = nil + o.itemPos = make(map[K]int) +} From 1b5c7a5b4c51827051d916f851eb9b9a156f80bf Mon Sep 17 00:00:00 2001 From: weiihann Date: Wed, 15 Jan 2025 18:29:11 +0800 Subject: [PATCH 17/44] implement triedb --- core/trie2/triedb/database.go | 100 ++++++++++++++++++++++++++++++++++ db/buckets.go | 3 + 2 files changed, 103 insertions(+) create mode 100644 core/trie2/triedb/database.go diff --git a/core/trie2/triedb/database.go b/core/trie2/triedb/database.go new file mode 100644 index 0000000000..ac0b908e1b --- /dev/null +++ b/core/trie2/triedb/database.go @@ -0,0 +1,100 @@ +package triedb + +import ( + "bytes" + "sync" + + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/trie2/trieutils" + "github.com/NethermindEth/juno/db" +) + +var dbBufferPool = sync.Pool{ + New: func() any { + return new(bytes.Buffer) + }, +} + +type Database struct { + txn db.Transaction + prefix db.Bucket +} + +func New(txn db.Transaction, prefix db.Bucket) *Database { + return &Database{txn: txn, prefix: prefix} +} + +func (d *Database) Get(buf *bytes.Buffer, owner felt.Felt, path trieutils.BitArray) (int, error) { + dbBuf := dbBufferPool.Get().(*bytes.Buffer) + dbBuf.Reset() + defer func() { + dbBuf.Reset() + dbBufferPool.Put(dbBuf) + }() + + if err := d.dbKey(dbBuf, owner, path); err != nil { + return 0, err + } + + err := d.txn.Get(dbBuf.Bytes(), func(blob []byte) error { + buf.Write(blob) + return nil + }) + if err != nil { + return 0, err + } + + return buf.Len(), nil +} + +func (d *Database) Put(owner felt.Felt, path trieutils.BitArray, blob []byte) error { + buffer := dbBufferPool.Get().(*bytes.Buffer) + buffer.Reset() + defer func() { + buffer.Reset() + dbBufferPool.Put(buffer) + }() + + if err := d.dbKey(buffer, owner, path); err != nil { + return err + } + + return d.txn.Set(buffer.Bytes(), blob) +} + +func (d *Database) Delete(owner felt.Felt, path trieutils.BitArray) error { + buffer := dbBufferPool.Get().(*bytes.Buffer) + buffer.Reset() + defer func() { + buffer.Reset() + dbBufferPool.Put(buffer) + }() + + if err := d.dbKey(buffer, owner, path); err != nil { + return err + } + + return d.txn.Delete(buffer.Bytes()) +} + +func (d *Database) dbKey(buf *bytes.Buffer, owner felt.Felt, path trieutils.BitArray) error { + _, err := buf.Write(d.prefix.Key()) + if err != nil { + return err + } + + if owner != (felt.Felt{}) { + oBytes := owner.Bytes() + _, err = buf.Write(oBytes[:]) + if err != nil { + return err + } + } + + _, err = path.Write(buf) + if err != nil { + return err + } + + return nil +} diff --git a/db/buckets.go b/db/buckets.go index e5037378a3..fe6ea5a85c 100644 --- a/db/buckets.go +++ b/db/buckets.go @@ -38,6 +38,9 @@ const ( MempoolTail // key of the tail node MempoolLength // number of transactions MempoolNode + ClassTrie + ContractTrieContract + ContractTrieStorage ) // Key flattens a prefix and series of byte arrays into a single []byte. From b1c5ff221e6b261bf6bf6663955ec0199eb28d48 Mon Sep 17 00:00:00 2001 From: weiihann Date: Wed, 15 Jan 2025 18:29:35 +0800 Subject: [PATCH 18/44] rename trie/utils to trie/trieutils --- core/trie2/{utils => trieutils}/bitarray.go | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) rename core/trie2/{utils => trieutils}/bitarray.go (97%) diff --git a/core/trie2/utils/bitarray.go b/core/trie2/trieutils/bitarray.go similarity index 97% rename from core/trie2/utils/bitarray.go rename to core/trie2/trieutils/bitarray.go index 8a81bbe7d9..63fc134593 100644 --- a/core/trie2/utils/bitarray.go +++ b/core/trie2/trieutils/bitarray.go @@ -1,4 +1,4 @@ -package utils +package trieutils import ( "bytes" @@ -12,6 +12,8 @@ import ( ) const ( + MaxBitArraySize = 33 // (1 + 4 * 8) bytes + maxUint64 = uint64(math.MaxUint64) // 0xFFFFFFFFFFFFFFFF maxUint8 = uint8(math.MaxUint8) ) @@ -37,7 +39,8 @@ func NewBitArray(length uint8, val uint64) BitArray { // Returns the felt representation of the bit array. func (b *BitArray) Felt() felt.Felt { var f felt.Felt - f.SetBytes(b.Bytes()) + bs := b.Bytes() + f.SetBytes(bs[:]) return f } @@ -46,7 +49,7 @@ func (b *BitArray) Len() uint8 { } // Returns the bytes representation of the bit array in big endian format -func (b *BitArray) Bytes() []byte { +func (b *BitArray) Bytes() [32]byte { var res [32]byte b.truncateToLength() @@ -55,7 +58,7 @@ func (b *BitArray) Bytes() []byte { binary.BigEndian.PutUint64(res[16:24], b.words[1]) binary.BigEndian.PutUint64(res[24:32], b.words[0]) - return res[:] + return res } // Sets the bit array to the least significant 'n' bits of x. @@ -326,6 +329,10 @@ func (b *BitArray) Append(x, y *BitArray) *BitArray { return b.Lsh(b, y.len).Or(b, y) } +func (b *BitArray) AppendBit(x *BitArray, bit uint8) *BitArray { + return b.Append(x, new(BitArray).SetBit(bit)) +} + // Sets the bit array to x | y and returns the bit array. func (b *BitArray) Or(x, y *BitArray) *BitArray { b.words[0] = x.words[0] | y.words[0] @@ -502,15 +509,17 @@ func (b *BitArray) Copy() BitArray { // Returns the encoded string representation of the bit array. func (b *BitArray) EncodedString() string { var res []byte + bs := b.Bytes() res = append(res, b.len) - res = append(res, b.Bytes()...) + res = append(res, bs[:]...) return string(res) } // Returns a string representation of the bit array. // This is typically used for logging or debugging. func (b *BitArray) String() string { - return fmt.Sprintf("(%d) %s", b.len, hex.EncodeToString(b.Bytes())) + bs := b.Bytes() + return fmt.Sprintf("(%d) %s", b.len, hex.EncodeToString(bs[:])) } func (b *BitArray) setFelt(f *felt.Felt) { From e945538af3dc24366072e5a5e5cf057daa59b4ba Mon Sep 17 00:00:00 2001 From: weiihann Date: Wed, 15 Jan 2025 18:29:58 +0800 Subject: [PATCH 19/44] minor changes on trienode --- core/trie2/trienode/node.go | 4 ++++ core/trie2/trienode/nodeset.go | 22 ++++++++++--------- core/trie2/trienode/trienode_test.go | 32 +++++++++++++++------------- 3 files changed, 33 insertions(+), 25 deletions(-) diff --git a/core/trie2/trienode/node.go b/core/trie2/trienode/node.go index f601a65006..febdbf53d1 100644 --- a/core/trie2/trienode/node.go +++ b/core/trie2/trienode/node.go @@ -13,6 +13,10 @@ func (r *Node) IsDeleted() bool { return len(r.blob) == 0 } +func (r *Node) Blob() []byte { + return r.blob +} + func NewNode(hash felt.Felt, blob []byte) *Node { return &Node{hash: hash, blob: blob} } diff --git a/core/trie2/trienode/nodeset.go b/core/trie2/trienode/nodeset.go index c596ba6ee1..ab891be3a0 100644 --- a/core/trie2/trienode/nodeset.go +++ b/core/trie2/trienode/nodeset.go @@ -6,23 +6,21 @@ import ( "sort" "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/core/trie2/utils" + "github.com/NethermindEth/juno/core/trie2/trieutils" ) -type path = utils.BitArray - type NodeSet struct { Owner felt.Felt - Nodes map[path]*Node + Nodes map[trieutils.BitArray]*Node updates int deletes int } func NewNodeSet(owner felt.Felt) *NodeSet { - return &NodeSet{Owner: owner, Nodes: make(map[path]*Node)} + return &NodeSet{Owner: owner, Nodes: make(map[trieutils.BitArray]*Node)} } -func (ns *NodeSet) Add(key path, node *Node) { +func (ns *NodeSet) Add(key trieutils.BitArray, node *Node) { if node.IsDeleted() { ns.deletes += 1 } else { @@ -31,8 +29,8 @@ func (ns *NodeSet) Add(key path, node *Node) { ns.Nodes[key] = node } -func (ns *NodeSet) ForEach(desc bool, callback func(key path, node *Node)) { - paths := make([]path, 0, len(ns.Nodes)) +func (ns *NodeSet) ForEach(desc bool, callback func(key trieutils.BitArray, node *Node) error) error { + paths := make([]trieutils.BitArray, 0, len(ns.Nodes)) for key := range ns.Nodes { paths = append(paths, key) } @@ -48,8 +46,12 @@ func (ns *NodeSet) ForEach(desc bool, callback func(key path, node *Node)) { } for _, key := range paths { - callback(key, ns.Nodes[key]) + if err := callback(key, ns.Nodes[key]); err != nil { + return err + } } + + return nil } func (ns *NodeSet) MergeSet(other *NodeSet) error { @@ -62,7 +64,7 @@ func (ns *NodeSet) MergeSet(other *NodeSet) error { return nil } -func (ns *NodeSet) Merge(owner felt.Felt, other map[path]*Node) error { +func (ns *NodeSet) Merge(owner felt.Felt, other map[trieutils.BitArray]*Node) error { if ns.Owner != owner { return fmt.Errorf("cannot merge node sets with different owners %x-%x", ns.Owner, owner) } diff --git a/core/trie2/trienode/trienode_test.go b/core/trie2/trienode/trienode_test.go index d757559fa4..9b22e9e6e2 100644 --- a/core/trie2/trienode/trienode_test.go +++ b/core/trie2/trienode/trienode_test.go @@ -4,7 +4,7 @@ import ( "testing" "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/core/trie2/utils" + "github.com/NethermindEth/juno/core/trie2/trieutils" "github.com/stretchr/testify/require" ) @@ -21,14 +21,14 @@ func TestNodeSet(t *testing.T) { ns := NewNodeSet(felt.Zero) // Add a regular node - key1 := utils.NewBitArray(8, 0xFF) + key1 := trieutils.NewBitArray(8, 0xFF) node1 := NewNode(felt.Zero, []byte{1, 2, 3}) ns.Add(key1, node1) require.Equal(t, 1, ns.updates) require.Equal(t, 0, ns.deletes) // Add a deleted node - key2 := utils.NewBitArray(8, 0xAA) + key2 := trieutils.NewBitArray(8, 0xAA) node2 := NewDeleted() ns.Add(key2, node2) require.Equal(t, 1, ns.updates) @@ -44,12 +44,12 @@ func TestNodeSet(t *testing.T) { ns2 := NewNodeSet(felt.Zero) // Add nodes to first set - key1 := utils.NewBitArray(8, 0xFF) + key1 := trieutils.NewBitArray(8, 0xFF) node1 := NewNode(felt.Zero, []byte{1, 2, 3}) ns1.Add(key1, node1) // Add nodes to second set - key2 := utils.NewBitArray(8, 0xAA) + key2 := trieutils.NewBitArray(8, 0xAA) node2 := NewDeleted() ns2.Add(key2, node2) @@ -80,8 +80,8 @@ func TestNodeSet(t *testing.T) { ns := NewNodeSet(*owner) // Create a map to merge - nodes := make(map[utils.BitArray]*Node) - key1 := utils.NewBitArray(8, 0xFF) + nodes := make(map[trieutils.BitArray]*Node) + key1 := trieutils.NewBitArray(8, 0xFF) node1 := NewNode(felt.Zero, []byte{1, 2, 3}) nodes[key1] = node1 @@ -100,19 +100,20 @@ func TestNodeSet(t *testing.T) { ns := NewNodeSet(felt.Zero) // Add nodes in random order - keys := []utils.BitArray{ - utils.NewBitArray(8, 0xFF), - utils.NewBitArray(8, 0xAA), - utils.NewBitArray(8, 0x55), + keys := []trieutils.BitArray{ + trieutils.NewBitArray(8, 0xFF), + trieutils.NewBitArray(8, 0xAA), + trieutils.NewBitArray(8, 0x55), } for _, key := range keys { ns.Add(key, NewNode(felt.Zero, []byte{1})) } t.Run("ascending order", func(t *testing.T) { - var visited []utils.BitArray - ns.ForEach(false, func(key utils.BitArray, node *Node) { + var visited []trieutils.BitArray + _ = ns.ForEach(false, func(key trieutils.BitArray, node *Node) error { visited = append(visited, key) + return nil }) // Verify ascending order @@ -122,9 +123,10 @@ func TestNodeSet(t *testing.T) { }) t.Run("descending order", func(t *testing.T) { - var visited []utils.BitArray - ns.ForEach(true, func(key utils.BitArray, node *Node) { + var visited []trieutils.BitArray + _ = ns.ForEach(true, func(key trieutils.BitArray, node *Node) error { visited = append(visited, key) + return nil }) // Verify descending order From 0262038927ab9bb563985d4b1b87c3bd4a127de0 Mon Sep 17 00:00:00 2001 From: weiihann Date: Wed, 15 Jan 2025 18:30:08 +0800 Subject: [PATCH 20/44] add trie id --- core/trie2/id.go | 67 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 core/trie2/id.go diff --git a/core/trie2/id.go b/core/trie2/id.go new file mode 100644 index 0000000000..4150ea6a37 --- /dev/null +++ b/core/trie2/id.go @@ -0,0 +1,67 @@ +package trie2 + +import ( + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/db" +) + +type TrieType uint8 + +const ( + Empty TrieType = iota + ClassTrie + ContractTrie +) + +// Represents the identifier for uniquely identifying a trie. +type ID struct { + TrieType TrieType + Root felt.Felt // The root hash of the trie + Owner felt.Felt // The contract address which the trie belongs to + StorageRoot felt.Felt // The root hash of the storage trie of a contract. +} + +func (id *ID) Bucket() db.Bucket { + switch id.TrieType { + case ClassTrie: + return db.ClassTrie + case ContractTrie: + if id.Owner == (felt.Felt{}) { + return db.ContractTrieContract + } + return db.ContractTrieStorage + case Empty: + return db.Bucket(0) + default: + panic("invalid trie type") + } +} + +// Constructs an identifier for a class trie with the provided class trie root hash +func ClassTrieID(root felt.Felt) *ID { + return &ID{ + TrieType: ClassTrie, + Root: root, + Owner: felt.Zero, // class trie does not have an owner + StorageRoot: felt.Zero, // only contract storage trie has a storage root + } +} + +// Constructs an identifier for a contract trie or a contract's storage trie +func ContractTrieID(root, owner, storageRoot felt.Felt) *ID { + return &ID{ + TrieType: ContractTrie, + Root: root, + Owner: owner, + StorageRoot: storageRoot, + } +} + +func TrieID(root felt.Felt) *ID { + return &ID{ + TrieType: Empty, + Root: root, + Owner: felt.Zero, + StorageRoot: felt.Zero, + } +} From efe4bb295e916661777296c747e951ec7783796d Mon Sep 17 00:00:00 2001 From: weiihann Date: Wed, 15 Jan 2025 18:32:44 +0800 Subject: [PATCH 21/44] ...a bunch of changes --- core/trie2/collector.go | 15 ++- core/trie2/hasher.go | 2 +- core/trie2/node.go | 17 +--- core/trie2/node_enc.go | 69 +++++++++++++- core/trie2/proof.go | 156 +++++++++++++++++++++++++++++++ core/trie2/proof_test.go | 1 + core/trie2/reader.go | 24 ----- core/trie2/trie.go | 196 +++++++++++++++++++++++++++++++-------- core/trie2/trie_test.go | 113 +++++++++++----------- core/trie2/types.go | 8 -- 10 files changed, 456 insertions(+), 145 deletions(-) create mode 100644 core/trie2/proof.go create mode 100644 core/trie2/proof_test.go delete mode 100644 core/trie2/reader.go delete mode 100644 core/trie2/types.go diff --git a/core/trie2/collector.go b/core/trie2/collector.go index aed7019595..8b47441b98 100644 --- a/core/trie2/collector.go +++ b/core/trie2/collector.go @@ -6,6 +6,7 @@ import ( ) import ( + "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/trie2/trienode" ) @@ -19,8 +20,8 @@ func newCollector(nodes *trienode.NodeSet) *collector { } // Collects the nodes in the node set and collapses a node into a hash node -func (c *collector) Collect(n node, parallel bool) hashNode { - return c.collect(new(Path), n, parallel).(hashNode) +func (c *collector) Collect(n node, parallel bool) *hashNode { + return c.collect(new(Path), n, parallel).(*hashNode) } func (c *collector) collect(path *Path, n node, parallel bool) node { @@ -46,9 +47,9 @@ func (c *collector) collect(path *Path, n node, parallel bool) node { collapsed := cn.copy() collapsed.children = c.collectChildren(path, cn, parallel) return c.store(path, collapsed) - case hashNode: + case *hashNode: return cn - case valueNode: // each leaf node is stored as a single entry in the node set + case *valueNode: // each leaf node is stored as a single entry in the node set return c.store(path, cn) default: panic(fmt.Sprintf("unknown node type: %T", cn)) @@ -111,6 +112,12 @@ func (c *collector) collectChildren(path *Path, n *binaryNode, parallel bool) [2 // Stores the node in the node set and returns the hash node func (c *collector) store(path *Path, n node) node { hash, _ := n.cache() + + if hash == nil { // this is a value node + c.nodes.Add(*path, trienode.NewNode(felt.Felt{}, nodeToBytes(n))) + return n + } + c.nodes.Add(*path, trienode.NewNode(hash.Felt, nodeToBytes(n))) return hash } diff --git a/core/trie2/hasher.go b/core/trie2/hasher.go index c3f255dbc8..12c9cef284 100644 --- a/core/trie2/hasher.go +++ b/core/trie2/hasher.go @@ -40,7 +40,7 @@ func (h *hasher) hash(n node) (node, node) { hn := &hashNode{Felt: *collapsed.hash(h.hashFn)} cached.flags.hash = hn return hn, cached - case valueNode, hashNode: + case *valueNode, *hashNode: return n, n default: panic(fmt.Sprintf("unknown node type: %T", n)) diff --git a/core/trie2/node.go b/core/trie2/node.go index d06b657b10..1333a0ba9b 100644 --- a/core/trie2/node.go +++ b/core/trie2/node.go @@ -37,19 +37,12 @@ type ( valueNode struct{ felt.Felt } ) -const ( - binaryNodeType byte = iota - edgeNodeType - hashNodeType - valueNodeType -) - type nodeFlag struct { hash *hashNode dirty bool } -func newFlag() nodeFlag { return nodeFlag{dirty: false} } +func newFlag() nodeFlag { return nodeFlag{dirty: true} } func (n *binaryNode) hash(hf crypto.HashFn) *felt.Felt { return hf(n.children[0].hash(hf), n.children[1].hash(hf)) @@ -63,13 +56,13 @@ func (n *edgeNode) hash(hf crypto.HashFn) *felt.Felt { return new(felt.Felt).Add(hf(n.child.hash(hf), &pathFelt), lengthFelt) } -func (n hashNode) hash(crypto.HashFn) *felt.Felt { return &n.Felt } -func (n valueNode) hash(crypto.HashFn) *felt.Felt { return &n.Felt } +func (n *hashNode) hash(crypto.HashFn) *felt.Felt { return &n.Felt } +func (n *valueNode) hash(crypto.HashFn) *felt.Felt { return &n.Felt } func (n *binaryNode) cache() (*hashNode, bool) { return n.flags.hash, n.flags.dirty } func (n *edgeNode) cache() (*hashNode, bool) { return n.flags.hash, n.flags.dirty } -func (n hashNode) cache() (*hashNode, bool) { return nil, true } -func (n valueNode) cache() (*hashNode, bool) { return nil, true } +func (n *hashNode) cache() (*hashNode, bool) { return nil, true } +func (n *valueNode) cache() (*hashNode, bool) { return nil, true } func (n *binaryNode) String() string { return fmt.Sprintf("Binary[\n left: %s\n right: %s\n]", diff --git a/core/trie2/node_enc.go b/core/trie2/node_enc.go index f3de40bce7..8b0074c96f 100644 --- a/core/trie2/node_enc.go +++ b/core/trie2/node_enc.go @@ -2,7 +2,17 @@ package trie2 import ( "bytes" + "fmt" "sync" + + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/trie2/trieutils" +) + +const ( + binaryNodeSize = 2 * hashOrValueNodeSize // LeftHash + RightHash + edgeNodeSize = trieutils.MaxBitArraySize + hashOrValueNodeSize // Path + Child Hash (max size, could be less) + hashOrValueNodeSize = felt.Bytes ) var bufferPool = sync.Pool{ @@ -24,18 +34,18 @@ func (n *binaryNode) write(buf *bytes.Buffer) error { } func (n *edgeNode) write(buf *bytes.Buffer) error { - if _, err := n.path.Write(buf); err != nil { + if err := n.child.write(buf); err != nil { return err } - if err := n.child.write(buf); err != nil { + if _, err := n.path.Write(buf); err != nil { return err } return nil } -func (n hashNode) write(buf *bytes.Buffer) error { +func (n *hashNode) write(buf *bytes.Buffer) error { if _, err := buf.Write(n.Felt.Marshal()); err != nil { return err } @@ -43,7 +53,7 @@ func (n hashNode) write(buf *bytes.Buffer) error { return nil } -func (n valueNode) write(buf *bytes.Buffer) error { +func (n *valueNode) write(buf *bytes.Buffer) error { if _, err := buf.Write(n.Felt.Marshal()); err != nil { return err } @@ -64,3 +74,54 @@ func nodeToBytes(n node) []byte { } return buf.Bytes() } + +func decodeNode(blob []byte, hash felt.Felt, pathLen, maxPathLen uint8) (node, error) { + var ( + n node + err error + isLeaf bool + ) + + isLeaf = pathLen == maxPathLen + + switch len(blob) { + case hashOrValueNodeSize: + if isLeaf { + n = &valueNode{Felt: hash} + } else { + n = &hashNode{Felt: hash} + } + case binaryNodeSize: + binary := &binaryNode{flags: nodeFlag{hash: &hashNode{Felt: hash}}} // cache the hash + binary.children[0], err = decodeNode(blob[:hashOrValueNodeSize], hash, pathLen+1, maxPathLen) + if err != nil { + return nil, err + } + binary.children[1], err = decodeNode(blob[hashOrValueNodeSize:], hash, pathLen+1, maxPathLen) + if err != nil { + return nil, err + } + + n = binary + default: + // Edge node size is capped, if the blob is larger than the max size, it's invalid + if len(blob) > edgeNodeSize { + return nil, fmt.Errorf("invalid node size: %d", len(blob)) + } + edge := &edgeNode{flags: nodeFlag{hash: &hashNode{Felt: hash}}} // cache the hash + edge.child, err = decodeNode(blob[:hashOrValueNodeSize], hash, pathLen, maxPathLen) + if err != nil { + return nil, err + } + edge.path.UnmarshalBinary(blob[hashOrValueNodeSize:]) + + // We do another path length check to see if the node is a leaf + if pathLen+edge.path.Len() == maxPathLen { + edge.child = &valueNode{Felt: edge.child.(*hashNode).Felt} + } + + n = edge + } + + return n, nil +} diff --git a/core/trie2/proof.go b/core/trie2/proof.go new file mode 100644 index 0000000000..79be95f4e3 --- /dev/null +++ b/core/trie2/proof.go @@ -0,0 +1,156 @@ +package trie2 + +import ( + "fmt" + + "github.com/NethermindEth/juno/core/crypto" + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/utils" +) + +type ProofNodeSet = utils.OrderedSet[felt.Felt, node] + +func NewProofNodeSet() *ProofNodeSet { + return utils.NewOrderedSet[felt.Felt, node]() +} + +// Prove generates a Merkle proof for a given key in the trie. +// The result contains the proof nodes on the path from the root to the leaf. +// The value is included in the proof if the key is present in the trie. +// If the key is not present, the proof will contain the nodes on the path to the closest ancestor. +func (t *Trie) Prove(key *felt.Felt, proof *ProofNodeSet) error { + if t.committed { + return ErrCommitted + } + + path := t.FeltToPath(key) + k := &path + + var ( + prefix *Path + nodes []node + rn = t.root + ) + + for k.Len() > 0 && rn != nil { + switch n := rn.(type) { + case *edgeNode: + if !n.pathMatches(k) { + rn = nil // Trie doesn't contain the key + } else { + rn = n.child + prefix.Append(prefix, n.path) + k.LSBs(k, n.path.Len()) + } + nodes = append(nodes, n) + case *binaryNode: + bit := k.MSB() + prefix.AppendBit(prefix, bit) + k.LSBs(k, 1) + nodes = append(nodes, n) + case *hashNode: + resolved, err := t.resolveNode(n, *prefix) + if err != nil { + return err + } + rn = resolved + default: + panic(fmt.Sprintf("unknown node type: %T", n)) + } + } + + h := newHasher(t.hashFn, false) + + for _, n := range nodes { + hashed, cached := h.hash(n) // subsequent nodes are cached + proof.Put(hashed.(*hashNode).Felt, cached) + } + + return nil +} + +// GetRangeProof generates a range proof for the given range of keys. +// The proof contains the proof nodes on the path from the root to the closest ancestor of the left and right keys. +func (t *Trie) GetRangeProof(leftKey, rightKey *felt.Felt, proofSet *ProofNodeSet) error { + err := t.Prove(leftKey, proofSet) + if err != nil { + return err + } + + // If they are the same key, don't need to generate the proof again + if leftKey.Equal(rightKey) { + return nil + } + + err = t.Prove(rightKey, proofSet) + if err != nil { + return err + } + + return nil +} + +// VerifyProof verifies that a proof path is valid for a given key in a binary trie. +// It walks through the proof nodes, verifying each step matches the expected path to reach the key. +// +// The proof is considered invalid if: +// - Any proof node is missing from the node set +// - Any node's computed hash doesn't match its expected hash +// - The path bits don't match the key bits +// - The proof ends before processing all key bits +func VerifyProof(root, key *felt.Felt, proof *ProofNodeSet, hash crypto.HashFn) (felt.Felt, error) { + keyBits := new(Path).SetFelt(contractClassTrieHeight, key) + expected := *root + h := newHasher(hash, false) + + for { + node, ok := proof.Get(expected) + if !ok { + return felt.Zero, fmt.Errorf("proof node not found, expected hash: %s", expected.String()) + } + + nHash, _ := h.hash(node) + + // Verify the hash matches + if !nHash.(*hashNode).Felt.Equal(&expected) { + return felt.Zero, fmt.Errorf("proof node hash mismatch, expected hash: %s, got hash: %s", expected.String(), nHash.String()) + } + + child := get(node, keyBits, false) + switch cld := child.(type) { + case nil: + return felt.Zero, nil + case *hashNode: + expected = cld.Felt + case *valueNode: + return cld.Felt, nil + } + } +} + +func get(rn node, key *Path, skipResolved bool) node { + for { + switch n := rn.(type) { + case *edgeNode: + if !n.pathMatches(key) { + return nil + } + rn = n.child + key.LSBs(key, n.path.Len()) + case *binaryNode: + bit := key.MSB() + rn = n.children[bit] + key.LSBs(key, 1) + case *hashNode: + return n + case *valueNode: + return n + case nil: + return nil + } + + if !skipResolved { + return rn + } + } +} diff --git a/core/trie2/proof_test.go b/core/trie2/proof_test.go new file mode 100644 index 0000000000..929a75dad6 --- /dev/null +++ b/core/trie2/proof_test.go @@ -0,0 +1 @@ +package trie2 diff --git a/core/trie2/reader.go b/core/trie2/reader.go deleted file mode 100644 index a804daa5be..0000000000 --- a/core/trie2/reader.go +++ /dev/null @@ -1,24 +0,0 @@ -package trie2 - -import ( - "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/db" -) - -type NodeReader interface { - Node(owner *felt.Felt, path *Path, hash *felt.Felt) ([]byte, error) -} - -type nodeReader struct { - txn db.Transaction - trieType TrieType -} - -func (n *nodeReader) Node(owner *felt.Felt, path *Path, hash *felt.Felt) ([]byte, error) { - panic("implement me") -} - -type trieReader struct { - owner *felt.Felt - reader NodeReader -} diff --git a/core/trie2/trie.go b/core/trie2/trie.go index b51384714e..2085e0774a 100644 --- a/core/trie2/trie.go +++ b/core/trie2/trie.go @@ -1,25 +1,43 @@ package trie2 import ( + "bytes" "fmt" "github.com/NethermindEth/juno/core/crypto" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/trie2/triedb" "github.com/NethermindEth/juno/core/trie2/trienode" - "github.com/NethermindEth/juno/core/trie2/utils" + "github.com/NethermindEth/juno/core/trie2/trieutils" "github.com/NethermindEth/juno/db" ) -type Path = utils.BitArray +const contractClassTrieHeight = 251 + +var emptyRoot = felt.Felt{} + +type Path = trieutils.BitArray type Trie struct { - txn db.Transaction - owner felt.Felt + // Height of the trie height uint8 - root node - reader interface{} // TODO(weiihann): implement reader - // committed bool + + // The owner of the trie, only used for contract trie. If not empty, this is a storage trie. + owner felt.Felt + + // The root node of the trie + root node + + // Hash function used to hash the trie nodes hashFn crypto.HashFn + + // The underlying database to store and retrieve trie nodes + db *triedb.Database + + // Check if the trie has been committed. Trie is unusable once committed. + committed bool + + // Maintains the records of trie changes, ensuring all nodes are modified or garbage collected properly tracer *tracer // Tracks the number of leaves inserted since the last hashing operation @@ -29,48 +47,80 @@ type Trie struct { pendingUpdates int } -// TODO(weiihann): implement this -func NewTrie(height uint8, hashFn crypto.HashFn) *Trie { - return &Trie{ +func New(id *ID, height uint8, hashFn crypto.HashFn, txn db.Transaction) (*Trie, error) { + database := triedb.New(txn, id.Bucket()) + tr := &Trie{ + owner: id.Owner, height: height, hashFn: hashFn, + db: database, tracer: newTracer(), } + + if id.Root != emptyRoot { + root, err := tr.resolveNode(&hashNode{Felt: id.Root}, Path{}) + if err != nil { + return nil, err + } + tr.root = root + } + return tr, nil } // Modifies or inserts a key-value pair in the trie. // If value is zero, the key is deleted from the trie. func (t *Trie) Update(key, value *felt.Felt) error { - // if t.commited { - // return ErrCommitted - // } + if t.committed { + return ErrCommitted + } + if err := t.update(key, value); err != nil { return err } t.pendingUpdates++ t.pendingHashes++ + return nil } // Retrieves the value associated with the given key. // Returns felt.Zero if the key doesn't exist. // May update the trie's internal structure if nodes need to be resolved. -func (t *Trie) Get(key *felt.Felt) (*felt.Felt, error) { +func (t *Trie) Get(key *felt.Felt) (felt.Felt, error) { + if t.committed { + return felt.Zero, ErrCommitted + } + k := t.FeltToPath(key) - // TODO(weiihann): get the value directly from the reader + // We first check if the value node exists in the trie database directly + v, err := t.resolveNode(&hashNode{}, k) + if _, ok := v.(*valueNode); ok && err == nil { + return v.(*valueNode).Felt, nil + } + + // Otherwise, we need to traverse the trie to find the value node + var ret felt.Felt val, root, didResolve, err := t.get(t.root, new(Path), &k) // In Starknet, a non-existent key is mapped to felt.Zero if val == nil { - val = &felt.Zero + ret = felt.Zero + } else { + ret = *val } + if err == nil && didResolve { t.root = root } - return val, err + + return ret, err } // Removes the given key from the trie. func (t *Trie) Delete(key *felt.Felt) error { + if t.committed { + return ErrCommitted + } + k := t.FeltToPath(key) _, n, err := t.delete(t.root, new(Path), &k) if err != nil { @@ -89,11 +139,16 @@ func (t *Trie) Hash() felt.Felt { return hash.(*hashNode).Felt } -func (t *Trie) Commit() felt.Felt { +// Collapses the trie into a single hash node and flush the node changes to the database. +func (t *Trie) Commit() (felt.Felt, error) { + defer func() { + t.committed = true + }() + rootHash := t.Hash() if hashedNode, dirty := t.root.cache(); !dirty { t.root = hashedNode - return rootHash + return rootHash, nil } nodes := trienode.NewNodeSet(t.owner) @@ -103,17 +158,31 @@ func (t *Trie) Commit() felt.Felt { t.root = newCollector(nodes).Collect(t.root, t.pendingUpdates > 100) // TODO(weiihann): 100 is arbitrary t.pendingUpdates = 0 - return rootHash + + err := nodes.ForEach(true, func(key trieutils.BitArray, node *trienode.Node) error { + if node.IsDeleted() { + return t.db.Delete(t.owner, key) + } + return t.db.Put(t.owner, key, node.Blob()) + }) + if err != nil { + return felt.Felt{}, err + } + + return rootHash, nil } func (t *Trie) Copy() *Trie { return &Trie{ - txn: t.txn, - owner: t.owner, - height: t.height, - root: t.root, - hashFn: t.hashFn, - tracer: t.tracer.copy(), + height: t.height, + owner: t.owner, + root: t.root, + hashFn: t.hashFn, + committed: t.committed, + db: t.db, + tracer: t.tracer.copy(), + pendingHashes: t.pendingHashes, + pendingUpdates: t.pendingUpdates, } } @@ -131,15 +200,20 @@ func (t *Trie) get(n node, prefix, key *Path) (*felt.Felt, node, bool, error) { return val, n, didResolve, err case *binaryNode: bit := key.MSB() - val, child, didResolve, err := t.get(n.children[bit], new(Path).SetBit(bit), key.LSBs(key, 1)) + val, child, didResolve, err := t.get(n.children[bit], new(Path).AppendBit(prefix, bit), key.LSBs(key, 1)) if err == nil && didResolve { n = n.copy() n.children[bit] = child } return val, n, didResolve, err - case hashNode: - panic("TODO(weiihann): implement me") - case valueNode: + case *hashNode: + child, err := t.resolveNode(n, *key) + if err != nil { + return nil, nil, false, err + } + value, newNode, _, err := t.get(child, prefix, key) + return value, newNode, true, err + case *valueNode: return &n.Felt, n, false, nil case nil: return nil, nil, false, nil @@ -159,7 +233,7 @@ func (t *Trie) update(key, value *felt.Felt) error { } t.root = n } else { - _, n, err := t.insert(t.root, new(Path), &k, valueNode{Felt: *value}) + _, n, err := t.insert(t.root, new(Path), &k, &valueNode{Felt: *value}) if err != nil { return err } @@ -171,8 +245,8 @@ func (t *Trie) update(key, value *felt.Felt) error { func (t *Trie) insert(n node, prefix, key *Path, value node) (bool, node, error) { // We reach the end of the key if key.Len() == 0 { - if v, ok := n.(valueNode); ok { - vFelt := value.(valueNode).Felt + if v, ok := n.(*valueNode); ok { + vFelt := value.(*valueNode).Felt return v.Equal(&vFelt), value, nil } return true, value, nil @@ -226,7 +300,7 @@ func (t *Trie) insert(n node, prefix, key *Path, value node) (bool, node, error) // Go to the child node based on the MSB of the key bit := key.MSB() dirty, newNode, err := t.insert( - n.children[bit], new(Path).Append(prefix, new(Path).SetBit(bit)), new(Path).LSBs(key, 1), value, + n.children[bit], new(Path).AppendBit(prefix, bit), new(Path).LSBs(key, 1), value, ) if !dirty || err != nil { return false, n, err @@ -244,8 +318,16 @@ func (t *Trie) insert(n node, prefix, key *Path, value node) (bool, node, error) } // Otherwise, return a new edge node with the Path being the key and the value as the child return true, &edgeNode{path: key, child: value, flags: newFlag()}, nil - case hashNode: - panic("TODO(weiihann): implement me") + case *hashNode: + child, err := t.resolveNode(n, *key) + if err != nil { + return false, n, err + } + dirty, newNode, err := t.insert(child, prefix, key, value) + if !dirty || err != nil { + return false, child, err + } + return true, newNode, nil default: panic(fmt.Sprintf("unknown node type: %T", n)) } @@ -265,7 +347,7 @@ func (t *Trie) delete(n node, prefix, key *Path) (bool, node, error) { return true, nil, nil } - // Otherwise, key is longer than current node Path, so we need to delete the child. + // Otherwise, key is longer than current node path, so we need to delete the child. // Child can never be nil because it's guaranteed that we have at least 2 other values in the subtrie. keyPrefix := new(Path).MSBs(key, n.path.Len()) dirty, child, err := t.delete(n.child, new(Path).Append(prefix, keyPrefix), key.LSBs(key, n.path.Len())) @@ -310,17 +392,43 @@ func (t *Trie) delete(n node, prefix, key *Path) (bool, node, error) { // other child is not an edge node, create a new edge node with the bit prefix as the Path // containing the other child as the child return true, &edgeNode{path: bitPrefix, child: n.children[other], flags: newFlag()}, nil - case valueNode: + case *valueNode: return true, nil, nil case nil: return false, nil, nil - case hashNode: - panic("TODO(weiihann): implement me") + case *hashNode: + child, err := t.resolveNode(n, *key) + if err != nil { + return false, nil, err + } + + dirty, newNode, err := t.delete(child, prefix, key) + if !dirty || err != nil { + return false, child, err + } + return true, newNode, nil default: panic(fmt.Sprintf("unknown node type: %T", n)) } } +func (t *Trie) resolveNode(hash *hashNode, path Path) (node, error) { + buf := bufferPool.Get().(*bytes.Buffer) + buf.Reset() + defer func() { + buf.Reset() + bufferPool.Put(buf) + }() + + _, err := t.db.Get(buf, t.owner, path) + if err != nil { + return nil, err + } + + blob := buf.Bytes() + return decodeNode(blob, hash.Felt, path.Len(), t.height) +} + func (t *Trie) hashRoot() (node, node) { h := newHasher(t.hashFn, t.pendingHashes > 100) // TODO(weiihann): 100 is arbitrary hashed, cached := h.hash(t.root) @@ -342,3 +450,11 @@ func (t *Trie) String() string { } return t.root.String() } + +func NewEmptyPedersen() (*Trie, error) { + return New(TrieID(felt.Zero), contractClassTrieHeight, crypto.Pedersen, db.NewMemTransaction()) +} + +func NewEmptyPoseidon() (*Trie, error) { + return New(TrieID(felt.Zero), contractClassTrieHeight, crypto.Poseidon, db.NewMemTransaction()) +} diff --git a/core/trie2/trie_test.go b/core/trie2/trie_test.go index edfbbe0a8e..360ece6c00 100644 --- a/core/trie2/trie_test.go +++ b/core/trie2/trie_test.go @@ -4,68 +4,64 @@ import ( "math/rand" "testing" - "github.com/NethermindEth/juno/core/crypto" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/utils" "github.com/stretchr/testify/require" ) func TestUpdate(t *testing.T) { - tr, records := nonRandomTrie(t, 1000) - - for _, record := range records { - err := tr.Update(record.key, record.value) - require.NoError(t, err) - - got, err := tr.Get(record.key) - require.NoError(t, err) - require.Equal(t, record.value, got) + verifyRecords := func(t *testing.T, tr *Trie, records []*keyValue) { + t.Helper() + for _, record := range records { + got, err := tr.Get(record.key) + require.NoError(t, err) + require.True(t, got.Equal(record.value), "expected %v, got %v", record.value, got) + } } -} -func TestUpdateRandom(t *testing.T) { - tr, records := randomTrie(t, 1000) - - for _, record := range records { - got, err := tr.Get(record.key) - require.NoError(t, err) + t.Run("sequential", func(t *testing.T) { + tr, records := nonRandomTrie(t, 10000) + verifyRecords(t, tr, records) + }) - if !got.Equal(record.value) { - t.Fatalf("expected %s, got %s", record.value, got) - } - } + t.Run("random", func(t *testing.T) { + tr, records := randomTrie(t, 10000) + verifyRecords(t, tr, records) + }) } func TestDelete(t *testing.T) { - tr, records := nonRandomTrie(t, 10000) - - for _, record := range records { - err := tr.Delete(record.key) - require.NoError(t, err) + verifyDelete := func(t *testing.T, tr *Trie, records []*keyValue) { + t.Helper() + for _, record := range records { + err := tr.Delete(record.key) + require.NoError(t, err) - got, err := tr.Get(record.key) - require.NoError(t, err) - require.Equal(t, got, &felt.Zero) + got, err := tr.Get(record.key) + require.NoError(t, err) + require.True(t, got.Equal(&felt.Zero), "expected %v, got %v", &felt.Zero, got) + } } -} - -func TestDeleteRandom(t *testing.T) { - tr, records := randomTrie(t, 10000) - for i := len(records) - 1; i >= 0; i-- { - err := tr.Delete(records[i].key) - require.NoError(t, err) + t.Run("sequential", func(t *testing.T) { + tr, records := nonRandomTrie(t, 10000) + verifyDelete(t, tr, records) + }) - got, err := tr.Get(records[i].key) - require.NoError(t, err) - require.Equal(t, got, &felt.Zero) - } + t.Run("random", func(t *testing.T) { + tr, records := randomTrie(t, 10000) + // Delete in reverse order for random case + for i, j := 0, len(records)-1; i < j; i, j = i+1, j-1 { + records[i], records[j] = records[j], records[i] + } + verifyDelete(t, tr, records) + }) } // The expected hashes are taken from Pathfinder's tests func TestHash(t *testing.T) { t.Run("one leaf", func(t *testing.T) { - tr := NewTrie(251, crypto.Pedersen) + tr, _ := NewEmptyPedersen() err := tr.Update(new(felt.Felt).SetUint64(1), new(felt.Felt).SetUint64(2)) require.NoError(t, err) hash := tr.Hash() @@ -75,7 +71,7 @@ func TestHash(t *testing.T) { }) t.Run("two leaves", func(t *testing.T) { - tr := NewTrie(251, crypto.Pedersen) + tr, _ := NewEmptyPedersen() err := tr.Update(new(felt.Felt).SetUint64(0), new(felt.Felt).SetUint64(2)) require.NoError(t, err) err = tr.Update(new(felt.Felt).SetUint64(1), new(felt.Felt).SetUint64(3)) @@ -87,7 +83,7 @@ func TestHash(t *testing.T) { }) t.Run("three leaves", func(t *testing.T) { - tr := NewTrie(251, crypto.Pedersen) + tr, _ := NewEmptyPedersen() keys := []*felt.Felt{ new(felt.Felt).SetUint64(16), @@ -132,7 +128,7 @@ func TestHash(t *testing.T) { new(felt.Felt).SetUint64(5), } - tr := NewTrie(251, crypto.Pedersen) + tr, _ := NewEmptyPedersen() for i := range keys { err := tr.Update(keys[i], vals[i]) require.NoError(t, err) @@ -160,7 +156,7 @@ func TestHash(t *testing.T) { utils.HexToFelt(t, "0xdd"), } - tr := NewTrie(251, crypto.Pedersen) + tr, _ := NewEmptyPedersen() for i := range keys { err := tr.Update(keys[i], vals[i]) require.NoError(t, err) @@ -173,11 +169,24 @@ func TestHash(t *testing.T) { } func TestCommit(t *testing.T) { - tr, _ := nonRandomTrie(t, 1000) - tr2 := tr.Copy() + t.Run("sequential", func(t *testing.T) { + tr, _ := nonRandomTrie(t, 10000) + + _, err := tr.Commit() + require.NoError(t, err) + }) + + t.Run("random", func(t *testing.T) { + tr, _ := randomTrie(t, 10000) + + _, err := tr.Commit() + require.NoError(t, err) + }) +} - root := tr.Commit() - require.Equal(t, root, tr2.Hash()) +func TestTrieOpsRandom(t *testing.T) { + t.Skip() + panic("implement me") } type keyValue struct { @@ -186,7 +195,7 @@ type keyValue struct { } func nonRandomTrie(t *testing.T, numKeys int) (*Trie, []*keyValue) { - tr := NewTrie(251, crypto.Pedersen) + tr, _ := NewEmptyPedersen() records := make([]*keyValue, numKeys) for i := 1; i < numKeys+1; i++ { @@ -202,7 +211,7 @@ func nonRandomTrie(t *testing.T, numKeys int) (*Trie, []*keyValue) { func randomTrie(t testing.TB, n int) (*Trie, []*keyValue) { rrand := rand.New(rand.NewSource(3)) - tr := NewTrie(251, crypto.Pedersen) + tr, _ := NewEmptyPedersen() records := make([]*keyValue, n) for i := 0; i < n; i++ { @@ -231,7 +240,7 @@ func buildTrie(t *testing.T, records []*keyValue) *Trie { t.Fatal("records must have at least one element") } - tempTrie := NewTrie(251, crypto.Pedersen) + tempTrie, _ := NewEmptyPedersen() for _, record := range records { err := tempTrie.Update(record.key, record.value) diff --git a/core/trie2/types.go b/core/trie2/types.go deleted file mode 100644 index 23ce5acc3f..0000000000 --- a/core/trie2/types.go +++ /dev/null @@ -1,8 +0,0 @@ -package trie2 - -type TrieType uint8 - -const ( - ClassTrie TrieType = iota + 1 - StateTrie -) From d48eac681e56a72469eb077e7bc8c5f928a8747e Mon Sep 17 00:00:00 2001 From: weiihann Date: Fri, 17 Jan 2025 15:37:54 +0800 Subject: [PATCH 22/44] bitarray changes --- core/trie2/trieutils/bitarray.go | 64 ++++++++++++++++++++++---------- 1 file changed, 44 insertions(+), 20 deletions(-) diff --git a/core/trie2/trieutils/bitarray.go b/core/trie2/trieutils/bitarray.go index 63fc134593..a1cfc60d80 100644 --- a/core/trie2/trieutils/bitarray.go +++ b/core/trie2/trieutils/bitarray.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/binary" "encoding/hex" + "errors" "fmt" "math" "math/bits" @@ -52,7 +53,6 @@ func (b *BitArray) Len() uint8 { func (b *BitArray) Bytes() [32]byte { var res [32]byte - b.truncateToLength() binary.BigEndian.PutUint64(res[0:8], b.words[3]) binary.BigEndian.PutUint64(res[8:16], b.words[2]) binary.BigEndian.PutUint64(res[16:24], b.words[1]) @@ -77,31 +77,28 @@ func (b *BitArray) LSBsFromLSB(x *BitArray, n uint8) *BitArray { return b.Set(x) } - b.Set(x) b.len = n - // Clear all words beyond what's needed switch { case n == 0: b.words = [4]uint64{0, 0, 0, 0} case n <= 64: - mask := maxUint64 >> (64 - n) - b.words[0] &= mask - b.words[1] = 0 - b.words[2] = 0 - b.words[3] = 0 + b.words[0] = x.words[0] & (maxUint64 >> (64 - n)) + b.words[1], b.words[2], b.words[3] = 0, 0, 0 case n <= 128: - mask := maxUint64 >> (128 - n) - b.words[1] &= mask - b.words[2] = 0 - b.words[3] = 0 + b.words[0] = x.words[0] + b.words[1] = x.words[1] & (maxUint64 >> (128 - n)) + b.words[2], b.words[3] = 0, 0 case n <= 192: - mask := maxUint64 >> (192 - n) - b.words[2] &= mask + b.words[0] = x.words[0] + b.words[1] = x.words[1] + b.words[2] = x.words[2] & (maxUint64 >> (192 - n)) b.words[3] = 0 default: - mask := maxUint64 >> (256 - uint16(n)) - b.words[3] &= mask + b.words[0] = x.words[0] + b.words[1] = x.words[1] + b.words[2] = x.words[2] + b.words[3] = x.words[3] & (maxUint64 >> (256 - uint16(n))) } return b @@ -435,12 +432,25 @@ func (b *BitArray) Write(buf *bytes.Buffer) (int, error) { // Example: // // [0x0A, 0x03, 0xFF] -> BitArray{len: 10, words: [4]uint64{0x03FF}} -func (b *BitArray) UnmarshalBinary(data []byte) { - b.len = data[0] +func (b *BitArray) UnmarshalBinary(data []byte) error { + if len(data) == 0 { + return errors.New("empty data") + } + + length := data[0] + byteCount := (uint(length) + 7) / 8 // Round up to nearest byte + + if len(data) < int(byteCount)+1 { + return fmt.Errorf("invalid data length: got %d bytes, expected %d", len(data), byteCount+1) + } + + b.len = length var bs [32]byte - copy(bs[32-b.byteCount():], data[1:]) + copy(bs[32-byteCount:], data[1:]) b.setBytes32(bs[:]) + + return nil } // Sets the bit array to the same value as x. @@ -531,7 +541,7 @@ func (b *BitArray) setFelt(f *felt.Felt) { } func (b *BitArray) setBytes32(data []byte) { - _ = data[31] + _ = data[31] // bound check hint, see https://golang.org/issue/14808 b.words[3] = binary.BigEndian.Uint64(data[0:8]) b.words[2] = binary.BigEndian.Uint64(data[8:16]) b.words[1] = binary.BigEndian.Uint64(data[16:24]) @@ -589,6 +599,20 @@ func (b *BitArray) clear() *BitArray { // Truncates the bit array to the specified length, ensuring that any unused bits are all zeros. // +// Example: +// +// b := &BitArray{ +// len: 5, +// words: [4]uint64{ +// 0xFFFFFFFFFFFFFFFF, // Before: all bits are 1 +// 0x0, 0x0, 0x0, +// }, +// } +// b.truncateToLength() +// // After: only first 5 bits remain +// // words[0] = 0x000000000000001F +// // words[1..3] = 0x0 +// //nolint:mnd func (b *BitArray) truncateToLength() { switch { From e2c367274f3d2394cd5258bdab76cbef994ce9c6 Mon Sep 17 00:00:00 2001 From: weiihann Date: Mon, 20 Jan 2025 11:07:29 +0800 Subject: [PATCH 23/44] bitarray changes --- core/trie2/trieutils/bitarray.go | 171 +++++++++++++++++++++++++++---- 1 file changed, 149 insertions(+), 22 deletions(-) diff --git a/core/trie2/trieutils/bitarray.go b/core/trie2/trieutils/bitarray.go index a1cfc60d80..3b404cb37c 100644 --- a/core/trie2/trieutils/bitarray.go +++ b/core/trie2/trieutils/bitarray.go @@ -13,10 +13,9 @@ import ( ) const ( + maxUint64 = uint64(math.MaxUint64) // 0xFFFFFFFFFFFFFFFF + maxUint8 = uint8(math.MaxUint8) MaxBitArraySize = 33 // (1 + 4 * 8) bytes - - maxUint64 = uint64(math.MaxUint64) // 0xFFFFFFFFFFFFFFFF - maxUint8 = uint8(math.MaxUint8) ) var emptyBitArray = new(BitArray) @@ -40,8 +39,8 @@ func NewBitArray(length uint8, val uint64) BitArray { // Returns the felt representation of the bit array. func (b *BitArray) Felt() felt.Felt { var f felt.Felt - bs := b.Bytes() - f.SetBytes(bs[:]) + bt := b.Bytes() + f.SetBytes(bt[:]) return f } @@ -262,10 +261,8 @@ func (b *BitArray) Rsh(x *BitArray, n uint8) *BitArray { // //nolint:mnd func (b *BitArray) Lsh(x *BitArray, n uint8) *BitArray { - b.Set(x) - if x.len == 0 || n == 0 { - return b + return b.Set(x) } // If the result will overflow, we set the length to the max length @@ -277,8 +274,6 @@ func (b *BitArray) Lsh(x *BitArray, n uint8) *BitArray { } switch { - case n == 0: - return b case n >= 192: b.lsh192(x) n -= 192 @@ -295,6 +290,7 @@ func (b *BitArray) Lsh(x *BitArray, n uint8) *BitArray { b.words[2] = (b.words[2] << n) | (b.words[1] >> (64 - n)) b.words[1] <<= n default: + b.words[3], b.words[2], b.words[1], b.words[0] = x.words[3], x.words[2], x.words[1], x.words[0] b.words[3] = (b.words[3] << n) | (b.words[2] >> (64 - n)) b.words[2] = (b.words[2] << n) | (b.words[1] >> (64 - n)) b.words[1] = (b.words[1] << n) | (b.words[0] >> (64 - n)) @@ -312,22 +308,20 @@ func (b *BitArray) Lsh(x *BitArray, n uint8) *BitArray { // y = 111 (len=3) // Append(x,y) = 000111 (len=6) func (b *BitArray) Append(x, y *BitArray) *BitArray { - if x.len == 0 { + if x.len == 0 || y.len == maxUint8 { return b.Set(y) } if y.len == 0 { return b.Set(x) } - // First copy x - b.Set(x) - // Then shift left by y's length and OR with y - return b.Lsh(b, y.len).Or(b, y) + return b.Lsh(x, y.len).Or(b, y) } +// Sets the bit array to the concatenation of x and a single bit. func (b *BitArray) AppendBit(x *BitArray, bit uint8) *BitArray { - return b.Append(x, new(BitArray).SetBit(bit)) + return b.Append(b, new(BitArray).SetBit(bit)) } // Sets the bit array to x | y and returns the bit array. @@ -481,8 +475,124 @@ func (b *BitArray) SetFelt251(f *felt.Felt) *BitArray { // Interprets the data as the big-endian bytes, sets the bit array to that value and returns it. // If the data is larger than 32 bytes, only the first 32 bytes are used. +// +//nolint:mnd,funlen,gocyclo func (b *BitArray) SetBytes(length uint8, data []byte) *BitArray { - b.setBytes32(data) + switch l := len(data); l { + case 0: + b.clear() + case 1: + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, 0, uint64(data[0]) + case 2: + _ = data[1] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, 0, uint64(binary.BigEndian.Uint16(data[0:2])) + case 3: + _ = data[2] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, 0, uint64(binary.BigEndian.Uint16(data[1:3]))|uint64(data[0])<<16 + case 4: + _ = data[3] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, 0, uint64(binary.BigEndian.Uint32(data[0:4])) + case 5: + _ = data[4] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, 0, bigEndianUint40(data[0:5]) + case 6: + _ = data[5] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, 0, bigEndianUint48(data[0:6]) + case 7: + _ = data[6] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, 0, bigEndianUint56(data[0:7]) + case 8: + _ = data[7] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, 0, binary.BigEndian.Uint64(data[0:8]) + case 9: + _ = data[8] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, uint64(data[0]), binary.BigEndian.Uint64(data[1:9]) + case 10: + _ = data[9] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, uint64(binary.BigEndian.Uint16(data[0:2])), binary.BigEndian.Uint64(data[2:10]) + case 11: + _ = data[10] + b.words[3], b.words[2] = 0, 0 + b.words[1], b.words[0] = uint64(binary.BigEndian.Uint16(data[1:3]))|uint64(data[0])<<16, binary.BigEndian.Uint64(data[3:11]) + case 12: + _ = data[11] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, uint64(binary.BigEndian.Uint32(data[0:4])), binary.BigEndian.Uint64(data[4:12]) + case 13: + _ = data[12] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, bigEndianUint40(data[0:5]), binary.BigEndian.Uint64(data[5:13]) + case 14: + _ = data[13] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, bigEndianUint48(data[0:6]), binary.BigEndian.Uint64(data[6:14]) + case 15: + _ = data[14] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, bigEndianUint56(data[0:7]), binary.BigEndian.Uint64(data[7:15]) + case 16: + _ = data[15] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, binary.BigEndian.Uint64(data[0:8]), binary.BigEndian.Uint64(data[8:16]) + case 17: + _ = data[16] + b.words[3], b.words[2] = 0, uint64(data[0]) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[1:9]), binary.BigEndian.Uint64(data[9:17]) + case 18: + _ = data[17] + b.words[3], b.words[2] = 0, uint64(binary.BigEndian.Uint16(data[0:2])) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[2:10]), binary.BigEndian.Uint64(data[10:18]) + case 19: + _ = data[18] + b.words[3], b.words[2] = 0, uint64(binary.BigEndian.Uint16(data[1:3]))|uint64(data[0])<<16 + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[3:11]), binary.BigEndian.Uint64(data[11:19]) + case 20: + _ = data[19] + b.words[3], b.words[2] = 0, uint64(binary.BigEndian.Uint32(data[0:4])) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[4:12]), binary.BigEndian.Uint64(data[12:20]) + case 21: + _ = data[20] + b.words[3], b.words[2] = 0, bigEndianUint40(data[0:5]) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[5:13]), binary.BigEndian.Uint64(data[13:21]) + case 22: + _ = data[21] + b.words[3], b.words[2] = 0, bigEndianUint48(data[0:6]) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[6:14]), binary.BigEndian.Uint64(data[14:22]) + case 23: + _ = data[22] + b.words[3], b.words[2] = 0, bigEndianUint56(data[0:7]) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[7:15]), binary.BigEndian.Uint64(data[15:23]) + case 24: + _ = data[23] + b.words[3], b.words[2] = 0, binary.BigEndian.Uint64(data[0:8]) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[8:16]), binary.BigEndian.Uint64(data[16:24]) + case 25: + _ = data[24] + b.words[3], b.words[2] = uint64(data[0]), binary.BigEndian.Uint64(data[1:9]) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[9:17]), binary.BigEndian.Uint64(data[17:25]) + case 26: + _ = data[25] + b.words[3], b.words[2] = uint64(binary.BigEndian.Uint16(data[0:2])), binary.BigEndian.Uint64(data[2:10]) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[10:18]), binary.BigEndian.Uint64(data[18:26]) + case 27: + _ = data[26] + b.words[3] = uint64(binary.BigEndian.Uint16(data[1:3])) | uint64(data[0])<<16 + b.words[2] = binary.BigEndian.Uint64(data[3:11]) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[11:19]), binary.BigEndian.Uint64(data[19:27]) + case 28: + _ = data[27] + b.words[3], b.words[2] = uint64(binary.BigEndian.Uint32(data[0:4])), binary.BigEndian.Uint64(data[4:12]) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[12:20]), binary.BigEndian.Uint64(data[20:28]) + case 29: + _ = data[28] + b.words[3], b.words[2] = bigEndianUint40(data[0:5]), binary.BigEndian.Uint64(data[5:13]) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[13:21]), binary.BigEndian.Uint64(data[21:29]) + case 30: + _ = data[29] + b.words[3], b.words[2] = bigEndianUint48(data[0:6]), binary.BigEndian.Uint64(data[6:14]) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[14:22]), binary.BigEndian.Uint64(data[22:30]) + case 31: + _ = data[30] + b.words[3], b.words[2] = bigEndianUint56(data[0:7]), binary.BigEndian.Uint64(data[7:15]) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[15:23]), binary.BigEndian.Uint64(data[23:31]) + default: + b.setBytes32(data) + } b.len = length b.truncateToLength() return b @@ -519,17 +629,17 @@ func (b *BitArray) Copy() BitArray { // Returns the encoded string representation of the bit array. func (b *BitArray) EncodedString() string { var res []byte - bs := b.Bytes() + bt := b.Bytes() res = append(res, b.len) - res = append(res, bs[:]...) + res = append(res, bt[:]...) return string(res) } // Returns a string representation of the bit array. // This is typically used for logging or debugging. func (b *BitArray) String() string { - bs := b.Bytes() - return fmt.Sprintf("(%d) %s", b.len, hex.EncodeToString(bs[:])) + bt := b.Bytes() + return fmt.Sprintf("(%d) %s", b.len, hex.EncodeToString(bt[:])) } func (b *BitArray) setFelt(f *felt.Felt) { @@ -552,7 +662,6 @@ func (b *BitArray) setBytes32(data []byte) { // It rounds up to the nearest byte. func (b *BitArray) byteCount() uint { const bits8 = 8 - // Cast to uint16 to avoid overflow return (uint(b.len) + (bits8 - 1)) / uint(bits8) } @@ -686,3 +795,21 @@ func (b *BitArray) Cmp(x *BitArray) int { return 1 } + +func bigEndianUint40(b []byte) uint64 { + _ = b[4] // bounds check hint to compiler; see golang.org/issue/14808 + return uint64(b[4]) | uint64(b[3])<<8 | uint64(b[2])<<16 | uint64(b[1])<<24 | + uint64(b[0])<<32 +} + +func bigEndianUint48(b []byte) uint64 { + _ = b[5] // bounds check hint to compiler; see golang.org/issue/14808 + return uint64(b[5]) | uint64(b[4])<<8 | uint64(b[3])<<16 | uint64(b[2])<<24 | + uint64(b[1])<<32 | uint64(b[0])<<40 +} + +func bigEndianUint56(b []byte) uint64 { + _ = b[6] // bounds check hint to compiler; see golang.org/issue/14808 + return uint64(b[6]) | uint64(b[5])<<8 | uint64(b[4])<<16 | uint64(b[3])<<24 | + uint64(b[2])<<32 | uint64(b[1])<<40 | uint64(b[0])<<48 +} From 81ed0866d9997c53d554764c5b477b863adcad49 Mon Sep 17 00:00:00 2001 From: weiihann Date: Mon, 20 Jan 2025 18:33:19 +0800 Subject: [PATCH 24/44] at this point range proof tests pass 50% --- core/trie/proof_test.go | 1 + core/trie2/errors.go | 5 +- core/trie2/proof.go | 427 +++++++++++++++++- core/trie2/proof_test.go | 744 +++++++++++++++++++++++++++++++ core/trie2/tracer.go | 18 +- core/trie2/trie.go | 36 +- core/trie2/trie_test.go | 10 +- core/trie2/trieutils/bitarray.go | 55 +++ 8 files changed, 1269 insertions(+), 27 deletions(-) diff --git a/core/trie/proof_test.go b/core/trie/proof_test.go index 5a43932042..dcbc9dc43e 100644 --- a/core/trie/proof_test.go +++ b/core/trie/proof_test.go @@ -284,6 +284,7 @@ func TestRangeProofWithInvalidNonExistentProof(t *testing.T) { } _, err = trie.VerifyRangeProof(root, first, keys, values, proof) + t.Log(err) require.Error(t, err) } diff --git a/core/trie2/errors.go b/core/trie2/errors.go index 306fe62059..aace1a44d9 100644 --- a/core/trie2/errors.go +++ b/core/trie2/errors.go @@ -2,4 +2,7 @@ package trie2 import "errors" -var ErrCommitted = errors.New("trie is committed") +var ( + ErrCommitted = errors.New("trie is committed") + ErrEmptyRange = errors.New("empty range") +) diff --git a/core/trie2/proof.go b/core/trie2/proof.go index 79be95f4e3..9b5132a5d2 100644 --- a/core/trie2/proof.go +++ b/core/trie2/proof.go @@ -1,6 +1,7 @@ package trie2 import ( + "errors" "fmt" "github.com/NethermindEth/juno/core/crypto" @@ -27,8 +28,8 @@ func (t *Trie) Prove(key *felt.Felt, proof *ProofNodeSet) error { k := &path var ( - prefix *Path nodes []node + prefix = new(Path) rn = t.root ) @@ -45,6 +46,7 @@ func (t *Trie) Prove(key *felt.Felt, proof *ProofNodeSet) error { nodes = append(nodes, n) case *binaryNode: bit := k.MSB() + rn = n.children[bit] prefix.AppendBit(prefix, bit) k.LSBs(k, 1) nodes = append(nodes, n) @@ -59,8 +61,9 @@ func (t *Trie) Prove(key *felt.Felt, proof *ProofNodeSet) error { } } + // TODO: ideally Hash() should be called before Prove() so that the hashes are cached + // There should be a better way to do this h := newHasher(t.hashFn, false) - for _, n := range nodes { hashed, cached := h.hash(n) // subsequent nodes are cached proof.Put(hashed.(*hashNode).Felt, cached) @@ -124,8 +127,428 @@ func VerifyProof(root, key *felt.Felt, proof *ProofNodeSet, hash crypto.HashFn) expected = cld.Felt case *valueNode: return cld.Felt, nil + case *edgeNode, *binaryNode: + if hash, _ := cld.cache(); hash != nil { + expected = hash.Felt + } + } + } +} + +// VerifyRangeProof checks the validity of given key-value pairs and range proof against a provided root hash. +// The key-value pairs should be consecutive (no gaps) and monotonically increasing. +// The range proof contains two edge proofs: one for the first key and another for the last key. +// Both edge proofs can be for existent or non-existent keys. +// This function handles the following special cases: +// +// - All elements proof: The proof can be nil if the range includes all leaves in the trie. +// - Single element proof: Both left and right edge proofs are identical, and the range contains only one element. +// - Zero element proof: A single edge proof suffices for verification. The proof is invalid if there are additional elements. +// +// The function returns a boolean indicating if there are more elements and an error if the range proof is invalid. +// +// TODO(weiihann): Given a binary leaf and a left-sibling first key, if the right sibling is removed, the proof would still be valid. +// Conversely, given a binary leaf and a right-sibling last key, if the left sibling is removed, the proof would still be valid. +// Range proof should not be valid for both of these cases, but currently is, which is an attack vector. +// The problem probably lies in how we do root hash calculation. +func VerifyRangeProof(rootHash, first *felt.Felt, keys, values []*felt.Felt, proof *ProofNodeSet) (bool, error) { //nolint:funlen,gocyclo + // Ensure the number of keys and values are the same + if len(keys) != len(values) { + return false, fmt.Errorf("inconsistent length of proof data, keys: %d, values: %d", len(keys), len(values)) + } + + // Ensure all keys are monotonically increasing and values contain no deletions + for i := 0; i < len(keys); i++ { + if i < len(keys)-1 && keys[i].Cmp(keys[i+1]) > 0 { + return false, errors.New("keys are not monotonic increasing") + } + + if values[i] == nil || values[i].Equal(&felt.Zero) { + return false, errors.New("range contains empty leaf") + } + } + + // Special case: no edge proof provided; the given range contains all leaves in the trie + if proof == nil { + tr := NewEmpty(contractClassTrieHeight, crypto.Pedersen) + for i, key := range keys { + if err := tr.Update(key, values[i]); err != nil { + return false, err + } + } + + recomputedRoot := tr.Hash() + if !recomputedRoot.Equal(rootHash) { + return false, fmt.Errorf("root hash mismatch, expected: %s, got: %s", rootHash.String(), recomputedRoot.String()) + } + + return false, nil // no more elements available + } + + var firstKey Path + firstKey.SetFelt(contractClassTrieHeight, first) + + // Special case: there is a provided proof but no key-value pairs, make sure regenerated trie has no more values + // Empty range proof with more elements on the right is not accepted in this function. + // This is due to snap sync specification detail, where the responder must send an existing key (if any) if the requested range is empty. + if len(keys) == 0 { + rootKey, val, err := proofToPath(rootHash, nil, firstKey, proof, true) + if err != nil { + return false, err + } + + if val != nil || hasRightElement(rootKey, firstKey) { + return false, errors.New("more entries available") + } + + return false, nil + } + + last := keys[len(keys)-1] + + var lastKey Path + lastKey.SetFelt(contractClassTrieHeight, last) + + // Special case: there is only one element and two edge keys are the same + if len(keys) == 1 && firstKey.Equal(&lastKey) { + root, val, err := proofToPath(rootHash, nil, firstKey, proof, false) + if err != nil { + return false, err + } + + firstItemKey := new(Path).SetFelt(contractClassTrieHeight, keys[0]) + if !firstKey.Equal(firstItemKey) { + return false, errors.New("correct proof but invalid key") + } + + if val == nil || !values[0].Equal(val) { + return false, errors.New("correct proof but invalid value") + } + + return hasRightElement(root, firstKey), nil + } + + // In all other cases, we require two edge paths available. + // First, ensure that the last key is greater than the first key + if last.Cmp(first) <= 0 { + return false, errors.New("last key is less than first key") + } + + root, _, err := proofToPath(rootHash, nil, firstKey, proof, true) + if err != nil { + return false, err + } + + root, _, err = proofToPath(rootHash, root, lastKey, proof, true) + if err != nil { + return false, err + } + + // TODO: unset internal + // empty, err := unsetInternal(root, firstKey, lastKey) + // if err != nil { + // return false, err + // } + + tr := NewEmpty(contractClassTrieHeight, crypto.Pedersen) + // if !empty { + // tr.root = root + // } + for i, key := range keys { + if err := tr.Update(key, values[i]); err != nil { + return false, err + } + } + + newRoot := tr.Hash() + + // Verify that the recomputed root hash matches the provided root hash + if !newRoot.Equal(rootHash) { + return false, fmt.Errorf("root hash mismatch, expected: %s, got: %s", rootHash.String(), newRoot.String()) + } + + return hasRightElement(root, lastKey), nil +} + +// proofToPath converts a Merkle proof to trie node path. All necessary nodes will be resolved and leave the remaining +// as hashes. The given edge proof can be existent or non-existent. +func proofToPath(rootHash *felt.Felt, root node, keyBits Path, proof *ProofNodeSet, allowNonExistent bool) (node, *felt.Felt, error) { + // Retrieves the node from the proof node set given the node hash + retrieveNode := func(hash felt.Felt) (node, error) { + n, _ := proof.Get(hash) + if n == nil { + return nil, fmt.Errorf("proof node not found, expected hash: %s", hash.String()) + } + return n, nil + } + + // Must resolve the root node first if it's not provided + if root == nil { + n, err := retrieveNode(*rootHash) + if err != nil { + return nil, nil, err + } + root = n + } + + var ( + err error + child, parent node + val *felt.Felt + ) + + parent = root + for { + msb := keyBits.MSB() // key gets modified in get(), we need the current msb to get the correct child during linking + child = get(parent, &keyBits, false) + switch n := child.(type) { + case nil: + if allowNonExistent { + return root, nil, nil + } + return nil, nil, errors.New("the node is not in the trie") + case *edgeNode: + parent = child + continue + case *binaryNode: + parent = child + continue + case *hashNode: + child, err = retrieveNode(n.Felt) + if err != nil { + return nil, nil, err + } + case *valueNode: + val = &n.Felt + } + // Link the parent and child + switch p := parent.(type) { + case *edgeNode: + p.child = child + case *binaryNode: + p.children[msb] = child + default: + panic(fmt.Sprintf("unknown parent node type: %T", p)) + } + + // We reached the leaf + if val != nil { + return root, val, nil + } + + parent = child + } +} + +// unsetInternal removes all internal node references (hashNode, embedded node). +// It should be called after a trie is constructed with two edge paths. Also +// the given boundary keys must be the ones used to construct the edge paths. +// +// It's the key step for range proof. All visited nodes should be marked dirty +// since the node content might be modified. +// +// Note we have the assumption here the given boundary keys are different +// and right is larger than left. +func unsetInternal(n node, left, right Path) (bool, error) { + // Step down to the fork point. There are two scenarios that can happen: + // - the fork point is an edgeNode: either the key of left proof or + // right proof doesn't match with the edge node's path + // - the fork point is a binaryNode: both two edge proofs are allowed + // to point to a non-existent key + var ( + pos uint8 + parent node + + // fork indicator for edge nodes + edgeForkLeft, edgeForkRight int + ) + +findFork: + for { + switch rn := n.(type) { + case *edgeNode: + rn.flags = newFlag() + + if left.Len()-pos < rn.path.Len() { + edgeForkLeft = new(Path).LSBs(&left, pos).Cmp(rn.path) + } else { + subKey := new(Path).Subset(&left, pos, pos+rn.path.Len()) + edgeForkLeft = subKey.Cmp(rn.path) + } + + if right.Len()-pos < rn.path.Len() { + edgeForkRight = new(Path).LSBs(&right, pos).Cmp(rn.path) + } else { + subKey := new(Path).Subset(&right, pos, pos+rn.path.Len()) + edgeForkRight = subKey.Cmp(rn.path) + } + + if edgeForkLeft != 0 || edgeForkRight != 0 { + break findFork + } + + parent = n + n = rn.child + pos += rn.path.Len() + case *binaryNode: + rn.flags = newFlag() + + leftnode, rightnode := rn.children[left.Bit(pos)], rn.children[right.Bit(pos)] + if leftnode == nil || rightnode == nil || leftnode != rightnode { + break findFork + } + parent = n + n = rn.children[left.Bit(pos)] + pos++ + default: + panic(fmt.Sprintf("%T: invalid node: %v", n, n)) + } + } + + switch rn := n.(type) { + case *edgeNode: + // There can have these five scenarios: + // - both proofs are less than the trie path => no valid range + // - both proofs are greater than the trie path => no valid range + // - left proof is less and right proof is greater => valid range, unset the shortnode entirely + // - left proof points to the shortnode, but right proof is greater + // - right proof points to the shortnode, but left proof is less + if edgeForkLeft == -1 && edgeForkRight == -1 { + return false, ErrEmptyRange + } + if edgeForkLeft == 1 && edgeForkRight == 1 { + return false, ErrEmptyRange + } + if edgeForkLeft != 0 && edgeForkRight != 0 { + // The fork point is root node, unset the entire trie + if parent == nil { + return true, nil + } + parent.(*binaryNode).children[left.Bit(pos-1)] = nil + return false, nil + } + // Only one proof points to non-existent key + if edgeForkRight != 0 { + if _, ok := rn.child.(*valueNode); ok { + // The fork point is root node, unset the entire trie + if parent == nil { + return true, nil + } + parent.(*binaryNode).children[left.Bit(pos-1)] = nil + return false, nil + } + return false, unset(rn, rn.child, new(Path).LSBs(&left, pos), rn.path.Len(), false) + } + if edgeForkLeft != 0 { + if _, ok := rn.child.(*valueNode); ok { + // The fork point is root node, unset the entire trie + if parent == nil { + return true, nil + } + parent.(*binaryNode).children[right.Bit(pos-1)] = nil + return false, nil + } + return false, unset(rn, rn.child, new(Path).LSBs(&right, pos), rn.path.Len(), true) + } + return false, nil + case *binaryNode: + leftBit := left.Bit(pos) + rightBit := right.Bit(pos) + if leftBit == 0 && rightBit == 0 { + rn.children[1] = nil + } + if leftBit == 1 && rightBit == 1 { + rn.children[0] = nil + } + if err := unset(rn, rn.children[leftBit], new(Path).LSBs(&left, pos), 1, false); err != nil { + return false, err + } + if err := unset(rn, rn.children[rightBit], new(Path).LSBs(&right, pos), 1, true); err != nil { + return false, err + } + return false, nil + default: + panic(fmt.Sprintf("%T: invalid node: %v", n, n)) + } +} + +// unset removes all internal node references either the left most or right most. +func unset(parent node, child node, key *Path, pos uint8, removeLeft bool) error { + switch cld := child.(type) { + case *binaryNode: + cld.flags = newFlag() // Mark dirty + if removeLeft { + // Remove left child if we're removing left side + if key.MSB() == 1 { + cld.children[0] = nil + } + } else { + // Remove right child if we're removing right side + if key.MSB() == 0 { + cld.children[1] = nil + } + } + return unset(cld, cld.children[key.Bit(pos)], key, pos+1, removeLeft) + + case *edgeNode: + keyPos := new(Path).LSBs(key, pos) + if !cld.pathMatches(keyPos) { + // Found fork point, non-existent branch + if removeLeft { + if cld.path.Cmp(keyPos) < 0 { + // Edge node path is in range, unset entire branch + parent.(*binaryNode).children[key.Bit(pos-1)] = nil + } + } else { + if cld.path.Cmp(keyPos) > 0 { + parent.(*binaryNode).children[key.Bit(pos-1)] = nil + } + } + return nil + } + if _, ok := cld.child.(*valueNode); ok { + parent.(*binaryNode).children[key.Bit(pos-1)] = nil + return nil + } + cld.flags = newFlag() + return unset(cld, cld.child, key, pos+cld.path.Len(), removeLeft) + + case nil, *hashNode, *valueNode: + // Child is nil, nothing to unset + return nil + default: + panic("it shouldn't happen") // hashNode, valueNode + } +} + +// hasRightElement checks if there is a right sibling for the given key in the trie. +// This function assumes that the entire path has been resolved. +func hasRightElement(node node, key Path) bool { + for node != nil { + switch n := node.(type) { + case *binaryNode: + for _, cn := range n.children { + if cn != nil { + return true + } + } + node = n.children[key.MSB()] + key.LSBs(&key, 1) + case *edgeNode: + if !n.pathMatches(&key) { + // There's a divergence in the path, check if the node path is greater than the key + // If so, that means that this node comes after the search key, which indicates that + // there are elements with larger values + return n.path.Cmp(&key) > 0 + } + node = n.child + case *valueNode: + return false // resolved the whole path + default: + panic(fmt.Sprintf("unknown node type: %T", n)) } } + return false } func get(rn node, key *Path, skipResolved bool) node { diff --git a/core/trie2/proof_test.go b/core/trie2/proof_test.go index 929a75dad6..18804df4ab 100644 --- a/core/trie2/proof_test.go +++ b/core/trie2/proof_test.go @@ -1 +1,745 @@ package trie2 + +import ( + "math/rand" + "testing" + + "github.com/NethermindEth/juno/core/crypto" + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/utils" + "github.com/stretchr/testify/require" +) + +func TestProve(t *testing.T) { + t.Parallel() + + n := 1000 + tempTrie, records := nonRandomTrie(t, n) + + for _, record := range records { + root := tempTrie.Hash() + + proofSet := NewProofNodeSet() + err := tempTrie.Prove(record.key, proofSet) + require.NoError(t, err) + + val, err := VerifyProof(&root, record.key, proofSet, crypto.Pedersen) + if err != nil { + t.Fatalf("failed for key %s", record.key.String()) + } + + if !val.Equal(record.value) { + t.Fatalf("expected value %s, got %s", record.value.String(), val.String()) + } + } +} + +func TestProveNonExistent(t *testing.T) { + t.Parallel() + + n := 1000 + tempTrie, _ := nonRandomTrie(t, n) + + for i := 1; i < n+1; i++ { + root := tempTrie.Hash() + + keyFelt := new(felt.Felt).SetUint64(uint64(i + n)) + proofSet := NewProofNodeSet() + err := tempTrie.Prove(keyFelt, proofSet) + require.NoError(t, err) + + val, err := VerifyProof(&root, keyFelt, proofSet, crypto.Pedersen) + if err != nil { + t.Fatalf("failed for key %s", keyFelt.String()) + } + require.Equal(t, felt.Zero, val) + } +} + +func TestProveRandom(t *testing.T) { + t.Parallel() + tempTrie, records := randomTrie(t, 1000) + + for _, record := range records { + root := tempTrie.Hash() + + proofSet := NewProofNodeSet() + err := tempTrie.Prove(record.key, proofSet) + require.NoError(t, err) + + val, err := VerifyProof(&root, record.key, proofSet, crypto.Pedersen) + require.NoError(t, err) + + if !val.Equal(record.value) { + t.Fatalf("expected value %s, got %s", record.value.String(), val.String()) + } + } +} + +func TestProveCustom(t *testing.T) { + t.Parallel() + + tests := []testTrie{ + { + name: "simple binary", + buildFn: buildSimpleTrie, + testKeys: []testKey{ + { + name: "prove existing key", + key: new(felt.Felt).SetUint64(1), + expected: new(felt.Felt).SetUint64(3), + }, + }, + }, + { + name: "simple double binary", + buildFn: buildSimpleDoubleBinaryTrie, + testKeys: []testKey{ + { + name: "prove existing key 0", + key: new(felt.Felt).SetUint64(0), + expected: new(felt.Felt).SetUint64(2), + }, + { + name: "prove existing key 3", + key: new(felt.Felt).SetUint64(3), + expected: new(felt.Felt).SetUint64(5), + }, + { + name: "prove non-existent key 2", + key: new(felt.Felt).SetUint64(2), + expected: new(felt.Felt).SetUint64(0), + }, + { + name: "prove non-existent key 123", + key: new(felt.Felt).SetUint64(123), + expected: new(felt.Felt).SetUint64(0), + }, + }, + }, + { + name: "simple binary root", + buildFn: buildSimpleBinaryRootTrie, + testKeys: []testKey{ + { + name: "prove existing key", + key: new(felt.Felt).SetUint64(0), + expected: utils.HexToFelt(t, "0xcc"), + }, + }, + }, + { + name: "left-right edge", + buildFn: func(t *testing.T) (*Trie, []*keyValue) { + tr, err := NewEmptyPedersen() + require.NoError(t, err) + + records := []*keyValue{ + {key: utils.HexToFelt(t, "0xff"), value: utils.HexToFelt(t, "0xaa")}, + } + + for _, record := range records { + err = tr.Update(record.key, record.value) + require.NoError(t, err) + } + return tr, records + }, + testKeys: []testKey{ + { + name: "prove existing key", + key: utils.HexToFelt(t, "0xff"), + expected: utils.HexToFelt(t, "0xaa"), + }, + }, + }, + { + name: "three key trie", + buildFn: build3KeyTrie, + testKeys: []testKey{ + { + name: "prove existing key", + key: new(felt.Felt).SetUint64(2), + expected: new(felt.Felt).SetUint64(6), + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + tr, _ := test.buildFn(t) + + for _, tc := range test.testKeys { + t.Run(tc.name, func(t *testing.T) { + proofSet := NewProofNodeSet() + + root := tr.Hash() + err := tr.Prove(tc.key, proofSet) + require.NoError(t, err) + + val, err := VerifyProof(&root, tc.key, proofSet, crypto.Pedersen) + require.NoError(t, err) + if !val.Equal(tc.expected) { + t.Fatalf("expected value %s, got %s", tc.expected.String(), val.String()) + } + }) + } + }) + } +} + +// TestRangeProof tests normal range proof with both edge proofs +func TestRangeProof(t *testing.T) { + t.Parallel() + + n := 500 + tr, records := randomTrie(t, n) + root := tr.Hash() + + for i := 0; i < 100; i++ { + start := rand.Intn(n) + end := rand.Intn(n-start) + start + 1 + + proof := NewProofNodeSet() + err := tr.GetRangeProof(records[start].key, records[end-1].key, proof) + require.NoError(t, err) + + keys := []*felt.Felt{} + values := []*felt.Felt{} + for i := start; i < end; i++ { + keys = append(keys, records[i].key) + values = append(values, records[i].value) + } + + _, err = VerifyRangeProof(&root, records[start].key, keys, values, proof) + require.NoError(t, err) + } +} + +// TestRangeProofWithNonExistentProof tests normal range proof with non-existent proofs +func TestRangeProofWithNonExistentProof(t *testing.T) { + t.Parallel() + + n := 500 + tr, records := randomTrie(t, n) + root := tr.Hash() + + for i := 0; i < 100; i++ { + start := rand.Intn(n) + end := rand.Intn(n-start) + start + 1 + + first := decrementFelt(records[start].key) + if start != 0 && first.Equal(records[start-1].key) { + continue + } + + proof := NewProofNodeSet() + err := tr.GetRangeProof(first, records[end-1].key, proof) + require.NoError(t, err) + + keys := make([]*felt.Felt, end-start) + values := make([]*felt.Felt, end-start) + for i := start; i < end; i++ { + keys[i-start] = records[i].key + values[i-start] = records[i].value + } + + _, err = VerifyRangeProof(&root, first, keys, values, proof) + require.NoError(t, err) + } +} + +// TestRangeProofWithInvalidNonExistentProof tests range proof with invalid non-existent proofs. +// One scenario is when there is a gap between the first element and the left edge proof. +func TestRangeProofWithInvalidNonExistentProof(t *testing.T) { + t.Parallel() + + n := 500 + tr, records := randomTrie(t, n) + root := tr.Hash() + + start, end := 100, 200 + first := decrementFelt(records[start].key) + + proof := NewProofNodeSet() + err := tr.GetRangeProof(first, records[end-1].key, proof) + require.NoError(t, err) + + start = 105 // Gap created + keys := make([]*felt.Felt, end-start) + values := make([]*felt.Felt, end-start) + for i := start; i < end; i++ { + keys[i-start] = records[i].key + values[i-start] = records[i].value + } + + _, err = VerifyRangeProof(&root, first, keys, values, proof) + require.Error(t, err) +} + +func TestRangeProofCustom(t *testing.T) { + tr, records := build4KeysTrieD(t) + root := tr.Hash() + + proof := NewProofNodeSet() + err := tr.GetRangeProof(records[0].key, records[2].key, proof) + require.NoError(t, err) + + _, err = VerifyRangeProof(&root, records[0].key, []*felt.Felt{records[0].key, records[1].key, records[2].key, records[3].key}, []*felt.Felt{records[0].value, records[1].value, records[2].value, records[3].value}, proof) + require.NoError(t, err) +} + +func TestOneElementRangeProof(t *testing.T) { + n := 1000 + tr, records := randomTrie(t, n) + root := tr.Hash() + + t.Run("both edge proofs with the same key", func(t *testing.T) { + start := 100 + proof := NewProofNodeSet() + err := tr.GetRangeProof(records[start].key, records[start].key, proof) + require.NoError(t, err) + + _, err = VerifyRangeProof(&root, records[start].key, []*felt.Felt{records[start].key}, []*felt.Felt{records[start].value}, proof) + require.NoError(t, err) + }) + + t.Run("left non-existent edge proof", func(t *testing.T) { + start := 100 + proof := NewProofNodeSet() + err := tr.GetRangeProof(decrementFelt(records[start].key), records[start].key, proof) + require.NoError(t, err) + + _, err = VerifyRangeProof(&root, decrementFelt(records[start].key), []*felt.Felt{records[start].key}, []*felt.Felt{records[start].value}, proof) + require.NoError(t, err) + }) + + t.Run("right non-existent edge proof", func(t *testing.T) { + end := 100 + proof := NewProofNodeSet() + err := tr.GetRangeProof(records[end].key, incrementFelt(records[end].key), proof) + require.NoError(t, err) + + _, err = VerifyRangeProof(&root, records[end].key, []*felt.Felt{records[end].key}, []*felt.Felt{records[end].value}, proof) + require.NoError(t, err) + }) + + t.Run("both non-existent edge proofs", func(t *testing.T) { + start := 100 + first, last := decrementFelt(records[start].key), incrementFelt(records[start].key) + proof := NewProofNodeSet() + err := tr.GetRangeProof(first, last, proof) + require.NoError(t, err) + + _, err = VerifyRangeProof(&root, first, []*felt.Felt{records[start].key}, []*felt.Felt{records[start].value}, proof) + require.NoError(t, err) + }) + + t.Run("1 key trie", func(t *testing.T) { + tr, records := build1KeyTrie(t) + root := tr.Hash() + + proof := NewProofNodeSet() + err := tr.GetRangeProof(&felt.Zero, records[0].key, proof) + require.NoError(t, err) + + _, err = VerifyRangeProof(&root, records[0].key, []*felt.Felt{records[0].key}, []*felt.Felt{records[0].value}, proof) + require.NoError(t, err) + }) +} + +// TestAllElementsRangeProof tests the range proof with all elements and nil proof. +func TestAllElementsRangeProof(t *testing.T) { + t.Parallel() + + n := 1000 + tr, records := randomTrie(t, n) + root := tr.Hash() + + keys := make([]*felt.Felt, n) + values := make([]*felt.Felt, n) + for i, record := range records { + keys[i] = record.key + values[i] = record.value + } + + _, err := VerifyRangeProof(&root, nil, keys, values, nil) + require.NoError(t, err) + + // Should also work with proof + proof := NewProofNodeSet() + err = tr.GetRangeProof(records[0].key, records[n-1].key, proof) + require.NoError(t, err) + + _, err = VerifyRangeProof(&root, keys[0], keys, values, proof) + require.NoError(t, err) +} + +// TestSingleSideRangeProof tests the range proof starting with zero. +func TestSingleSideRangeProof(t *testing.T) { + t.Parallel() + + tr, records := randomTrie(t, 1000) + root := tr.Hash() + + for i := 0; i < len(records); i += 100 { + proof := NewProofNodeSet() + err := tr.GetRangeProof(&felt.Zero, records[i].key, proof) + require.NoError(t, err) + + keys := make([]*felt.Felt, i+1) + values := make([]*felt.Felt, i+1) + for j := 0; j < i+1; j++ { + keys[j] = records[j].key + values[j] = records[j].value + } + + _, err = VerifyRangeProof(&root, &felt.Zero, keys, values, proof) + require.NoError(t, err) + } +} + +func TestGappedRangeProof(t *testing.T) { + t.Parallel() + // t.Skip("gapped keys will sometimes succeed, the current proof format is not able to handle this") + + tr, records := nonRandomTrie(t, 5) + root := tr.Hash() + + first, last := 1, 4 + proof := NewProofNodeSet() + err := tr.GetRangeProof(records[first].key, records[last].key, proof) + require.NoError(t, err) + + keys := []*felt.Felt{} + values := []*felt.Felt{} + for i := first; i <= last; i++ { + if i == (first+last)/2 { + continue + } + + keys = append(keys, records[i].key) + values = append(values, records[i].value) + } + + _, err = VerifyRangeProof(&root, records[first].key, keys, values, proof) + require.Error(t, err) +} + +func TestEmptyRangeProof(t *testing.T) { + t.Parallel() + + tr, records := randomTrie(t, 1000) + root := tr.Hash() + + cases := []struct { + pos int + err bool + }{ + {len(records) - 1, false}, + {500, true}, + } + + for _, c := range cases { + proof := NewProofNodeSet() + first := incrementFelt(records[c.pos].key) + err := tr.GetRangeProof(first, first, proof) + require.NoError(t, err) + + _, err = VerifyRangeProof(&root, first, nil, nil, proof) + if c.err { + require.Error(t, err) + } else { + require.NoError(t, err) + } + } +} + +func TestHasRightElement(t *testing.T) { + t.Parallel() + + tr, records := randomTrie(t, 500) + root := tr.Hash() + + cases := []struct { + start int + end int + hasMore bool + }{ + {-1, 1, true}, // single element with non-existent left proof + {0, 1, true}, // single element with existent left proof + {0, 100, true}, // start to middle + {50, 100, true}, // middle only + {50, len(records), false}, // middle to end + {len(records) - 1, len(records), false}, // Single last element with two existent proofs(point to same key) + {0, len(records), false}, // The whole set with existent left proof + {-1, len(records), false}, // The whole set with non-existent left proof + } + + for _, c := range cases { + var ( + first *felt.Felt + start = c.start + end = c.end + proof = NewProofNodeSet() + ) + if start == -1 { + first = &felt.Zero + start = 0 + } else { + first = records[start].key + } + + err := tr.GetRangeProof(first, records[end-1].key, proof) + require.NoError(t, err) + + keys := []*felt.Felt{} + values := []*felt.Felt{} + for i := start; i < end; i++ { + keys = append(keys, records[i].key) + values = append(values, records[i].value) + } + + hasMore, err := VerifyRangeProof(&root, first, keys, values, proof) + require.NoError(t, err) + require.Equal(t, c.hasMore, hasMore) + } +} + +// TestBadRangeProof generates random bad proof scenarios and verifies that the proof is invalid. +func TestBadRangeProof(t *testing.T) { + t.Parallel() + + tr, records := randomTrie(t, 1000) + root := tr.Hash() + + for i := 0; i < 100; i++ { + start := rand.Intn(len(records)) + end := rand.Intn(len(records)-start) + start + 1 + + proof := NewProofNodeSet() + err := tr.GetRangeProof(records[start].key, records[end-1].key, proof) + require.NoError(t, err) + + keys := []*felt.Felt{} + values := []*felt.Felt{} + for j := start; j < end; j++ { + keys = append(keys, records[j].key) + values = append(values, records[j].value) + } + + first := keys[0] + testCase := rand.Intn(5) + + index := rand.Intn(end - start) + switch testCase { + case 0: // modified key + keys[index] = new(felt.Felt).SetUint64(rand.Uint64()) + case 1: // modified value + values[index] = new(felt.Felt).SetUint64(rand.Uint64()) + case 2: // out of order + index2 := rand.Intn(end - start) + if index2 == index { + continue + } + keys[index], keys[index2] = keys[index2], keys[index] + values[index], values[index2] = values[index2], values[index] + case 3: // set random key to empty + keys[index] = &felt.Zero + case 4: // set random value to empty + values[index] = &felt.Zero + // TODO(weiihann): gapped proof will fail sometimes + // case 5: // gapped + // if end-start < 100 || index == 0 || index == end-start-1 { + // continue + // } + // keys = append(keys[:index], keys[index+1:]...) + // values = append(values[:index], values[index+1:]...) + } + _, err = VerifyRangeProof(&root, first, keys, values, proof) + if err == nil { + t.Fatalf("expected error for test case %d, index %d, start %d, end %d", testCase, index, start, end) + } + } +} + +func BenchmarkProve(b *testing.B) { + tr, records := randomTrie(b, 1000) + b.ResetTimer() + for i := 0; i < b.N; i++ { + proof := NewProofNodeSet() + key := records[i%len(records)].key + if err := tr.Prove(key, proof); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkVerifyProof(b *testing.B) { + tr, records := randomTrie(b, 1000) + root := tr.Hash() + + var proofs []*ProofNodeSet + for _, record := range records { + proof := NewProofNodeSet() + if err := tr.Prove(record.key, proof); err != nil { + b.Fatal(err) + } + proofs = append(proofs, proof) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + index := i % len(records) + if _, err := VerifyProof(&root, records[index].key, proofs[index], crypto.Pedersen); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkVerifyRangeProof(b *testing.B) { + tr, records := randomTrie(b, 1000) + root := tr.Hash() + + start := 2 + end := start + 500 + + proof := NewProofNodeSet() + err := tr.GetRangeProof(records[start].key, records[end-1].key, proof) + require.NoError(b, err) + + keys := make([]*felt.Felt, end-start) + values := make([]*felt.Felt, end-start) + for i := start; i < end; i++ { + keys[i-start] = records[i].key + values[i-start] = records[i].value + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := VerifyRangeProof(&root, keys[0], keys, values, proof) + require.NoError(b, err) + } +} + +func buildTrie(t *testing.T, records []*keyValue) *Trie { + if len(records) == 0 { + t.Fatal("records must have at least one element") + } + + tempTrie, err := NewEmptyPedersen() + require.NoError(t, err) + + for _, record := range records { + err = tempTrie.Update(record.key, record.value) + require.NoError(t, err) + } + + return tempTrie +} + +func build1KeyTrie(t *testing.T) (*Trie, []*keyValue) { + return nonRandomTrie(t, 1) +} + +func buildSimpleTrie(t *testing.T) (*Trie, []*keyValue) { + // (250, 0, x1) edge + // | + // (0,0,x1) binary + // / \ + // (2) (3) + records := []*keyValue{ + {key: new(felt.Felt).SetUint64(0), value: new(felt.Felt).SetUint64(2)}, + {key: new(felt.Felt).SetUint64(1), value: new(felt.Felt).SetUint64(3)}, + } + + return buildTrie(t, records), records +} + +func buildSimpleBinaryRootTrie(t *testing.T) (*Trie, []*keyValue) { + // PF + // (0, 0, x) + // / \ + // (250, 0, cc) (250, 11111.., dd) + // | | + // (cc) (dd) + + // JUNO + // (0, 0, x) + // / \ + // (251, 0, cc) (251, 11111.., dd) + records := []*keyValue{ + {key: new(felt.Felt).SetUint64(0), value: utils.HexToFelt(t, "0xcc")}, + {key: utils.HexToFelt(t, "0x7ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"), value: utils.HexToFelt(t, "0xdd")}, + } + return buildTrie(t, records), records +} + +//nolint:dupl +func buildSimpleDoubleBinaryTrie(t *testing.T) (*Trie, []*keyValue) { + // (249,0,x3) // Edge + // | + // (0, 0, x3) // Binary + // / \ + // (0,0,x1) // B (1, 1, 5) // Edge leaf + // / \ | + // (2) (3) (5) + records := []*keyValue{ + {key: new(felt.Felt).SetUint64(0), value: new(felt.Felt).SetUint64(2)}, + {key: new(felt.Felt).SetUint64(1), value: new(felt.Felt).SetUint64(3)}, + {key: new(felt.Felt).SetUint64(3), value: new(felt.Felt).SetUint64(5)}, + } + return buildTrie(t, records), records +} + +//nolint:dupl +func build3KeyTrie(t *testing.T) (*Trie, []*keyValue) { + // Starknet + // -------- + // + // Edge + // | + // Binary with len 249 parent + // / \ + // Binary (250) Edge with len 250 + // / \ / + // 0x4 0x5 0x6 child + + // Juno + // ---- + // + // Node (path 249) + // / \ + // Node (binary) \ + // / \ / + // 0x4 0x5 0x6 + records := []*keyValue{ + {key: new(felt.Felt).SetUint64(0), value: new(felt.Felt).SetUint64(4)}, + {key: new(felt.Felt).SetUint64(1), value: new(felt.Felt).SetUint64(5)}, + {key: new(felt.Felt).SetUint64(2), value: new(felt.Felt).SetUint64(6)}, + } + + return buildTrie(t, records), records +} + +func decrementFelt(f *felt.Felt) *felt.Felt { + return new(felt.Felt).Sub(f, new(felt.Felt).SetUint64(1)) +} + +func incrementFelt(f *felt.Felt) *felt.Felt { + return new(felt.Felt).Add(f, new(felt.Felt).SetUint64(1)) +} + +type testKey struct { + name string + key *felt.Felt + expected *felt.Felt +} + +type testTrie struct { + name string + buildFn func(*testing.T) (*Trie, []*keyValue) + testKeys []testKey +} diff --git a/core/trie2/tracer.go b/core/trie2/tracer.go index 81b253718d..0d0659130a 100644 --- a/core/trie2/tracer.go +++ b/core/trie2/tracer.go @@ -4,19 +4,19 @@ import ( "maps" ) -type tracer struct { +type nodeTracer struct { inserts map[Path]struct{} deletes map[Path]struct{} } -func newTracer() *tracer { - return &tracer{ +func newTracer() *nodeTracer { + return &nodeTracer{ inserts: make(map[Path]struct{}), deletes: make(map[Path]struct{}), } } -func (t *tracer) onInsert(key *Path) { +func (t *nodeTracer) onInsert(key *Path) { k := *key if _, present := t.deletes[k]; present { return @@ -24,7 +24,7 @@ func (t *tracer) onInsert(key *Path) { t.inserts[k] = struct{}{} } -func (t *tracer) onDelete(key *Path) { +func (t *nodeTracer) onDelete(key *Path) { k := *key if _, present := t.inserts[k]; present { return @@ -32,19 +32,19 @@ func (t *tracer) onDelete(key *Path) { t.deletes[k] = struct{}{} } -func (t *tracer) reset() { +func (t *nodeTracer) reset() { t.inserts = make(map[Path]struct{}) t.deletes = make(map[Path]struct{}) } -func (t *tracer) copy() *tracer { - return &tracer{ +func (t *nodeTracer) copy() *nodeTracer { + return &nodeTracer{ inserts: maps.Clone(t.inserts), deletes: maps.Clone(t.deletes), } } -func (t *tracer) deletedNodes() []Path { +func (t *nodeTracer) deletedNodes() []Path { keys := make([]Path, 0, len(t.deletes)) for k := range t.deletes { keys = append(keys, k) diff --git a/core/trie2/trie.go b/core/trie2/trie.go index 2085e0774a..60744add6c 100644 --- a/core/trie2/trie.go +++ b/core/trie2/trie.go @@ -38,7 +38,7 @@ type Trie struct { committed bool // Maintains the records of trie changes, ensuring all nodes are modified or garbage collected properly - tracer *tracer + nodeTracer *nodeTracer // Tracks the number of leaves inserted since the last hashing operation pendingHashes int @@ -50,11 +50,11 @@ type Trie struct { func New(id *ID, height uint8, hashFn crypto.HashFn, txn db.Transaction) (*Trie, error) { database := triedb.New(txn, id.Bucket()) tr := &Trie{ - owner: id.Owner, - height: height, - hashFn: hashFn, - db: database, - tracer: newTracer(), + owner: id.Owner, + height: height, + hashFn: hashFn, + db: database, + nodeTracer: newTracer(), } if id.Root != emptyRoot { @@ -67,6 +67,16 @@ func New(id *ID, height uint8, hashFn crypto.HashFn, txn db.Transaction) (*Trie, return tr, nil } +// Creates an empty trie, only used for temporary trie construction +func NewEmpty(height uint8, hashFn crypto.HashFn) *Trie { + return &Trie{ + height: height, + hashFn: hashFn, + root: nil, + nodeTracer: newTracer(), + } +} + // Modifies or inserts a key-value pair in the trie. // If value is zero, the key is deleted from the trie. func (t *Trie) Update(key, value *felt.Felt) error { @@ -152,7 +162,7 @@ func (t *Trie) Commit() (felt.Felt, error) { } nodes := trienode.NewNodeSet(t.owner) - for _, Path := range t.tracer.deletedNodes() { + for _, Path := range t.nodeTracer.deletedNodes() { nodes.Add(Path, trienode.NewDeleted()) } @@ -180,7 +190,7 @@ func (t *Trie) Copy() *Trie { hashFn: t.hashFn, committed: t.committed, db: t.db, - tracer: t.tracer.copy(), + nodeTracer: t.nodeTracer.copy(), pendingHashes: t.pendingHashes, pendingUpdates: t.pendingUpdates, } @@ -291,7 +301,7 @@ func (t *Trie) insert(n node, prefix, key *Path, value node) (bool, node, error) return true, branch, nil } matchPrefix := new(Path).MSBs(key, match.Len()) - t.tracer.onInsert(new(Path).Append(prefix, matchPrefix)) + t.nodeTracer.onInsert(new(Path).Append(prefix, matchPrefix)) // Otherwise, create a new edge node with the Path being the common Path and the branch as the child return true, &edgeNode{path: matchPrefix, child: branch, flags: newFlag()}, nil @@ -311,7 +321,7 @@ func (t *Trie) insert(n node, prefix, key *Path, value node) (bool, node, error) n.children[bit] = newNode return true, n, nil case nil: - t.tracer.onInsert(prefix) + t.nodeTracer.onInsert(prefix) // We reach the end of the key, return the value node if key.IsEmpty() { return true, value, nil @@ -343,7 +353,7 @@ func (t *Trie) delete(n node, prefix, key *Path) (bool, node, error) { } // If the whole key matches, remove the entire edge node if match.Len() == key.Len() { - t.tracer.onDelete(prefix) + t.nodeTracer.onDelete(prefix) return true, nil, nil } @@ -356,7 +366,7 @@ func (t *Trie) delete(n node, prefix, key *Path) (bool, node, error) { } switch child := child.(type) { case *edgeNode: - t.tracer.onDelete(new(Path).Append(prefix, n.path)) + t.nodeTracer.onDelete(new(Path).Append(prefix, n.path)) return true, &edgeNode{path: new(Path).Append(n.path, child.path), child: child.child, flags: newFlag()}, nil default: return true, &edgeNode{path: new(Path).Set(n.path), child: child, flags: newFlag()}, nil @@ -381,7 +391,7 @@ func (t *Trie) delete(n node, prefix, key *Path) (bool, node, error) { other := bit ^ 1 bitPrefix := new(Path).SetBit(other) if cn, ok := n.children[other].(*edgeNode); ok { // other child is an edge node, append the bit prefix to the child Path - t.tracer.onDelete(new(Path).Append(prefix, bitPrefix)) + t.nodeTracer.onDelete(new(Path).Append(prefix, bitPrefix)) return true, &edgeNode{ path: new(Path).Append(bitPrefix, cn.path), child: cn.child, diff --git a/core/trie2/trie_test.go b/core/trie2/trie_test.go index 360ece6c00..8ef0fd41f5 100644 --- a/core/trie2/trie_test.go +++ b/core/trie2/trie_test.go @@ -2,6 +2,7 @@ package trie2 import ( "math/rand" + "sort" "testing" "github.com/NethermindEth/juno/core/felt" @@ -221,6 +222,11 @@ func randomTrie(t testing.TB, n int) (*Trie, []*keyValue) { require.NoError(t, err) } + // Sort records by key + sort.Slice(records, func(i, j int) bool { + return records[i].key.Cmp(records[j].key) < 0 + }) + return tr, records } @@ -232,10 +238,10 @@ func build4KeysTrieD(t *testing.T) (*Trie, []*keyValue) { {key: new(felt.Felt).SetUint64(7), value: new(felt.Felt).SetUint64(7)}, } - return buildTrie(t, records), records + return buildTestTrie(t, records), records } -func buildTrie(t *testing.T, records []*keyValue) *Trie { +func buildTestTrie(t *testing.T, records []*keyValue) *Trie { if len(records) == 0 { t.Fatal("records must have at least one element") } diff --git a/core/trie2/trieutils/bitarray.go b/core/trie2/trieutils/bitarray.go index 3b404cb37c..3af7a774b0 100644 --- a/core/trie2/trieutils/bitarray.go +++ b/core/trie2/trieutils/bitarray.go @@ -104,6 +104,7 @@ func (b *BitArray) LSBsFromLSB(x *BitArray, n uint8) *BitArray { } // Returns the least significant bits of `x` with `n` counted from the most significant bit, starting at 0. +// Think of this method as array[n:] // For example: // // x = 11001011 (len=8) @@ -159,6 +160,7 @@ func (b *BitArray) EqualMSBs(x *BitArray) bool { // Sets the bit array to the most significant 'n' bits of x, that is position 0 to n (exclusive). // If n >= x.len, the bit array is an exact copy of x. +// Think of this method as array[0:n] // For example: // // x = 11001011 (len=8) @@ -324,6 +326,39 @@ func (b *BitArray) AppendBit(x *BitArray, bit uint8) *BitArray { return b.Append(b, new(BitArray).SetBit(bit)) } +// Sets the bit array to a subset of x from startPos (inclusive) to endPos (exclusive), +// where position 0 is the MSB. If startPos >= endPos or if startPos >= x.len, +// returns an empty BitArray. +// Think of this method as array[start:end] +// For example: +// +// x = 001011011 (len=9) +// Subset(x, 2, 5) = 101 (len=3) +func (b *BitArray) Subset(x *BitArray, startPos, endPos uint8) *BitArray { + // Check for invalid inputs + if startPos >= endPos || startPos >= x.len { + return b.clear() + } + + // Clamp endPos to x.len if it exceeds it + if endPos > x.len { + endPos = x.len + } + + length := endPos - startPos + + // First, trim off the MSBs that are not part of the subset + b.LSBs(x, startPos) + + // Then, we create a mask of ones and appends zeros to the end to match the length + mask := new(BitArray).Ones(length) + zeros := &BitArray{len: b.len - length} + mask.Append(mask, zeros) + + // Apply the mask to the bit array and then only take the first `length` bits + return b.And(b, mask).MSBs(b, length) +} + // Sets the bit array to x | y and returns the bit array. func (b *BitArray) Or(x, y *BitArray) *BitArray { b.words[0] = x.words[0] | y.words[0] @@ -334,6 +369,15 @@ func (b *BitArray) Or(x, y *BitArray) *BitArray { return b } +func (b *BitArray) And(x, y *BitArray) *BitArray { + b.words[0] = x.words[0] & y.words[0] + b.words[1] = x.words[1] & y.words[1] + b.words[2] = x.words[2] & y.words[2] + b.words[3] = x.words[3] & y.words[3] + b.len = x.len + return b +} + // Sets the bit array to x ^ y and returns the bit array. func (b *BitArray) Xor(x, y *BitArray) *BitArray { b.words[0] = x.words[0] ^ y.words[0] @@ -741,6 +785,17 @@ func (b *BitArray) truncateToLength() { } } +// Sets the bit array to a sequence of ones of the specified length. +func (b *BitArray) Ones(length uint8) *BitArray { + b.len = length + b.words[0] = maxUint64 + b.words[1] = maxUint64 + b.words[2] = maxUint64 + b.words[3] = maxUint64 + b.truncateToLength() + return b +} + // Returns the position of the first '1' bit in the array, scanning from most significant to least significant bit. // The bit position is counted from the least significant bit, starting at 0. // For example: From 245a389f7d85902862d30ca5b697ef163417fc5b Mon Sep 17 00:00:00 2001 From: weiihann Date: Tue, 21 Jan 2025 10:12:53 +0800 Subject: [PATCH 25/44] still failing --- core/trie/proof_test.go | 40 ++------------------------------ core/trie2/hasher.go | 15 ++++++++++++ core/trie2/node.go | 17 +++++++++++--- core/trie2/proof.go | 39 +++++++++++++++++++++---------- core/trie2/proof_test.go | 22 ++++++++++++++++++ core/trie2/trieutils/bitarray.go | 2 +- 6 files changed, 81 insertions(+), 54 deletions(-) diff --git a/core/trie/proof_test.go b/core/trie/proof_test.go index dcbc9dc43e..dc441036cc 100644 --- a/core/trie/proof_test.go +++ b/core/trie/proof_test.go @@ -14,8 +14,6 @@ import ( ) func TestProve(t *testing.T) { - t.Parallel() - n := 1000 tempTrie, records := nonRandomTrie(t, n) @@ -36,8 +34,6 @@ func TestProve(t *testing.T) { } func TestProveNonExistent(t *testing.T) { - t.Parallel() - n := 1000 tempTrie, _ := nonRandomTrie(t, n) @@ -60,7 +56,6 @@ func TestProveNonExistent(t *testing.T) { } func TestProveRandom(t *testing.T) { - t.Parallel() tempTrie, records := randomTrie(t, 1000) for _, record := range records { @@ -78,8 +73,6 @@ func TestProveRandom(t *testing.T) { } func TestProveCustom(t *testing.T) { - t.Parallel() - tests := []testTrie{ { name: "simple binary", @@ -173,8 +166,6 @@ func TestProveCustom(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - t.Parallel() - tr, _ := test.buildFn(t) for _, tc := range test.testKeys { @@ -197,8 +188,6 @@ func TestProveCustom(t *testing.T) { // TestRangeProof tests normal range proof with both edge proofs func TestRangeProof(t *testing.T) { - t.Parallel() - n := 500 tr, records := randomTrie(t, n) root, err := tr.Root() @@ -226,8 +215,6 @@ func TestRangeProof(t *testing.T) { // TestRangeProofWithNonExistentProof tests normal range proof with non-existent proofs func TestRangeProofWithNonExistentProof(t *testing.T) { - t.Parallel() - n := 500 tr, records := randomTrie(t, n) root, err := tr.Root() @@ -260,9 +247,8 @@ func TestRangeProofWithNonExistentProof(t *testing.T) { // TestRangeProofWithInvalidNonExistentProof tests range proof with invalid non-existent proofs. // One scenario is when there is a gap between the first element and the left edge proof. +// TODO(weiihann): this is failing func TestRangeProofWithInvalidNonExistentProof(t *testing.T) { - t.Parallel() - n := 500 tr, records := randomTrie(t, n) root, err := tr.Root() @@ -289,16 +275,12 @@ func TestRangeProofWithInvalidNonExistentProof(t *testing.T) { } func TestOneElementRangeProof(t *testing.T) { - t.Parallel() - n := 1000 tr, records := randomTrie(t, n) root, err := tr.Root() require.NoError(t, err) t.Run("both edge proofs with the same key", func(t *testing.T) { - t.Parallel() - start := 100 proof := trie.NewProofNodeSet() err = tr.GetRangeProof(records[start].key, records[start].key, proof) @@ -309,8 +291,6 @@ func TestOneElementRangeProof(t *testing.T) { }) t.Run("left non-existent edge proof", func(t *testing.T) { - t.Parallel() - start := 100 proof := trie.NewProofNodeSet() err = tr.GetRangeProof(decrementFelt(records[start].key), records[start].key, proof) @@ -321,8 +301,6 @@ func TestOneElementRangeProof(t *testing.T) { }) t.Run("right non-existent edge proof", func(t *testing.T) { - t.Parallel() - end := 100 proof := trie.NewProofNodeSet() err = tr.GetRangeProof(records[end].key, incrementFelt(records[end].key), proof) @@ -333,8 +311,6 @@ func TestOneElementRangeProof(t *testing.T) { }) t.Run("both non-existent edge proofs", func(t *testing.T) { - t.Parallel() - start := 100 first, last := decrementFelt(records[start].key), incrementFelt(records[start].key) proof := trie.NewProofNodeSet() @@ -346,8 +322,6 @@ func TestOneElementRangeProof(t *testing.T) { }) t.Run("1 key trie", func(t *testing.T) { - t.Parallel() - tr, records := build1KeyTrie(t) root, err := tr.Root() require.NoError(t, err) @@ -363,8 +337,6 @@ func TestOneElementRangeProof(t *testing.T) { // TestAllElementsRangeProof tests the range proof with all elements and nil proof. func TestAllElementsRangeProof(t *testing.T) { - t.Parallel() - n := 1000 tr, records := randomTrie(t, n) root, err := tr.Root() @@ -391,8 +363,6 @@ func TestAllElementsRangeProof(t *testing.T) { // TestSingleSideRangeProof tests the range proof starting with zero. func TestSingleSideRangeProof(t *testing.T) { - t.Parallel() - tr, records := randomTrie(t, 1000) root, err := tr.Root() require.NoError(t, err) @@ -414,8 +384,8 @@ func TestSingleSideRangeProof(t *testing.T) { } } +// TODO(weiihann): this is failing func TestGappedRangeProof(t *testing.T) { - t.Parallel() t.Skip("gapped keys will sometimes succeed, the current proof format is not able to handle this") tr, records := nonRandomTrie(t, 5) @@ -443,8 +413,6 @@ func TestGappedRangeProof(t *testing.T) { } func TestEmptyRangeProof(t *testing.T) { - t.Parallel() - tr, records := randomTrie(t, 1000) root, err := tr.Root() require.NoError(t, err) @@ -473,8 +441,6 @@ func TestEmptyRangeProof(t *testing.T) { } func TestHasRightElement(t *testing.T) { - t.Parallel() - tr, records := randomTrie(t, 500) root, err := tr.Root() require.NoError(t, err) @@ -526,8 +492,6 @@ func TestHasRightElement(t *testing.T) { // TestBadRangeProof generates random bad proof scenarios and verifies that the proof is invalid. func TestBadRangeProof(t *testing.T) { - t.Parallel() - tr, records := randomTrie(t, 1000) root, err := tr.Root() require.NoError(t, err) diff --git a/core/trie2/hasher.go b/core/trie2/hasher.go index 12c9cef284..401e9c9c07 100644 --- a/core/trie2/hasher.go +++ b/core/trie2/hasher.go @@ -83,3 +83,18 @@ func (h *hasher) hashBinaryChildren(n *binaryNode) (collapsed, cached *binaryNod return collapsed, cached } + +// Construct trie proofs and returns the collapsed node (i.e. nodes with hash children) +// and the hashed node. +func (h *hasher) proofHash(original node) (collapsed, hashed node) { + switch n := original.(type) { + case *edgeNode: + en, _ := h.hashEdgeChild(n) + return en, &hashNode{Felt: *en.hash(h.hashFn)} + case *binaryNode: + bn, _ := h.hashBinaryChildren(n) + return bn, &hashNode{Felt: *bn.hash(h.hashFn)} + default: + return n, n + } +} diff --git a/core/trie2/node.go b/core/trie2/node.go index 1333a0ba9b..bd1884b092 100644 --- a/core/trie2/node.go +++ b/core/trie2/node.go @@ -65,15 +65,26 @@ func (n *hashNode) cache() (*hashNode, bool) { return nil, true } func (n *valueNode) cache() (*hashNode, bool) { return nil, true } func (n *binaryNode) String() string { + var left, right string + if n.children[0] != nil { + left = n.children[0].String() + } + if n.children[1] != nil { + right = n.children[1].String() + } return fmt.Sprintf("Binary[\n left: %s\n right: %s\n]", - indent(n.children[0].String()), - indent(n.children[1].String())) + indent(left), + indent(right)) } func (n *edgeNode) String() string { + var child string + if n.child != nil { + child = n.child.String() + } return fmt.Sprintf("Edge{\n path: %s\n child: %s\n}", n.path.String(), - indent(n.child.String())) + indent(child)) } func (n hashNode) String() string { diff --git a/core/trie2/proof.go b/core/trie2/proof.go index 9b5132a5d2..d9792c1d23 100644 --- a/core/trie2/proof.go +++ b/core/trie2/proof.go @@ -63,10 +63,16 @@ func (t *Trie) Prove(key *felt.Felt, proof *ProofNodeSet) error { // TODO: ideally Hash() should be called before Prove() so that the hashes are cached // There should be a better way to do this - h := newHasher(t.hashFn, false) - for _, n := range nodes { - hashed, cached := h.hash(n) // subsequent nodes are cached - proof.Put(hashed.(*hashNode).Felt, cached) + hasher := newHasher(t.hashFn, false) + for i, n := range nodes { + var hn node + n, hn = hasher.proofHash(n) + if hash, ok := hn.(*hashNode); ok || i == 0 { + if !ok { + hash = &hashNode{Felt: *n.hash(hasher.hashFn)} + } + proof.Put(hash.Felt, n) + } } return nil @@ -244,22 +250,31 @@ func VerifyRangeProof(rootHash, first *felt.Felt, keys, values []*felt.Felt, pro return false, err } - // TODO: unset internal - // empty, err := unsetInternal(root, firstKey, lastKey) - // if err != nil { - // return false, err - // } + fmt.Println("Before unsetInternal") + fmt.Println(root.String()) + + empty, err := unsetInternal(root, firstKey, lastKey) + if err != nil { + return false, err + } + + fmt.Println("After unsetInternal") + fmt.Println(root.String()) tr := NewEmpty(contractClassTrieHeight, crypto.Pedersen) - // if !empty { - // tr.root = root - // } + if !empty { + tr.root = root + } + for i, key := range keys { if err := tr.Update(key, values[i]); err != nil { return false, err } } + fmt.Println("After update") + fmt.Println(tr.root.String()) + newRoot := tr.Hash() // Verify that the recomputed root hash matches the provided root hash diff --git a/core/trie2/proof_test.go b/core/trie2/proof_test.go index 18804df4ab..65450e4def 100644 --- a/core/trie2/proof_test.go +++ b/core/trie2/proof_test.go @@ -218,6 +218,28 @@ func TestRangeProof(t *testing.T) { } } +func TestRangeProofNonRandom(t *testing.T) { + tr, records := nonRandomTrie(t, 6) + root := tr.Hash() + + t.Log("Original trie") + t.Log(tr.root.String()) + + proof := NewProofNodeSet() + err := tr.GetRangeProof(records[0].key, records[3].key, proof) + require.NoError(t, err) + + for i := 0; i < 4; i++ { + t.Logf("key %s: value %s", records[i].key.String(), records[i].value.String()) + } + + keys := []*felt.Felt{records[0].key, records[1].key, records[2].key, records[3].key} + values := []*felt.Felt{records[0].value, records[1].value, records[2].value, records[3].value} + + _, err = VerifyRangeProof(&root, records[0].key, keys, values, proof) + require.NoError(t, err) +} + // TestRangeProofWithNonExistentProof tests normal range proof with non-existent proofs func TestRangeProofWithNonExistentProof(t *testing.T) { t.Parallel() diff --git a/core/trie2/trieutils/bitarray.go b/core/trie2/trieutils/bitarray.go index 3af7a774b0..27f2f9ba7e 100644 --- a/core/trie2/trieutils/bitarray.go +++ b/core/trie2/trieutils/bitarray.go @@ -323,7 +323,7 @@ func (b *BitArray) Append(x, y *BitArray) *BitArray { // Sets the bit array to the concatenation of x and a single bit. func (b *BitArray) AppendBit(x *BitArray, bit uint8) *BitArray { - return b.Append(b, new(BitArray).SetBit(bit)) + return b.Append(x, new(BitArray).SetBit(bit)) } // Sets the bit array to a subset of x from startPos (inclusive) to endPos (exclusive), From 6db9aef182981ec8da853afe5bf1f13d3d136dd3 Mon Sep 17 00:00:00 2001 From: weiihann Date: Tue, 21 Jan 2025 10:28:29 +0800 Subject: [PATCH 26/44] pass base test but got nil dereference bug --- core/trie2/proof.go | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/core/trie2/proof.go b/core/trie2/proof.go index d9792c1d23..8c1f60d03b 100644 --- a/core/trie2/proof.go +++ b/core/trie2/proof.go @@ -491,15 +491,16 @@ findFork: func unset(parent node, child node, key *Path, pos uint8, removeLeft bool) error { switch cld := child.(type) { case *binaryNode: + keyBit := key.Bit(pos) cld.flags = newFlag() // Mark dirty if removeLeft { // Remove left child if we're removing left side - if key.MSB() == 1 { + if keyBit == 1 { cld.children[0] = nil } } else { // Remove right child if we're removing right side - if key.MSB() == 0 { + if keyBit == 0 { cld.children[1] = nil } } @@ -507,22 +508,23 @@ func unset(parent node, child node, key *Path, pos uint8, removeLeft bool) error case *edgeNode: keyPos := new(Path).LSBs(key, pos) + keyBit := key.Bit(pos - 1) if !cld.pathMatches(keyPos) { // Found fork point, non-existent branch if removeLeft { if cld.path.Cmp(keyPos) < 0 { // Edge node path is in range, unset entire branch - parent.(*binaryNode).children[key.Bit(pos-1)] = nil + parent.(*binaryNode).children[keyBit] = nil } } else { if cld.path.Cmp(keyPos) > 0 { - parent.(*binaryNode).children[key.Bit(pos-1)] = nil + parent.(*binaryNode).children[keyBit] = nil } } return nil } if _, ok := cld.child.(*valueNode); ok { - parent.(*binaryNode).children[key.Bit(pos-1)] = nil + parent.(*binaryNode).children[keyBit] = nil return nil } cld.flags = newFlag() From 35ac499cc4da6a31a7145cf0437cde50de38d127 Mon Sep 17 00:00:00 2001 From: weiihann Date: Tue, 21 Jan 2025 10:45:47 +0800 Subject: [PATCH 27/44] fix invalid non existent proof --- core/trie2/hasher.go | 25 +++++++++++++++++++++---- core/trie2/node.go | 3 +++ 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/core/trie2/hasher.go b/core/trie2/hasher.go index 401e9c9c07..0fe11c6590 100644 --- a/core/trie2/hasher.go +++ b/core/trie2/hasher.go @@ -67,18 +67,35 @@ func (h *hasher) hashBinaryChildren(n *binaryNode) (collapsed, cached *binaryNod go func() { defer wg.Done() - collapsed.children[0], cached.children[0] = h.hash(n.children[0]) + if n.children[0] != nil { + collapsed.children[0], cached.children[0] = h.hash(n.children[0]) + } else { + collapsed.children[0], cached.children[0] = nilValueNode, nilValueNode + } }() go func() { defer wg.Done() - collapsed.children[1], cached.children[1] = h.hash(n.children[1]) + if n.children[1] != nil { + collapsed.children[1], cached.children[1] = h.hash(n.children[1]) + } else { + collapsed.children[1], cached.children[1] = nilValueNode, nilValueNode + } }() wg.Wait() } else { - collapsed.children[0], cached.children[0] = h.hash(n.children[0]) - collapsed.children[1], cached.children[1] = h.hash(n.children[1]) + if n.children[0] != nil { + collapsed.children[0], cached.children[0] = h.hash(n.children[0]) + } else { + collapsed.children[0], cached.children[0] = nilValueNode, nilValueNode + } + + if n.children[1] != nil { + collapsed.children[1], cached.children[1] = h.hash(n.children[1]) + } else { + collapsed.children[1], cached.children[1] = nilValueNode, nilValueNode + } } return collapsed, cached diff --git a/core/trie2/node.go b/core/trie2/node.go index bd1884b092..f5fc456fe9 100644 --- a/core/trie2/node.go +++ b/core/trie2/node.go @@ -37,6 +37,9 @@ type ( valueNode struct{ felt.Felt } ) +// nilValueNode is used when collapsing internal trie nodes for hashing, since unset children need to be hashed correctly +var nilValueNode = &valueNode{felt.Felt{}} + type nodeFlag struct { hash *hashNode dirty bool From 0d31b31af1d04c1bfadb8d23e138046777a6924b Mon Sep 17 00:00:00 2001 From: weiihann Date: Tue, 21 Jan 2025 19:22:53 +0800 Subject: [PATCH 28/44] fix hasRightElement --- core/trie2/proof.go | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/core/trie2/proof.go b/core/trie2/proof.go index 8c1f60d03b..11b6ab7382 100644 --- a/core/trie2/proof.go +++ b/core/trie2/proof.go @@ -544,12 +544,12 @@ func hasRightElement(node node, key Path) bool { for node != nil { switch n := node.(type) { case *binaryNode: - for _, cn := range n.children { - if cn != nil { - return true - } + bit := key.MSB() + if bit == 0 && n.children[1] != nil { + // right sibling exists + return true } - node = n.children[key.MSB()] + node = n.children[bit] key.LSBs(&key, 1) case *edgeNode: if !n.pathMatches(&key) { @@ -559,6 +559,7 @@ func hasRightElement(node node, key Path) bool { return n.path.Cmp(&key) > 0 } node = n.child + key.LSBs(&key, n.path.Len()) case *valueNode: return false // resolved the whole path default: From 68e949c7f67edde1db5082b9d4137113ba13c84a Mon Sep 17 00:00:00 2001 From: weiihann Date: Tue, 21 Jan 2025 21:30:25 +0800 Subject: [PATCH 29/44] create TrieDB interface --- core/trie2/trie.go | 7 ++++--- core/trie2/triedb/database.go | 28 ++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/core/trie2/trie.go b/core/trie2/trie.go index 60744add6c..4e233107b4 100644 --- a/core/trie2/trie.go +++ b/core/trie2/trie.go @@ -32,7 +32,7 @@ type Trie struct { hashFn crypto.HashFn // The underlying database to store and retrieve trie nodes - db *triedb.Database + db triedb.TrieDB // Check if the trie has been committed. Trie is unusable once committed. committed bool @@ -74,6 +74,7 @@ func NewEmpty(height uint8, hashFn crypto.HashFn) *Trie { hashFn: hashFn, root: nil, nodeTracer: newTracer(), + db: triedb.EmptyDatabase{}, } } @@ -280,9 +281,9 @@ func (t *Trie) insert(n node, prefix, key *Path, value node) (bool, node, error) // Otherwise branch out at the bit position where they differ branch := &binaryNode{flags: newFlag()} var err error - PathPrefix := new(Path).MSBs(n.path, match.Len()+1) + pathPrefix := new(Path).MSBs(n.path, match.Len()+1) _, branch.children[n.path.Bit(match.Len())], err = t.insert( - nil, new(Path).Append(prefix, PathPrefix), new(Path).LSBs(n.path, match.Len()+1), n.child, + nil, new(Path).Append(prefix, pathPrefix), new(Path).LSBs(n.path, match.Len()+1), n.child, ) if err != nil { return false, n, err diff --git a/core/trie2/triedb/database.go b/core/trie2/triedb/database.go index ac0b908e1b..979c0ab840 100644 --- a/core/trie2/triedb/database.go +++ b/core/trie2/triedb/database.go @@ -2,6 +2,7 @@ package triedb import ( "bytes" + "errors" "sync" "github.com/NethermindEth/juno/core/felt" @@ -9,12 +10,25 @@ import ( "github.com/NethermindEth/juno/db" ) +var ErrCallEmptyDatabase = errors.New("call to empty database") + var dbBufferPool = sync.Pool{ New: func() any { return new(bytes.Buffer) }, } +var ( + _ TrieDB = (*Database)(nil) + _ TrieDB = (*EmptyDatabase)(nil) +) + +type TrieDB interface { + Get(buf *bytes.Buffer, owner felt.Felt, path trieutils.BitArray) (int, error) + Put(owner felt.Felt, path trieutils.BitArray, blob []byte) error + Delete(owner felt.Felt, path trieutils.BitArray) error +} + type Database struct { txn db.Transaction prefix db.Bucket @@ -98,3 +112,17 @@ func (d *Database) dbKey(buf *bytes.Buffer, owner felt.Felt, path trieutils.BitA return nil } + +type EmptyDatabase struct{} + +func (EmptyDatabase) Get(buf *bytes.Buffer, owner felt.Felt, path trieutils.BitArray) (int, error) { + return 0, ErrCallEmptyDatabase +} + +func (EmptyDatabase) Put(owner felt.Felt, path trieutils.BitArray, blob []byte) error { + return ErrCallEmptyDatabase +} + +func (EmptyDatabase) Delete(owner felt.Felt, path trieutils.BitArray) error { + return ErrCallEmptyDatabase +} From 09ea3be987518197aed89a3404797a561917feb2 Mon Sep 17 00:00:00 2001 From: weiihann Date: Tue, 21 Jan 2025 21:30:46 +0800 Subject: [PATCH 30/44] range proof test cases all pass --- core/trie2/proof.go | 16 ++++------- core/trie2/proof_test.go | 49 +++++++------------------------- core/trie2/trieutils/bitarray.go | 14 +++++++++ 3 files changed, 29 insertions(+), 50 deletions(-) diff --git a/core/trie2/proof.go b/core/trie2/proof.go index 11b6ab7382..432b410363 100644 --- a/core/trie2/proof.go +++ b/core/trie2/proof.go @@ -250,17 +250,11 @@ func VerifyRangeProof(rootHash, first *felt.Felt, keys, values []*felt.Felt, pro return false, err } - fmt.Println("Before unsetInternal") - fmt.Println(root.String()) - empty, err := unsetInternal(root, firstKey, lastKey) if err != nil { return false, err } - fmt.Println("After unsetInternal") - fmt.Println(root.String()) - tr := NewEmpty(contractClassTrieHeight, crypto.Pedersen) if !empty { tr.root = root @@ -272,9 +266,6 @@ func VerifyRangeProof(rootHash, first *felt.Felt, keys, values []*felt.Felt, pro } } - fmt.Println("After update") - fmt.Println(tr.root.String()) - newRoot := tr.Hash() // Verify that the recomputed root hash matches the provided root hash @@ -510,14 +501,17 @@ func unset(parent node, child node, key *Path, pos uint8, removeLeft bool) error keyPos := new(Path).LSBs(key, pos) keyBit := key.Bit(pos - 1) if !cld.pathMatches(keyPos) { + // We append zeros to the path to match the length of the remaining key + // The key length is guaranteed to be >= path length + edgePath := new(Path).AppendZeros(cld.path, keyPos.Len()-cld.path.Len()) // Found fork point, non-existent branch if removeLeft { - if cld.path.Cmp(keyPos) < 0 { + if edgePath.Cmp(keyPos) < 0 { // Edge node path is in range, unset entire branch parent.(*binaryNode).children[keyBit] = nil } } else { - if cld.path.Cmp(keyPos) > 0 { + if edgePath.Cmp(keyPos) > 0 { parent.(*binaryNode).children[keyBit] = nil } } diff --git a/core/trie2/proof_test.go b/core/trie2/proof_test.go index 65450e4def..2853d2935f 100644 --- a/core/trie2/proof_test.go +++ b/core/trie2/proof_test.go @@ -11,8 +11,6 @@ import ( ) func TestProve(t *testing.T) { - t.Parallel() - n := 1000 tempTrie, records := nonRandomTrie(t, n) @@ -35,8 +33,6 @@ func TestProve(t *testing.T) { } func TestProveNonExistent(t *testing.T) { - t.Parallel() - n := 1000 tempTrie, _ := nonRandomTrie(t, n) @@ -57,7 +53,6 @@ func TestProveNonExistent(t *testing.T) { } func TestProveRandom(t *testing.T) { - t.Parallel() tempTrie, records := randomTrie(t, 1000) for _, record := range records { @@ -77,8 +72,6 @@ func TestProveRandom(t *testing.T) { } func TestProveCustom(t *testing.T) { - t.Parallel() - tests := []testTrie{ { name: "simple binary", @@ -167,8 +160,6 @@ func TestProveCustom(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - t.Parallel() - tr, _ := test.buildFn(t) for _, tc := range test.testKeys { @@ -192,8 +183,6 @@ func TestProveCustom(t *testing.T) { // TestRangeProof tests normal range proof with both edge proofs func TestRangeProof(t *testing.T) { - t.Parallel() - n := 500 tr, records := randomTrie(t, n) root := tr.Hash() @@ -242,8 +231,6 @@ func TestRangeProofNonRandom(t *testing.T) { // TestRangeProofWithNonExistentProof tests normal range proof with non-existent proofs func TestRangeProofWithNonExistentProof(t *testing.T) { - t.Parallel() - n := 500 tr, records := randomTrie(t, n) root := tr.Hash() @@ -276,8 +263,6 @@ func TestRangeProofWithNonExistentProof(t *testing.T) { // TestRangeProofWithInvalidNonExistentProof tests range proof with invalid non-existent proofs. // One scenario is when there is a gap between the first element and the left edge proof. func TestRangeProofWithInvalidNonExistentProof(t *testing.T) { - t.Parallel() - n := 500 tr, records := randomTrie(t, n) root := tr.Hash() @@ -374,8 +359,6 @@ func TestOneElementRangeProof(t *testing.T) { // TestAllElementsRangeProof tests the range proof with all elements and nil proof. func TestAllElementsRangeProof(t *testing.T) { - t.Parallel() - n := 1000 tr, records := randomTrie(t, n) root := tr.Hash() @@ -401,8 +384,6 @@ func TestAllElementsRangeProof(t *testing.T) { // TestSingleSideRangeProof tests the range proof starting with zero. func TestSingleSideRangeProof(t *testing.T) { - t.Parallel() - tr, records := randomTrie(t, 1000) root := tr.Hash() @@ -424,9 +405,6 @@ func TestSingleSideRangeProof(t *testing.T) { } func TestGappedRangeProof(t *testing.T) { - t.Parallel() - // t.Skip("gapped keys will sometimes succeed, the current proof format is not able to handle this") - tr, records := nonRandomTrie(t, 5) root := tr.Hash() @@ -451,8 +429,6 @@ func TestGappedRangeProof(t *testing.T) { } func TestEmptyRangeProof(t *testing.T) { - t.Parallel() - tr, records := randomTrie(t, 1000) root := tr.Hash() @@ -480,9 +456,7 @@ func TestEmptyRangeProof(t *testing.T) { } func TestHasRightElement(t *testing.T) { - t.Parallel() - - tr, records := randomTrie(t, 500) + tr, records := randomTrie(t, 10000) root := tr.Hash() cases := []struct { @@ -532,12 +506,10 @@ func TestHasRightElement(t *testing.T) { // TestBadRangeProof generates random bad proof scenarios and verifies that the proof is invalid. func TestBadRangeProof(t *testing.T) { - t.Parallel() - - tr, records := randomTrie(t, 1000) + tr, records := randomTrie(t, 5000) root := tr.Hash() - for i := 0; i < 100; i++ { + for i := 0; i < 500; i++ { start := rand.Intn(len(records)) end := rand.Intn(len(records)-start) + start + 1 @@ -553,7 +525,7 @@ func TestBadRangeProof(t *testing.T) { } first := keys[0] - testCase := rand.Intn(5) + testCase := rand.Intn(6) index := rand.Intn(end - start) switch testCase { @@ -572,13 +544,12 @@ func TestBadRangeProof(t *testing.T) { keys[index] = &felt.Zero case 4: // set random value to empty values[index] = &felt.Zero - // TODO(weiihann): gapped proof will fail sometimes - // case 5: // gapped - // if end-start < 100 || index == 0 || index == end-start-1 { - // continue - // } - // keys = append(keys[:index], keys[index+1:]...) - // values = append(values[:index], values[index+1:]...) + case 5: // gapped + if end-start < 100 || index == 0 || index == end-start-1 { + continue + } + keys = append(keys[:index], keys[index+1:]...) + values = append(values[:index], values[index+1:]...) } _, err = VerifyRangeProof(&root, first, keys, values, proof) if err == nil { diff --git a/core/trie2/trieutils/bitarray.go b/core/trie2/trieutils/bitarray.go index 27f2f9ba7e..2ebe72d77d 100644 --- a/core/trie2/trieutils/bitarray.go +++ b/core/trie2/trieutils/bitarray.go @@ -326,6 +326,11 @@ func (b *BitArray) AppendBit(x *BitArray, bit uint8) *BitArray { return b.Append(x, new(BitArray).SetBit(bit)) } +// Sets the bit array to the concatenation of x and n zeros. +func (b *BitArray) AppendZeros(x *BitArray, n uint8) *BitArray { + return b.Append(x, new(BitArray).Zeros(n)) +} + // Sets the bit array to a subset of x from startPos (inclusive) to endPos (exclusive), // where position 0 is the MSB. If startPos >= endPos or if startPos >= x.len, // returns an empty BitArray. @@ -796,6 +801,15 @@ func (b *BitArray) Ones(length uint8) *BitArray { return b } +func (b *BitArray) Zeros(length uint8) *BitArray { + b.len = length + b.words[0] = 0 + b.words[1] = 0 + b.words[2] = 0 + b.words[3] = 0 + return b +} + // Returns the position of the first '1' bit in the array, scanning from most significant to least significant bit. // The bit position is counted from the least significant bit, starting at 0. // For example: From 6e33a7a3a84cd8c5df2459909f462cc740328c5f Mon Sep 17 00:00:00 2001 From: weiihann Date: Tue, 21 Jan 2025 21:42:15 +0800 Subject: [PATCH 31/44] hasRightElement fix edgeNode case --- core/trie2/proof.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/core/trie2/proof.go b/core/trie2/proof.go index 432b410363..3ec88c07ca 100644 --- a/core/trie2/proof.go +++ b/core/trie2/proof.go @@ -550,7 +550,13 @@ func hasRightElement(node node, key Path) bool { // There's a divergence in the path, check if the node path is greater than the key // If so, that means that this node comes after the search key, which indicates that // there are elements with larger values - return n.path.Cmp(&key) > 0 + var edgePath *Path + if key.Len() > n.path.Len() { + edgePath = new(Path).AppendZeros(n.path, key.Len()-n.path.Len()) + } else { + edgePath = n.path + } + return edgePath.Cmp(&key) > 0 } node = n.child key.LSBs(&key, n.path.Len()) From aa57d85f37db1fdd37e11186b1aea4c23905077f Mon Sep 17 00:00:00 2001 From: weiihann Date: Wed, 22 Jan 2025 10:32:49 +0800 Subject: [PATCH 32/44] remove test --- core/trie2/proof_test.go | 22 ---------------------- 1 file changed, 22 deletions(-) diff --git a/core/trie2/proof_test.go b/core/trie2/proof_test.go index 2853d2935f..ed462e7992 100644 --- a/core/trie2/proof_test.go +++ b/core/trie2/proof_test.go @@ -207,28 +207,6 @@ func TestRangeProof(t *testing.T) { } } -func TestRangeProofNonRandom(t *testing.T) { - tr, records := nonRandomTrie(t, 6) - root := tr.Hash() - - t.Log("Original trie") - t.Log(tr.root.String()) - - proof := NewProofNodeSet() - err := tr.GetRangeProof(records[0].key, records[3].key, proof) - require.NoError(t, err) - - for i := 0; i < 4; i++ { - t.Logf("key %s: value %s", records[i].key.String(), records[i].value.String()) - } - - keys := []*felt.Felt{records[0].key, records[1].key, records[2].key, records[3].key} - values := []*felt.Felt{records[0].value, records[1].value, records[2].value, records[3].value} - - _, err = VerifyRangeProof(&root, records[0].key, keys, values, proof) - require.NoError(t, err) -} - // TestRangeProofWithNonExistentProof tests normal range proof with non-existent proofs func TestRangeProofWithNonExistentProof(t *testing.T) { n := 500 From 20b713a4e96aec42826ca451642784fc117df84d Mon Sep 17 00:00:00 2001 From: weiihann Date: Wed, 22 Jan 2025 18:19:27 +0800 Subject: [PATCH 33/44] fix node encoding and add tests --- core/trie2/node_enc.go | 33 +++++-- core/trie2/node_enc_test.go | 178 ++++++++++++++++++++++++++++++++++++ 2 files changed, 202 insertions(+), 9 deletions(-) create mode 100644 core/trie2/node_enc_test.go diff --git a/core/trie2/node_enc.go b/core/trie2/node_enc.go index 8b0074c96f..089011f3af 100644 --- a/core/trie2/node_enc.go +++ b/core/trie2/node_enc.go @@ -11,7 +11,7 @@ import ( const ( binaryNodeSize = 2 * hashOrValueNodeSize // LeftHash + RightHash - edgeNodeSize = trieutils.MaxBitArraySize + hashOrValueNodeSize // Path + Child Hash (max size, could be less) + edgeNodeMaxSize = trieutils.MaxBitArraySize + hashOrValueNodeSize // Path + Child Hash hashOrValueNodeSize = felt.Bytes ) @@ -72,7 +72,10 @@ func nodeToBytes(n node) []byte { if err := n.write(buf); err != nil { panic(err) } - return buf.Bytes() + + res := make([]byte, buf.Len()) + copy(res, buf.Bytes()) + return res } func decodeNode(blob []byte, hash felt.Felt, pathLen, maxPathLen uint8) (node, error) { @@ -82,14 +85,20 @@ func decodeNode(blob []byte, hash felt.Felt, pathLen, maxPathLen uint8) (node, e isLeaf bool ) + if pathLen > maxPathLen { + return nil, fmt.Errorf("node path length (%d) is greater than max path length (%d)", pathLen, maxPathLen) + } + isLeaf = pathLen == maxPathLen switch len(blob) { case hashOrValueNodeSize: + var f felt.Felt + f.SetBytes(blob) if isLeaf { - n = &valueNode{Felt: hash} + n = &valueNode{Felt: f} } else { - n = &hashNode{Felt: hash} + n = &hashNode{Felt: f} } case binaryNodeSize: binary := &binaryNode{flags: nodeFlag{hash: &hashNode{Felt: hash}}} // cache the hash @@ -104,16 +113,22 @@ func decodeNode(blob []byte, hash felt.Felt, pathLen, maxPathLen uint8) (node, e n = binary default: - // Edge node size is capped, if the blob is larger than the max size, it's invalid - if len(blob) > edgeNodeSize { + // Ensure the blob length is within the valid range for an edge node + if len(blob) > edgeNodeMaxSize || len(blob) < hashOrValueNodeSize { return nil, fmt.Errorf("invalid node size: %d", len(blob)) } - edge := &edgeNode{flags: nodeFlag{hash: &hashNode{Felt: hash}}} // cache the hash - edge.child, err = decodeNode(blob[:hashOrValueNodeSize], hash, pathLen, maxPathLen) + + edge := &edgeNode{ + path: &trieutils.BitArray{}, + flags: nodeFlag{hash: &hashNode{Felt: hash}}, // cache the hash + } + edge.child, err = decodeNode(blob[:hashOrValueNodeSize], felt.Felt{}, pathLen, maxPathLen) if err != nil { return nil, err } - edge.path.UnmarshalBinary(blob[hashOrValueNodeSize:]) + if err := edge.path.UnmarshalBinary(blob[hashOrValueNodeSize:]); err != nil { + return nil, err + } // We do another path length check to see if the node is a leaf if pathLen+edge.path.Len() == maxPathLen { diff --git a/core/trie2/node_enc_test.go b/core/trie2/node_enc_test.go new file mode 100644 index 0000000000..022e1730f9 --- /dev/null +++ b/core/trie2/node_enc_test.go @@ -0,0 +1,178 @@ +package trie2 + +import ( + "bytes" + "testing" + + "github.com/NethermindEth/juno/core/crypto" + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/trie2/trieutils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNodeEncodingDecoding(t *testing.T) { + newFelt := func(v uint64) felt.Felt { + var f felt.Felt + f.SetUint64(v) + return f + } + + // Helper to create a path with specific bits + newPath := func(bits ...uint8) *trieutils.BitArray { + path := &trieutils.BitArray{} + for _, bit := range bits { + path.AppendBit(path, bit) + } + return path + } + + tests := []struct { + name string + node node + pathLen uint8 + maxPath uint8 + wantErr bool + errMsg string + }{ + { + name: "edge node with value child", + node: &edgeNode{ + path: newPath(1, 0, 1), + child: &valueNode{Felt: newFelt(123)}, + flags: nodeFlag{}, + }, + pathLen: 8, + maxPath: 8, + }, + { + name: "edge node with hash child", + node: &edgeNode{ + path: newPath(0, 1, 0), + child: &hashNode{Felt: newFelt(456)}, + flags: nodeFlag{}, + }, + pathLen: 3, + maxPath: 8, + }, + { + name: "binary node with two hash children", + node: &binaryNode{ + children: [2]node{ + &hashNode{Felt: newFelt(111)}, + &hashNode{Felt: newFelt(222)}, + }, + flags: nodeFlag{}, + }, + pathLen: 0, + maxPath: 8, + }, + { + name: "binary node with two leaf children", + node: &binaryNode{ + children: [2]node{ + &valueNode{Felt: newFelt(555)}, + &valueNode{Felt: newFelt(666)}, + }, + flags: nodeFlag{}, + }, + pathLen: 7, + maxPath: 8, + }, + { + name: "value node at max path length", + node: &valueNode{Felt: newFelt(999)}, + pathLen: 8, + maxPath: 8, + }, + { + name: "hash node below max path length", + node: &hashNode{Felt: newFelt(1000)}, + pathLen: 7, + maxPath: 8, + }, + { + name: "edge node with empty path", + node: &edgeNode{ + path: newPath(), + child: &hashNode{Felt: newFelt(1111)}, + flags: nodeFlag{}, + }, + pathLen: 0, + maxPath: 8, + }, + { + name: "edge node with max length path", + node: &edgeNode{ + path: newPath(1, 1, 1, 1, 1, 1, 1, 1), + child: &valueNode{Felt: newFelt(2222)}, + flags: nodeFlag{}, + }, + pathLen: 0, + maxPath: 8, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Encode the node + encoded := nodeToBytes(tt.node) + + // Try to decode + hash := tt.node.hash(crypto.Pedersen) + decoded, err := decodeNode(encoded, *hash, tt.pathLen, tt.maxPath) + + if tt.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errMsg) + return + } + + require.NoError(t, err) + require.NotNil(t, decoded) + + // Re-encode the decoded node and compare with original encoding + reEncoded := nodeToBytes(decoded) + assert.True(t, bytes.Equal(encoded, reEncoded), "re-encoded node doesn't match original encoding") + + // Test specific node type assertions and properties + switch n := decoded.(type) { + case *edgeNode: + original := tt.node.(*edgeNode) + assert.Equal(t, original.path.String(), n.path.String(), "edge node paths don't match") + case *binaryNode: + // Verify both children are present + assert.NotNil(t, n.children[0], "left child is nil") + assert.NotNil(t, n.children[1], "right child is nil") + case *valueNode: + original := tt.node.(*valueNode) + assert.True(t, original.Felt.Equal(&n.Felt), "value node felts don't match") + case *hashNode: + original := tt.node.(*hashNode) + assert.True(t, original.Felt.Equal(&n.Felt), "hash node felts don't match") + } + }) + } +} + +func TestNodeEncodingDecodingBoundary(t *testing.T) { + // Test empty/nil nodes + t.Run("nil node encoding", func(t *testing.T) { + assert.Panics(t, func() { nodeToBytes(nil) }) + }) + + // Test with invalid path lengths + t.Run("invalid path lengths", func(t *testing.T) { + hash := felt.Zero + blob := make([]byte, hashOrValueNodeSize) + _, err := decodeNode(blob, hash, 255, 8) // pathLen > maxPath + require.Error(t, err) + }) + + // Test with empty buffer + t.Run("empty buffer", func(t *testing.T) { + hash := felt.Zero + _, err := decodeNode([]byte{}, hash, 0, 8) + require.Error(t, err) + }) +} From 252375c8ffbac4a373b1f06aec2ce3ca3d00fa7f Mon Sep 17 00:00:00 2001 From: weiihann Date: Wed, 22 Jan 2025 22:16:14 +0800 Subject: [PATCH 34/44] everything pass!!!!! --- core/trie2/collector.go | 11 +- core/trie2/node_enc.go | 38 ++++- core/trie2/proof.go | 2 +- core/trie2/trie.go | 63 +++++++-- core/trie2/trie_test.go | 270 ++++++++++++++++++++++++++++++++++-- core/trie2/trienode/node.go | 4 + 6 files changed, 354 insertions(+), 34 deletions(-) diff --git a/core/trie2/collector.go b/core/trie2/collector.go index 8b47441b98..fa4976b8b3 100644 --- a/core/trie2/collector.go +++ b/core/trie2/collector.go @@ -39,9 +39,7 @@ func (c *collector) collect(path *Path, n node, parallel bool) node { // If the child is a binary node, recurse into it. // Otherwise, it can only be a hashNode or valueNode. // Combination of edge (parent) + edge (child) is not possible. - if _, ok := cn.child.(*binaryNode); ok { - collapsed.child = c.collect(new(Path).Append(path, cn.path), cn.child, parallel) - } + collapsed.child = c.collect(new(Path).Append(path, cn.path), cn.child, parallel) return c.store(path, collapsed) case *binaryNode: collapsed := cn.copy() @@ -70,7 +68,7 @@ func (c *collector) collectChildren(path *Path, n *binaryNode, parallel bool) [2 } // Create child path - childPath := new(Path).Append(path, new(Path).SetBit(uint8(i))) + childPath := new(Path).AppendBit(path, uint8(i)) if !parallel { children[i] = c.collect(childPath, child, parallel) @@ -113,11 +111,12 @@ func (c *collector) collectChildren(path *Path, n *binaryNode, parallel bool) [2 func (c *collector) store(path *Path, n node) node { hash, _ := n.cache() + blob := nodeToBytes(n) if hash == nil { // this is a value node - c.nodes.Add(*path, trienode.NewNode(felt.Felt{}, nodeToBytes(n))) + c.nodes.Add(*path, trienode.NewNode(felt.Felt{}, blob)) return n } - c.nodes.Add(*path, trienode.NewNode(hash.Felt, nodeToBytes(n))) + c.nodes.Add(*path, trienode.NewNode(hash.Felt, blob)) return hash } diff --git a/core/trie2/node_enc.go b/core/trie2/node_enc.go index 089011f3af..3d454391e7 100644 --- a/core/trie2/node_enc.go +++ b/core/trie2/node_enc.go @@ -2,6 +2,7 @@ package trie2 import ( "bytes" + "errors" "fmt" "sync" @@ -21,7 +22,20 @@ var bufferPool = sync.Pool{ }, } +// The initial idea was to differentiate between binary and edge nodes by the size of the buffer. +// However, there may be a case where the size of the buffer is the same for both binary and edge nodes. +// In this case, we need to prepend the buffer with a byte to indicate the type of the node. +// Value and hash nodes do not need this because their size is fixed. +const ( + binaryNodeType byte = iota + 1 + edgeNodeType +) + func (n *binaryNode) write(buf *bytes.Buffer) error { + if err := buf.WriteByte(binaryNodeType); err != nil { + return err + } + if err := n.children[0].write(buf); err != nil { return err } @@ -34,6 +48,10 @@ func (n *binaryNode) write(buf *bytes.Buffer) error { } func (n *edgeNode) write(buf *bytes.Buffer) error { + if err := buf.WriteByte(edgeNodeType); err != nil { + return err + } + if err := n.child.write(buf); err != nil { return err } @@ -79,6 +97,10 @@ func nodeToBytes(n node) []byte { } func decodeNode(blob []byte, hash felt.Felt, pathLen, maxPathLen uint8) (node, error) { + if len(blob) == 0 { + return nil, errors.New("cannot decode empty blob") + } + var ( n node err error @@ -91,8 +113,7 @@ func decodeNode(blob []byte, hash felt.Felt, pathLen, maxPathLen uint8) (node, e isLeaf = pathLen == maxPathLen - switch len(blob) { - case hashOrValueNodeSize: + if len(blob) == hashOrValueNodeSize { var f felt.Felt f.SetBytes(blob) if isLeaf { @@ -100,7 +121,14 @@ func decodeNode(blob []byte, hash felt.Felt, pathLen, maxPathLen uint8) (node, e } else { n = &hashNode{Felt: f} } - case binaryNodeSize: + return n, nil + } + + nodeType := blob[0] + blob = blob[1:] + + switch nodeType { + case binaryNodeType: binary := &binaryNode{flags: nodeFlag{hash: &hashNode{Felt: hash}}} // cache the hash binary.children[0], err = decodeNode(blob[:hashOrValueNodeSize], hash, pathLen+1, maxPathLen) if err != nil { @@ -112,7 +140,7 @@ func decodeNode(blob []byte, hash felt.Felt, pathLen, maxPathLen uint8) (node, e } n = binary - default: + case edgeNodeType: // Ensure the blob length is within the valid range for an edge node if len(blob) > edgeNodeMaxSize || len(blob) < hashOrValueNodeSize { return nil, fmt.Errorf("invalid node size: %d", len(blob)) @@ -136,6 +164,8 @@ func decodeNode(blob []byte, hash felt.Felt, pathLen, maxPathLen uint8) (node, e } n = edge + default: + panic(fmt.Sprintf("unknown decode node type: %d", nodeType)) } return n, nil diff --git a/core/trie2/proof.go b/core/trie2/proof.go index 3ec88c07ca..322a6ff2c7 100644 --- a/core/trie2/proof.go +++ b/core/trie2/proof.go @@ -57,7 +57,7 @@ func (t *Trie) Prove(key *felt.Felt, proof *ProofNodeSet) error { } rn = resolved default: - panic(fmt.Sprintf("unknown node type: %T", n)) + panic(fmt.Sprintf("key: %s, unknown node type: %T", key.String(), n)) } } diff --git a/core/trie2/trie.go b/core/trie2/trie.go index 4e233107b4..78e2e036a9 100644 --- a/core/trie2/trie.go +++ b/core/trie2/trie.go @@ -97,19 +97,16 @@ func (t *Trie) Update(key, value *felt.Felt) error { // Retrieves the value associated with the given key. // Returns felt.Zero if the key doesn't exist. // May update the trie's internal structure if nodes need to be resolved. +// TODO(weiihann): +// The State should keep track of the modified key and values, so that we can avoid traversing the trie +// No action needed for the Trie, remove this once State provides the functionality func (t *Trie) Get(key *felt.Felt) (felt.Felt, error) { if t.committed { return felt.Zero, ErrCommitted } k := t.FeltToPath(key) - // We first check if the value node exists in the trie database directly - v, err := t.resolveNode(&hashNode{}, k) - if _, ok := v.(*valueNode); ok && err == nil { - return v.(*valueNode).Felt, nil - } - // Otherwise, we need to traverse the trie to find the value node var ret felt.Felt val, root, didResolve, err := t.get(t.root, new(Path), &k) // In Starknet, a non-existent key is mapped to felt.Zero @@ -156,6 +153,25 @@ func (t *Trie) Commit() (felt.Felt, error) { t.committed = true }() + // Trie is empty and can be classified into two types of situations: + // (a) The trie was empty and no update happens => return empty root + // (b) The trie was non-empty and all nodes are dropped => commit and return empty root + if t.root == nil { + paths := t.nodeTracer.deletedNodes() + if len(paths) == 0 { // case (a) + return felt.Zero, nil + } + // case (b) + nodes := trienode.NewNodeSet(t.owner) + for _, path := range paths { + nodes.Add(path, trienode.NewDeleted()) + } + err := nodes.ForEach(true, func(key trieutils.BitArray, node *trienode.Node) error { + return t.db.Delete(t.owner, key) + }) + return felt.Zero, err + } + rootHash := t.Hash() if hashedNode, dirty := t.root.cache(); !dirty { t.root = hashedNode @@ -174,6 +190,9 @@ func (t *Trie) Commit() (felt.Felt, error) { if node.IsDeleted() { return t.db.Delete(t.owner, key) } + // decodeNode, _ := decodeNode(node.Blob(), node.Hash(), key.Len(), t.height) + // fmt.Printf("key: %v value: blob %v hash %v\n", key.String(), node.Blob(), node.Hash()) + // fmt.Printf("decodeNode: %v\n", decodeNode.String()) return t.db.Put(t.owner, key, node.Blob()) }) if err != nil { @@ -201,7 +220,7 @@ func (t *Trie) get(n node, prefix, key *Path) (*felt.Felt, node, bool, error) { switch n := n.(type) { case *edgeNode: if !n.pathMatches(key) { - return nil, nil, false, nil + return nil, n, false, nil } val, child, didResolve, err := t.get(n.child, new(Path).Append(prefix, n.path), key.LSBs(key, n.path.Len())) if err == nil && didResolve { @@ -218,7 +237,7 @@ func (t *Trie) get(n node, prefix, key *Path) (*felt.Felt, node, bool, error) { } return val, n, didResolve, err case *hashNode: - child, err := t.resolveNode(n, *key) + child, err := t.resolveNode(n, *prefix) if err != nil { return nil, nil, false, err } @@ -258,7 +277,7 @@ func (t *Trie) insert(n node, prefix, key *Path, value node) (bool, node, error) if key.Len() == 0 { if v, ok := n.(*valueNode); ok { vFelt := value.(*valueNode).Felt - return v.Equal(&vFelt), value, nil + return !v.Equal(&vFelt), value, nil } return true, value, nil } @@ -330,7 +349,7 @@ func (t *Trie) insert(n node, prefix, key *Path, value node) (bool, node, error) // Otherwise, return a new edge node with the Path being the key and the value as the child return true, &edgeNode{path: key, child: value, flags: newFlag()}, nil case *hashNode: - child, err := t.resolveNode(n, *key) + child, err := t.resolveNode(n, *prefix) if err != nil { return false, n, err } @@ -361,7 +380,7 @@ func (t *Trie) delete(n node, prefix, key *Path) (bool, node, error) { // Otherwise, key is longer than current node path, so we need to delete the child. // Child can never be nil because it's guaranteed that we have at least 2 other values in the subtrie. keyPrefix := new(Path).MSBs(key, n.path.Len()) - dirty, child, err := t.delete(n.child, new(Path).Append(prefix, keyPrefix), key.LSBs(key, n.path.Len())) + dirty, child, err := t.delete(n.child, new(Path).Append(prefix, keyPrefix), new(Path).LSBs(key, n.path.Len())) if !dirty || err != nil { return false, n, err } @@ -389,9 +408,22 @@ func (t *Trie) delete(n node, prefix, key *Path) (bool, node, error) { } // Otherwise, we need to combine this binary node with the other child + // If it's a hash node, we need to resolve it first. + // If the other child is an edge node, we prepend the bit prefix to the other child path other := bit ^ 1 bitPrefix := new(Path).SetBit(other) - if cn, ok := n.children[other].(*edgeNode); ok { // other child is an edge node, append the bit prefix to the child Path + + if hn, ok := n.children[other].(*hashNode); ok { + var cPath Path + cPath.Append(prefix, bitPrefix) + cNode, err := t.resolveNode(hn, cPath) + if err != nil { + return false, nil, err + } + n.children[other] = cNode + } + + if cn, ok := n.children[other].(*edgeNode); ok { t.nodeTracer.onDelete(new(Path).Append(prefix, bitPrefix)) return true, &edgeNode{ path: new(Path).Append(bitPrefix, cn.path), @@ -408,7 +440,7 @@ func (t *Trie) delete(n node, prefix, key *Path) (bool, node, error) { case nil: return false, nil, nil case *hashNode: - child, err := t.resolveNode(n, *key) + child, err := t.resolveNode(n, *prefix) if err != nil { return false, nil, err } @@ -431,6 +463,8 @@ func (t *Trie) resolveNode(hash *hashNode, path Path) (node, error) { bufferPool.Put(buf) }() + // fmt.Printf("resolveNode: path %v\n", path.String()) + _, err := t.db.Get(buf, t.owner, path) if err != nil { return nil, err @@ -441,6 +475,9 @@ func (t *Trie) resolveNode(hash *hashNode, path Path) (node, error) { } func (t *Trie) hashRoot() (node, node) { + if t.root == nil { + return &hashNode{Felt: felt.Zero}, nil + } h := newHasher(t.hashFn, t.pendingHashes > 100) // TODO(weiihann): 100 is arbitrary hashed, cached := h.hash(t.root) t.pendingHashes = 0 diff --git a/core/trie2/trie_test.go b/core/trie2/trie_test.go index 8ef0fd41f5..d728ed7896 100644 --- a/core/trie2/trie_test.go +++ b/core/trie2/trie_test.go @@ -1,12 +1,19 @@ package trie2 import ( + "fmt" + "io" "math/rand" + "reflect" "sort" "testing" + "testing/quick" + "github.com/NethermindEth/juno/core/crypto" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/utils" + "github.com/davecgh/go-spew/spew" "github.com/stretchr/testify/require" ) @@ -61,6 +68,12 @@ func TestDelete(t *testing.T) { // The expected hashes are taken from Pathfinder's tests func TestHash(t *testing.T) { + t.Run("empty", func(t *testing.T) { + tr, _ := NewEmptyPedersen() + hash := tr.Hash() + require.Equal(t, felt.Zero, hash) + }) + t.Run("one leaf", func(t *testing.T) { tr, _ := NewEmptyPedersen() err := tr.Update(new(felt.Felt).SetUint64(1), new(felt.Felt).SetUint64(2)) @@ -169,25 +182,262 @@ func TestHash(t *testing.T) { }) } +func TestMissingRoot(t *testing.T) { + var root felt.Felt + root.SetUint64(1) + + tr, err := New(TrieID(root), contractClassTrieHeight, crypto.Pedersen, db.NewMemTransaction()) + require.Nil(t, tr) + require.Error(t, err) +} + func TestCommit(t *testing.T) { - t.Run("sequential", func(t *testing.T) { - tr, _ := nonRandomTrie(t, 10000) + verifyCommit := func(t *testing.T, records []*keyValue) { + t.Helper() + db := db.NewMemTransaction() + tr, err := New(TrieID(felt.Zero), contractClassTrieHeight, crypto.Pedersen, db) + require.NoError(t, err) + + for _, record := range records { + err := tr.Update(record.key, record.value) + require.NoError(t, err) + } + + root, err := tr.Commit() + require.NoError(t, err) - _, err := tr.Commit() + tr2, err := New(TrieID(root), contractClassTrieHeight, crypto.Pedersen, db) require.NoError(t, err) + + for _, record := range records { + got, err := tr2.Get(record.key) + require.NoError(t, err) + require.True(t, got.Equal(record.value), "expected %v, got %v", record.value, got) + } + } + + t.Run("sequential", func(t *testing.T) { + _, records := nonRandomTrie(t, 10000) + verifyCommit(t, records) }) t.Run("random", func(t *testing.T) { - tr, _ := randomTrie(t, 10000) - - _, err := tr.Commit() - require.NoError(t, err) + _, records := randomTrie(t, 10000) + verifyCommit(t, records) }) } -func TestTrieOpsRandom(t *testing.T) { - t.Skip() - panic("implement me") +func TestRandom(t *testing.T) { + if err := quick.Check(runRandTestBool, nil); err != nil { + if cerr, ok := err.(*quick.CheckError); ok { + t.Fatalf("random test iteration %d failed: %s", cerr.Count, spew.Sdump(cerr.In)) + } + t.Fatal(err) + } +} + +func TestSpecificRandomFailure(t *testing.T) { + // Create test steps that match the failing sequence + key1 := utils.HexToFelt(t, "0x5c0d77d04056ae1e0c49bce8a3223b0373ccf94c1b04935d") + key2 := utils.HexToFelt(t, "0x19bc6d500358f3046b0743c903") + key3 := utils.HexToFelt(t, "0xf7d2180a5a138b325ea522e2b65b58a5dffb859b1fbbdab0e5efb125e8a840") + + steps := []randTestStep{ + {op: opProve, key: key1}, + {op: opCommit}, + {op: opDelete, key: key2}, + {op: opHash}, + {op: opCommit}, + {op: opCommit}, + {op: opGet, key: key2}, + {op: opUpdate, key: key2, value: new(felt.Felt).SetUint64(7)}, + {op: opHash}, + {op: opDelete, key: key1}, + {op: opUpdate, key: key1, value: new(felt.Felt).SetUint64(10)}, + {op: opHash}, + {op: opGet, key: key2}, + {op: opUpdate, key: key1, value: new(felt.Felt).SetUint64(13)}, + {op: opCommit}, + {op: opHash}, + {op: opProve, key: key1}, + {op: opProve, key: key3}, + {op: opUpdate, key: key3, value: new(felt.Felt).SetUint64(18)}, + {op: opCommit}, + {op: opGet, key: key2}, + {op: opUpdate, key: key2, value: new(felt.Felt).SetUint64(21)}, + {op: opHash}, + {op: opHash}, + {op: opGet, key: key2}, + {op: opProve, key: key1}, + {op: opUpdate, key: key2, value: new(felt.Felt).SetUint64(26)}, + {op: opHash}, + {op: opGet, key: key2}, + {op: opCommit}, + {op: opCommit}, + {op: opProve, key: key3}, // This is where the original failure occurred + } + + // Add debug logging + t.Log("Starting test sequence") + err := runRandTest(steps) + if err != nil { + t.Logf("Test failed at step: %v", err) + // Print the state of the trie at failure + for i, step := range steps { + if step.err != nil { + t.Logf("Failed at step %d: %v", i, step) + break + } + } + } + require.NoError(t, err, "specific random test sequence should not fail") +} + +const ( + opUpdate = iota + opDelete + opGet + opHash + opCommit + opProve + opMax // max number of operations, not an actual operation +) + +type randTestStep struct { + op int + key *felt.Felt // for opUpdate, opDelete, opGet + value *felt.Felt // for opUpdate + err error +} + +type randTest []randTestStep + +func (randTest) Generate(r *rand.Rand, size int) reflect.Value { + finishedFn := func() bool { + size-- + return size == 0 + } + return reflect.ValueOf(generateSteps(finishedFn, r)) +} + +func generateSteps(finished func() bool, r io.Reader) randTest { + var allKeys []*felt.Felt + random := []byte{0} + + genKey := func() *felt.Felt { + r.Read(random) + // Create a new key with 10% probability or when < 2 keys exist + if len(allKeys) < 2 || random[0]%100 > 90 { + size := random[0] % 32 // ensure key size is between 1 and 32 bytes + key := make([]byte, size) + r.Read(key) + allKeys = append(allKeys, new(felt.Felt).SetBytes(key)) + } + // 90% probability to return an existing key + idx := int(random[0]) % len(allKeys) + return allKeys[idx] + } + + var steps randTest + for !finished() { + r.Read(random) + step := randTestStep{op: int(random[0]) % opMax} + switch step.op { + case opUpdate: + step.key = genKey() + step.value = new(felt.Felt).SetUint64(uint64(len(steps))) + case opGet, opDelete, opProve: + step.key = genKey() + } + steps = append(steps, step) + } + return steps +} + +func runRandTestBool(rt randTest) bool { + return runRandTest(rt) == nil +} + +func runRandTest(rt randTest) error { + txn := db.NewMemTransaction() + tr, err := New(TrieID(felt.Zero), contractClassTrieHeight, crypto.Pedersen, txn) + if err != nil { + return err + } + + values := make(map[felt.Felt]felt.Felt) // keeps track of the content of the trie + + for i, step := range rt { + // fmt.Printf("Step %d: %d key=%s value=%s\n", i, step.op, step.key, step.value) + switch step.op { + case opUpdate: + err := tr.Update(step.key, step.value) + // fmt.Println("--------------------------------") + // fmt.Println(tr.root.String()) + if err != nil { + rt[i].err = fmt.Errorf("update failed: key %s, %w", step.key.String(), err) + } + values[*step.key] = *step.value + case opDelete: + err := tr.Delete(step.key) + // fmt.Println("--------------------------------") + // if tr.root != nil { + // fmt.Println("trie", tr.root.String()) + // } else { + // fmt.Println("nil root") + // } + if err != nil { + rt[i].err = fmt.Errorf("delete failed: key %s, %w", step.key.String(), err) + } + delete(values, *step.key) + case opGet: + got, err := tr.Get(step.key) + if err != nil { + rt[i].err = fmt.Errorf("get failed: key %s, %w", step.key.String(), err) + } + want := values[*step.key] + if !got.Equal(&want) { + rt[i].err = fmt.Errorf("mismatch in get: key %s, expected %v, got %v", step.key.String(), want.String(), got.String()) + } + case opProve: + hash := tr.Hash() + if hash.Equal(&felt.Zero) { + continue + } + proof := NewProofNodeSet() + err := tr.Prove(step.key, proof) + if err != nil { + rt[i].err = fmt.Errorf("prove failed for key %s: %w", step.key.String(), err) + } + _, err = VerifyProof(&hash, step.key, proof, crypto.Pedersen) + if err != nil { + rt[i].err = fmt.Errorf("verify proof failed for key %s: %w", step.key.String(), err) + } + case opHash: + tr.Hash() + case opCommit: + root, err := tr.Commit() + if err != nil { + rt[i].err = fmt.Errorf("commit failed: %w", err) + } + newtr, err := New(TrieID(root), contractClassTrieHeight, crypto.Pedersen, txn) + // fmt.Println("--------------------------------") + // if newtr.root != nil { + // fmt.Println(newtr.root.String()) + // } else { + // fmt.Println("nil root") + // } + if err != nil { + rt[i].err = fmt.Errorf("new trie failed: %w", err) + } + tr = newtr + } + + if rt[i].err != nil { + return rt[i].err + } + } + return nil } type keyValue struct { diff --git a/core/trie2/trienode/node.go b/core/trie2/trienode/node.go index febdbf53d1..55f58ab34f 100644 --- a/core/trie2/trienode/node.go +++ b/core/trie2/trienode/node.go @@ -17,6 +17,10 @@ func (r *Node) Blob() []byte { return r.blob } +func (r *Node) Hash() felt.Felt { + return r.hash +} + func NewNode(hash felt.Felt, blob []byte) *Node { return &Node{hash: hash, blob: blob} } From e13fd3a180e7428d48c051cd72297181a48c2c9f Mon Sep 17 00:00:00 2001 From: weiihann Date: Thu, 23 Jan 2025 09:55:07 +0800 Subject: [PATCH 35/44] linter --- core/trie2/collector.go | 2 -- core/trie2/proof.go | 4 +++- core/trie2/tracer.go | 5 ----- core/trie2/trie.go | 11 ++++------- core/trie2/trie_test.go | 22 ++++------------------ core/trie2/trieutils/bitarray.go | 2 +- 6 files changed, 12 insertions(+), 34 deletions(-) diff --git a/core/trie2/collector.go b/core/trie2/collector.go index fa4976b8b3..710f55b5bb 100644 --- a/core/trie2/collector.go +++ b/core/trie2/collector.go @@ -3,9 +3,7 @@ package trie2 import ( "fmt" "sync" -) -import ( "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/trie2/trienode" ) diff --git a/core/trie2/proof.go b/core/trie2/proof.go index 322a6ff2c7..2da26947f1 100644 --- a/core/trie2/proof.go +++ b/core/trie2/proof.go @@ -355,6 +355,8 @@ func proofToPath(rootHash *felt.Felt, root node, keyBits Path, proof *ProofNodeS // // Note we have the assumption here the given boundary keys are different // and right is larger than left. +// +//nolint:gocyclo func unsetInternal(n node, left, right Path) (bool, error) { // Step down to the fork point. There are two scenarios that can happen: // - the fork point is an edgeNode: either the key of left proof or @@ -479,7 +481,7 @@ findFork: } // unset removes all internal node references either the left most or right most. -func unset(parent node, child node, key *Path, pos uint8, removeLeft bool) error { +func unset(parent, child node, key *Path, pos uint8, removeLeft bool) error { switch cld := child.(type) { case *binaryNode: keyBit := key.Bit(pos) diff --git a/core/trie2/tracer.go b/core/trie2/tracer.go index 0d0659130a..3ee9ff7e9b 100644 --- a/core/trie2/tracer.go +++ b/core/trie2/tracer.go @@ -32,11 +32,6 @@ func (t *nodeTracer) onDelete(key *Path) { t.deletes[k] = struct{}{} } -func (t *nodeTracer) reset() { - t.inserts = make(map[Path]struct{}) - t.deletes = make(map[Path]struct{}) -} - func (t *nodeTracer) copy() *nodeTracer { return &nodeTracer{ inserts: maps.Clone(t.inserts), diff --git a/core/trie2/trie.go b/core/trie2/trie.go index 78e2e036a9..dba6b96fff 100644 --- a/core/trie2/trie.go +++ b/core/trie2/trie.go @@ -183,16 +183,13 @@ func (t *Trie) Commit() (felt.Felt, error) { nodes.Add(Path, trienode.NewDeleted()) } - t.root = newCollector(nodes).Collect(t.root, t.pendingUpdates > 100) // TODO(weiihann): 100 is arbitrary + t.root = newCollector(nodes).Collect(t.root, t.pendingUpdates > 100) //nolint:mnd // TODO(weiihann): 100 is arbitrary t.pendingUpdates = 0 err := nodes.ForEach(true, func(key trieutils.BitArray, node *trienode.Node) error { if node.IsDeleted() { return t.db.Delete(t.owner, key) } - // decodeNode, _ := decodeNode(node.Blob(), node.Hash(), key.Len(), t.height) - // fmt.Printf("key: %v value: blob %v hash %v\n", key.String(), node.Blob(), node.Hash()) - // fmt.Printf("decodeNode: %v\n", decodeNode.String()) return t.db.Put(t.owner, key, node.Blob()) }) if err != nil { @@ -272,6 +269,7 @@ func (t *Trie) update(key, value *felt.Felt) error { return nil } +//nolint:gocyclo,funlen func (t *Trie) insert(n node, prefix, key *Path, value node) (bool, node, error) { // We reach the end of the key if key.Len() == 0 { @@ -363,6 +361,7 @@ func (t *Trie) insert(n node, prefix, key *Path, value node) (bool, node, error) } } +//nolint:gocyclo,funlen func (t *Trie) delete(n node, prefix, key *Path) (bool, node, error) { switch n := n.(type) { case *edgeNode: @@ -463,8 +462,6 @@ func (t *Trie) resolveNode(hash *hashNode, path Path) (node, error) { bufferPool.Put(buf) }() - // fmt.Printf("resolveNode: path %v\n", path.String()) - _, err := t.db.Get(buf, t.owner, path) if err != nil { return nil, err @@ -478,7 +475,7 @@ func (t *Trie) hashRoot() (node, node) { if t.root == nil { return &hashNode{Felt: felt.Zero}, nil } - h := newHasher(t.hashFn, t.pendingHashes > 100) // TODO(weiihann): 100 is arbitrary + h := newHasher(t.hashFn, t.pendingHashes > 100) //nolint:mnd //TODO(weiihann): 100 is arbitrary hashed, cached := h.hash(t.root) t.pendingHashes = 0 return hashed, cached diff --git a/core/trie2/trie_test.go b/core/trie2/trie_test.go index d728ed7896..7effe704a2 100644 --- a/core/trie2/trie_test.go +++ b/core/trie2/trie_test.go @@ -325,12 +325,12 @@ func generateSteps(finished func() bool, r io.Reader) randTest { random := []byte{0} genKey := func() *felt.Felt { - r.Read(random) + _, _ = r.Read(random) // Create a new key with 10% probability or when < 2 keys exist if len(allKeys) < 2 || random[0]%100 > 90 { size := random[0] % 32 // ensure key size is between 1 and 32 bytes key := make([]byte, size) - r.Read(key) + _, _ = r.Read(key) allKeys = append(allKeys, new(felt.Felt).SetBytes(key)) } // 90% probability to return an existing key @@ -340,7 +340,7 @@ func generateSteps(finished func() bool, r io.Reader) randTest { var steps randTest for !finished() { - r.Read(random) + _, _ = r.Read(random) step := randTestStep{op: int(random[0]) % opMax} switch step.op { case opUpdate: @@ -358,6 +358,7 @@ func runRandTestBool(rt randTest) bool { return runRandTest(rt) == nil } +//nolint:gocyclo func runRandTest(rt randTest) error { txn := db.NewMemTransaction() tr, err := New(TrieID(felt.Zero), contractClassTrieHeight, crypto.Pedersen, txn) @@ -368,24 +369,15 @@ func runRandTest(rt randTest) error { values := make(map[felt.Felt]felt.Felt) // keeps track of the content of the trie for i, step := range rt { - // fmt.Printf("Step %d: %d key=%s value=%s\n", i, step.op, step.key, step.value) switch step.op { case opUpdate: err := tr.Update(step.key, step.value) - // fmt.Println("--------------------------------") - // fmt.Println(tr.root.String()) if err != nil { rt[i].err = fmt.Errorf("update failed: key %s, %w", step.key.String(), err) } values[*step.key] = *step.value case opDelete: err := tr.Delete(step.key) - // fmt.Println("--------------------------------") - // if tr.root != nil { - // fmt.Println("trie", tr.root.String()) - // } else { - // fmt.Println("nil root") - // } if err != nil { rt[i].err = fmt.Errorf("delete failed: key %s, %w", step.key.String(), err) } @@ -421,12 +413,6 @@ func runRandTest(rt randTest) error { rt[i].err = fmt.Errorf("commit failed: %w", err) } newtr, err := New(TrieID(root), contractClassTrieHeight, crypto.Pedersen, txn) - // fmt.Println("--------------------------------") - // if newtr.root != nil { - // fmt.Println(newtr.root.String()) - // } else { - // fmt.Println("nil root") - // } if err != nil { rt[i].err = fmt.Errorf("new trie failed: %w", err) } diff --git a/core/trie2/trieutils/bitarray.go b/core/trie2/trieutils/bitarray.go index 2ebe72d77d..b83a59923e 100644 --- a/core/trie2/trieutils/bitarray.go +++ b/core/trie2/trieutils/bitarray.go @@ -18,7 +18,7 @@ const ( MaxBitArraySize = 33 // (1 + 4 * 8) bytes ) -var emptyBitArray = new(BitArray) +var emptyBitArray = new(BitArray) //nolint:unused //TODO(weiihann): remove this nolint when we replace legacy trie // Represents a bit array with length representing the number of used bits. // It uses a little endian representation to do bitwise operations of the words efficiently. From 4ed10a14b0c063cf31e5b52f9819025af0bd4019 Mon Sep 17 00:00:00 2001 From: weiihann Date: Thu, 23 Jan 2025 10:09:05 +0800 Subject: [PATCH 36/44] add comments --- core/trie2/hasher.go | 4 ++-- core/trie2/id.go | 2 ++ core/trie2/node_enc.go | 6 ++++++ core/trie2/proof.go | 2 ++ core/trie2/tracer.go | 7 +++++++ core/trie2/trie.go | 2 ++ 6 files changed, 21 insertions(+), 2 deletions(-) diff --git a/core/trie2/hasher.go b/core/trie2/hasher.go index 0fe11c6590..6917896b0b 100644 --- a/core/trie2/hasher.go +++ b/core/trie2/hasher.go @@ -7,7 +7,7 @@ import ( "github.com/NethermindEth/juno/core/crypto" ) -// hasher handles node hashing for the trie. It supports both sequential and parallel +// A tool for shashing nodes in the trie. It supports both sequential and parallel // hashing modes. type hasher struct { hashFn crypto.HashFn // The hash function to use @@ -21,7 +21,7 @@ func newHasher(hash crypto.HashFn, parallel bool) hasher { } } -// hash computes the hash of a node and returns both the hash node and a cached +// Computes the hash of a node and returns both the hash node and a cached // version of the original node. If the node already has a cached hash, returns // that instead of recomputing. func (h *hasher) hash(n node) (node, node) { diff --git a/core/trie2/id.go b/core/trie2/id.go index 4150ea6a37..49e3b09fe6 100644 --- a/core/trie2/id.go +++ b/core/trie2/id.go @@ -21,6 +21,7 @@ type ID struct { StorageRoot felt.Felt // The root hash of the storage trie of a contract. } +// Returns the corresponding DB bucket for the trie func (id *ID) Bucket() db.Bucket { switch id.TrieType { case ClassTrie: @@ -57,6 +58,7 @@ func ContractTrieID(root, owner, storageRoot felt.Felt) *ID { } } +// A general identifier, typically used for temporary trie func TrieID(root felt.Felt) *ID { return &ID{ TrieType: Empty, diff --git a/core/trie2/node_enc.go b/core/trie2/node_enc.go index 3d454391e7..3cc0229ecf 100644 --- a/core/trie2/node_enc.go +++ b/core/trie2/node_enc.go @@ -31,6 +31,7 @@ const ( edgeNodeType ) +// Enc(binary) = binaryNodeType + HashNode(left) + HashNode(right) func (n *binaryNode) write(buf *bytes.Buffer) error { if err := buf.WriteByte(binaryNodeType); err != nil { return err @@ -47,6 +48,7 @@ func (n *binaryNode) write(buf *bytes.Buffer) error { return nil } +// Enc(edge) = edgeNodeType + HashNode(child) + Path func (n *edgeNode) write(buf *bytes.Buffer) error { if err := buf.WriteByte(edgeNodeType); err != nil { return err @@ -63,6 +65,7 @@ func (n *edgeNode) write(buf *bytes.Buffer) error { return nil } +// Enc(hash) = Felt func (n *hashNode) write(buf *bytes.Buffer) error { if _, err := buf.Write(n.Felt.Marshal()); err != nil { return err @@ -71,6 +74,7 @@ func (n *hashNode) write(buf *bytes.Buffer) error { return nil } +// Enc(value) = Felt func (n *valueNode) write(buf *bytes.Buffer) error { if _, err := buf.Write(n.Felt.Marshal()); err != nil { return err @@ -79,6 +83,7 @@ func (n *valueNode) write(buf *bytes.Buffer) error { return nil } +// Returns the encoded bytes of a node func nodeToBytes(n node) []byte { buf := bufferPool.Get().(*bytes.Buffer) buf.Reset() @@ -96,6 +101,7 @@ func nodeToBytes(n node) []byte { return res } +// Decodes the encoded bytes and returns the corresponding node func decodeNode(blob []byte, hash felt.Felt, pathLen, maxPathLen uint8) (node, error) { if len(blob) == 0 { return nil, errors.New("cannot decode empty blob") diff --git a/core/trie2/proof.go b/core/trie2/proof.go index 2da26947f1..1916cb1339 100644 --- a/core/trie2/proof.go +++ b/core/trie2/proof.go @@ -571,6 +571,8 @@ func hasRightElement(node node, key Path) bool { return false } +// Resolves the whole path of the given key and node. +// If skipResolved is true, it will only return the immediate child node of the current node func get(rn node, key *Path, skipResolved bool) node { for { switch n := rn.(type) { diff --git a/core/trie2/tracer.go b/core/trie2/tracer.go index 3ee9ff7e9b..3197965221 100644 --- a/core/trie2/tracer.go +++ b/core/trie2/tracer.go @@ -4,6 +4,7 @@ import ( "maps" ) +// Tracks the changes to the trie, so that we know which node needs to be updated or deleted in the database type nodeTracer struct { inserts map[Path]struct{} deletes map[Path]struct{} @@ -16,17 +17,23 @@ func newTracer() *nodeTracer { } } +// Tracks the newly inserted trie node. If the trie node was previously deleted, remove it from the deletion set +// as it means that the node will not be deleted in the database func (t *nodeTracer) onInsert(key *Path) { k := *key if _, present := t.deletes[k]; present { + delete(t.deletes, k) return } t.inserts[k] = struct{}{} } +// Tracks the newly deleted trie node. If the trie node was previously inserted, remove it from the insertion set +// as it means that the node will not be inserted in the database func (t *nodeTracer) onDelete(key *Path) { k := *key if _, present := t.inserts[k]; present { + delete(t.inserts, k) return } t.deletes[k] = struct{}{} diff --git a/core/trie2/trie.go b/core/trie2/trie.go index dba6b96fff..d3830e9bc3 100644 --- a/core/trie2/trie.go +++ b/core/trie2/trie.go @@ -454,6 +454,7 @@ func (t *Trie) delete(n node, prefix, key *Path) (bool, node, error) { } } +// Resolves the node at the given path from the database func (t *Trie) resolveNode(hash *hashNode, path Path) (node, error) { buf := bufferPool.Get().(*bytes.Buffer) buf.Reset() @@ -471,6 +472,7 @@ func (t *Trie) resolveNode(hash *hashNode, path Path) (node, error) { return decodeNode(blob, hash.Felt, path.Len(), t.height) } +// Calculate the hash of the root node func (t *Trie) hashRoot() (node, node) { if t.root == nil { return &hashNode{Felt: felt.Zero}, nil From cf9562af70cbb08815ac5844b8a10e76b37f7806 Mon Sep 17 00:00:00 2001 From: weiihann Date: Thu, 23 Jan 2025 10:14:21 +0800 Subject: [PATCH 37/44] fix rebase --- core/trie/bitarray_test.go | 424 ------------------------------------- 1 file changed, 424 deletions(-) diff --git a/core/trie/bitarray_test.go b/core/trie/bitarray_test.go index 5581346899..e711a9ddd6 100644 --- a/core/trie/bitarray_test.go +++ b/core/trie/bitarray_test.go @@ -3,7 +3,6 @@ package trie import ( "bytes" "encoding/binary" - "math" "math/bits" "testing" @@ -12,8 +11,6 @@ import ( "github.com/stretchr/testify/require" ) -var maxBits = [4]uint64{math.MaxUint64, math.MaxUint64, math.MaxUint64, math.MaxUint64} - const ( ones63 = 0x7FFFFFFFFFFFFFFF // 63 bits of 1 ) @@ -2092,424 +2089,3 @@ func TestSubset(t *testing.T) { }) } } - -func BenchmarkBitArrayBytes(b *testing.B) { - testCases := []struct { - name string - ba bitArray - }{ - { - name: "empty", - ba: bitArray{pos: 0, words: maxBitArray}, - }, - { - name: "pos_38", - ba: bitArray{pos: 38, words: maxBitArray}, - }, - { - name: "pos_100", - ba: bitArray{pos: 100, words: maxBitArray}, - }, - { - name: "pos_201", - ba: bitArray{pos: 201, words: maxBitArray}, - }, - { - name: "pos_255", - ba: bitArray{pos: 255, words: maxBitArray}, - }, - } - - for _, tc := range testCases { - b.Run(tc.name, func(b *testing.B) { - b.ReportAllocs() - for i := 0; i < b.N; i++ { - _ = tc.ba.Bytes() - } - }) - } -} - -func TestRsh(t *testing.T) { - tests := []struct { - name string - initial *bitArray - shiftBy uint8 - expected *bitArray - }{ - { - name: "zero length array", - initial: &bitArray{ - len: 0, - words: [4]uint64{0, 0, 0, 0}, - }, - shiftBy: 5, - expected: &bitArray{ - len: 0, - words: [4]uint64{0, 0, 0, 0}, - }, - }, - { - name: "shift by 0", - initial: &bitArray{ - len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, - }, - shiftBy: 0, - expected: &bitArray{ - len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, - }, - }, - { - name: "shift by more than length", - initial: &bitArray{ - len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, - }, - shiftBy: 65, - expected: &bitArray{ - len: 0, - words: [4]uint64{0, 0, 0, 0}, - }, - }, - { - name: "shift by less than 64", - initial: &bitArray{ - len: 128, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, - }, - shiftBy: 32, - expected: &bitArray{ - len: 96, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0x00000000FFFFFFFF, 0, 0}, - }, - }, - { - name: "shift by exactly 64", - initial: &bitArray{ - len: 128, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, - }, - shiftBy: 64, - expected: &bitArray{ - len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, - }, - }, - { - name: "shift by 128", - initial: &bitArray{ - len: 251, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF}, - }, - shiftBy: 128, - expected: &bitArray{ - len: 123, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, - }, - }, - { - name: "shift by 192", - initial: &bitArray{ - len: 251, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF}, - }, - shiftBy: 192, - expected: &bitArray{ - len: 59, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := new(bitArray).Rsh(tt.initial, tt.shiftBy) - if !result.Equal(tt.expected) { - t.Errorf("Rsh() got = %+v, want %+v", result, tt.expected) - } - }) - } -} - -func TestPrefixEqual(t *testing.T) { - tests := []struct { - name string - a *bitArray - b *bitArray - want bool - }{ - { - name: "equal lengths, equal values", - a: &bitArray{ - len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, - }, - b: &bitArray{ - len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, - }, - want: true, - }, - { - name: "equal lengths, different values", - a: &bitArray{ - len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, - }, - b: &bitArray{ - len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFF0, 0, 0, 0}, - }, - want: false, - }, - { - name: "different lengths, a longer but same prefix", - a: &bitArray{ - len: 128, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, - }, - b: &bitArray{ - len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, - }, - want: true, - }, - { - name: "different lengths, b longer but same prefix", - a: &bitArray{ - len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, - }, - b: &bitArray{ - len: 128, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, - }, - want: true, - }, - { - name: "different lengths, different prefix", - a: &bitArray{ - len: 128, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, - }, - b: &bitArray{ - len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFF0, 0, 0, 0}, - }, - want: false, - }, - { - name: "zero length arrays", - a: &bitArray{ - len: 0, - words: [4]uint64{0, 0, 0, 0}, - }, - b: &bitArray{ - len: 0, - words: [4]uint64{0, 0, 0, 0}, - }, - want: true, - }, - { - name: "one zero length array", - a: &bitArray{ - len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, - }, - b: &bitArray{ - len: 0, - words: [4]uint64{0, 0, 0, 0}, - }, - want: true, - }, - { - name: "max length difference", - a: &bitArray{ - len: 251, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFF}, - }, - b: &bitArray{ - len: 1, - words: [4]uint64{0x1, 0, 0, 0}, - }, - want: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := tt.a.PrefixEqual(tt.b); got != tt.want { - t.Errorf("PrefixEqual() = %v, want %v", got, tt.want) - } - // Test symmetry: a.PrefixEqual(b) should equal b.PrefixEqual(a) - if got := tt.b.PrefixEqual(tt.a); got != tt.want { - t.Errorf("PrefixEqual() symmetric test = %v, want %v", got, tt.want) - } - }) - } -} - -func TestTruncate(t *testing.T) { - tests := []struct { - name string - initial bitArray - length uint8 - expected bitArray - }{ - { - name: "truncate to zero", - initial: bitArray{ - len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, - }, - length: 0, - expected: bitArray{ - len: 0, - words: [4]uint64{0, 0, 0, 0}, - }, - }, - { - name: "truncate within first word - 32 bits", - initial: bitArray{ - len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, - }, - length: 32, - expected: bitArray{ - len: 32, - words: [4]uint64{0x00000000FFFFFFFF, 0, 0, 0}, - }, - }, - { - name: "truncate to single bit", - initial: bitArray{ - len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, - }, - length: 1, - expected: bitArray{ - len: 1, - words: [4]uint64{0x0000000000000001, 0, 0, 0}, - }, - }, - { - name: "truncate across words - 100 bits", - initial: bitArray{ - len: 128, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, - }, - length: 100, - expected: bitArray{ - len: 100, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0x0000000FFFFFFFFF, 0, 0}, - }, - }, - { - name: "truncate at word boundary - 64 bits", - initial: bitArray{ - len: 128, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, - }, - length: 64, - expected: bitArray{ - len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, - }, - }, - { - name: "truncate at word boundary - 128 bits", - initial: bitArray{ - len: 192, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0}, - }, - length: 128, - expected: bitArray{ - len: 128, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, - }, - }, - { - name: "truncate in third word - 150 bits", - initial: bitArray{ - len: 192, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0}, - }, - length: 150, - expected: bitArray{ - len: 150, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x3FFFFF, 0}, - }, - }, - { - name: "truncate in fourth word - 220 bits", - initial: bitArray{ - len: 255, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF}, - }, - length: 220, - expected: bitArray{ - len: 220, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFF}, - }, - }, - { - name: "truncate max length - 251 bits", - initial: bitArray{ - len: 255, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF}, - }, - length: 251, - expected: bitArray{ - len: 251, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0x7FFFFFFFFFFFFFF}, - }, - }, - { - name: "truncate sparse bits", - initial: bitArray{ - len: 128, - words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0x5555555555555555, 0, 0}, - }, - length: 100, - expected: bitArray{ - len: 100, - words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0x0000000555555555, 0, 0}, - }, - }, - { - name: "no change when new length equals current length", - initial: bitArray{ - len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, - }, - length: 64, - expected: bitArray{ - len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, - }, - }, - { - name: "no change when new length greater than current length", - initial: bitArray{ - len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, - }, - length: 128, - expected: bitArray{ - len: 64, - words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0, 0, 0}, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := new(bitArray).Truncate(&tt.initial, tt.length) - if !result.Equal(&tt.expected) { - t.Errorf("Truncate() got = %+v, want %+v", result, tt.expected) - } - }) - } -} From 7f70307aa1eb2c890749193d7e16e32701e2c137 Mon Sep 17 00:00:00 2001 From: weiihann Date: Thu, 23 Jan 2025 10:15:59 +0800 Subject: [PATCH 38/44] remove proof_test changes --- core/trie/proof_test.go | 41 ++++++++++++++++++++++++++++++++++++++--- 1 file changed, 38 insertions(+), 3 deletions(-) diff --git a/core/trie/proof_test.go b/core/trie/proof_test.go index dc441036cc..5a43932042 100644 --- a/core/trie/proof_test.go +++ b/core/trie/proof_test.go @@ -14,6 +14,8 @@ import ( ) func TestProve(t *testing.T) { + t.Parallel() + n := 1000 tempTrie, records := nonRandomTrie(t, n) @@ -34,6 +36,8 @@ func TestProve(t *testing.T) { } func TestProveNonExistent(t *testing.T) { + t.Parallel() + n := 1000 tempTrie, _ := nonRandomTrie(t, n) @@ -56,6 +60,7 @@ func TestProveNonExistent(t *testing.T) { } func TestProveRandom(t *testing.T) { + t.Parallel() tempTrie, records := randomTrie(t, 1000) for _, record := range records { @@ -73,6 +78,8 @@ func TestProveRandom(t *testing.T) { } func TestProveCustom(t *testing.T) { + t.Parallel() + tests := []testTrie{ { name: "simple binary", @@ -166,6 +173,8 @@ func TestProveCustom(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { + t.Parallel() + tr, _ := test.buildFn(t) for _, tc := range test.testKeys { @@ -188,6 +197,8 @@ func TestProveCustom(t *testing.T) { // TestRangeProof tests normal range proof with both edge proofs func TestRangeProof(t *testing.T) { + t.Parallel() + n := 500 tr, records := randomTrie(t, n) root, err := tr.Root() @@ -215,6 +226,8 @@ func TestRangeProof(t *testing.T) { // TestRangeProofWithNonExistentProof tests normal range proof with non-existent proofs func TestRangeProofWithNonExistentProof(t *testing.T) { + t.Parallel() + n := 500 tr, records := randomTrie(t, n) root, err := tr.Root() @@ -247,8 +260,9 @@ func TestRangeProofWithNonExistentProof(t *testing.T) { // TestRangeProofWithInvalidNonExistentProof tests range proof with invalid non-existent proofs. // One scenario is when there is a gap between the first element and the left edge proof. -// TODO(weiihann): this is failing func TestRangeProofWithInvalidNonExistentProof(t *testing.T) { + t.Parallel() + n := 500 tr, records := randomTrie(t, n) root, err := tr.Root() @@ -270,17 +284,20 @@ func TestRangeProofWithInvalidNonExistentProof(t *testing.T) { } _, err = trie.VerifyRangeProof(root, first, keys, values, proof) - t.Log(err) require.Error(t, err) } func TestOneElementRangeProof(t *testing.T) { + t.Parallel() + n := 1000 tr, records := randomTrie(t, n) root, err := tr.Root() require.NoError(t, err) t.Run("both edge proofs with the same key", func(t *testing.T) { + t.Parallel() + start := 100 proof := trie.NewProofNodeSet() err = tr.GetRangeProof(records[start].key, records[start].key, proof) @@ -291,6 +308,8 @@ func TestOneElementRangeProof(t *testing.T) { }) t.Run("left non-existent edge proof", func(t *testing.T) { + t.Parallel() + start := 100 proof := trie.NewProofNodeSet() err = tr.GetRangeProof(decrementFelt(records[start].key), records[start].key, proof) @@ -301,6 +320,8 @@ func TestOneElementRangeProof(t *testing.T) { }) t.Run("right non-existent edge proof", func(t *testing.T) { + t.Parallel() + end := 100 proof := trie.NewProofNodeSet() err = tr.GetRangeProof(records[end].key, incrementFelt(records[end].key), proof) @@ -311,6 +332,8 @@ func TestOneElementRangeProof(t *testing.T) { }) t.Run("both non-existent edge proofs", func(t *testing.T) { + t.Parallel() + start := 100 first, last := decrementFelt(records[start].key), incrementFelt(records[start].key) proof := trie.NewProofNodeSet() @@ -322,6 +345,8 @@ func TestOneElementRangeProof(t *testing.T) { }) t.Run("1 key trie", func(t *testing.T) { + t.Parallel() + tr, records := build1KeyTrie(t) root, err := tr.Root() require.NoError(t, err) @@ -337,6 +362,8 @@ func TestOneElementRangeProof(t *testing.T) { // TestAllElementsRangeProof tests the range proof with all elements and nil proof. func TestAllElementsRangeProof(t *testing.T) { + t.Parallel() + n := 1000 tr, records := randomTrie(t, n) root, err := tr.Root() @@ -363,6 +390,8 @@ func TestAllElementsRangeProof(t *testing.T) { // TestSingleSideRangeProof tests the range proof starting with zero. func TestSingleSideRangeProof(t *testing.T) { + t.Parallel() + tr, records := randomTrie(t, 1000) root, err := tr.Root() require.NoError(t, err) @@ -384,8 +413,8 @@ func TestSingleSideRangeProof(t *testing.T) { } } -// TODO(weiihann): this is failing func TestGappedRangeProof(t *testing.T) { + t.Parallel() t.Skip("gapped keys will sometimes succeed, the current proof format is not able to handle this") tr, records := nonRandomTrie(t, 5) @@ -413,6 +442,8 @@ func TestGappedRangeProof(t *testing.T) { } func TestEmptyRangeProof(t *testing.T) { + t.Parallel() + tr, records := randomTrie(t, 1000) root, err := tr.Root() require.NoError(t, err) @@ -441,6 +472,8 @@ func TestEmptyRangeProof(t *testing.T) { } func TestHasRightElement(t *testing.T) { + t.Parallel() + tr, records := randomTrie(t, 500) root, err := tr.Root() require.NoError(t, err) @@ -492,6 +525,8 @@ func TestHasRightElement(t *testing.T) { // TestBadRangeProof generates random bad proof scenarios and verifies that the proof is invalid. func TestBadRangeProof(t *testing.T) { + t.Parallel() + tr, records := randomTrie(t, 1000) root, err := tr.Root() require.NoError(t, err) From 598fd2c08e766b349d3f4342eccaf9f6402ac5f8 Mon Sep 17 00:00:00 2001 From: weiihann Date: Thu, 23 Jan 2025 13:07:05 +0800 Subject: [PATCH 39/44] add bitarray_test.go --- core/trie2/trieutils/bitarray_test.go | 2091 +++++++++++++++++++++++++ 1 file changed, 2091 insertions(+) create mode 100644 core/trie2/trieutils/bitarray_test.go diff --git a/core/trie2/trieutils/bitarray_test.go b/core/trie2/trieutils/bitarray_test.go new file mode 100644 index 0000000000..6589f94939 --- /dev/null +++ b/core/trie2/trieutils/bitarray_test.go @@ -0,0 +1,2091 @@ +package trieutils + +import ( + "bytes" + "encoding/binary" + "math/bits" + "testing" + + "github.com/NethermindEth/juno/core/felt" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const ( + ones63 = 0x7FFFFFFFFFFFFFFF // 63 bits of 1 +) + +func TestBytes(t *testing.T) { + tests := []struct { + name string + ba BitArray + want [32]byte + }{ + { + name: "length == 0", + ba: BitArray{len: 0, words: [4]uint64{0, 0, 0, 0}}, + want: [32]byte{}, + }, + { + name: "length < 64", + ba: BitArray{len: 38, words: [4]uint64{0x3FFFFFFFFF, 0, 0, 0}}, + want: func() [32]byte { + var b [32]byte + binary.BigEndian.PutUint64(b[24:32], 0x3FFFFFFFFF) + return b + }(), + }, + { + name: "64 <= length < 128", + ba: BitArray{len: 100, words: [4]uint64{maxUint64, 0xFFFFFFFFF, 0, 0}}, + want: func() [32]byte { + var b [32]byte + binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFF) + binary.BigEndian.PutUint64(b[24:32], maxUint64) + return b + }(), + }, + { + name: "128 <= length < 192", + ba: BitArray{len: 130, words: [4]uint64{maxUint64, maxUint64, 0x3, 0}}, + want: func() [32]byte { + var b [32]byte + binary.BigEndian.PutUint64(b[8:16], 0x3) + binary.BigEndian.PutUint64(b[16:24], maxUint64) + binary.BigEndian.PutUint64(b[24:32], maxUint64) + return b + }(), + }, + { + name: "192 <= length < 255", + ba: BitArray{len: 201, words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x1FF}}, + want: func() [32]byte { + var b [32]byte + binary.BigEndian.PutUint64(b[0:8], 0x1FF) + binary.BigEndian.PutUint64(b[8:16], maxUint64) + binary.BigEndian.PutUint64(b[16:24], maxUint64) + binary.BigEndian.PutUint64(b[24:32], maxUint64) + return b + }(), + }, + { + name: "length == 254", + ba: BitArray{len: 254, words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x3FFFFFFFFFFFFFFF}}, + want: func() [32]byte { + var b [32]byte + binary.BigEndian.PutUint64(b[0:8], 0x3FFFFFFFFFFFFFFF) + binary.BigEndian.PutUint64(b[8:16], maxUint64) + binary.BigEndian.PutUint64(b[16:24], maxUint64) + binary.BigEndian.PutUint64(b[24:32], maxUint64) + return b + }(), + }, + { + name: "length == 255", + ba: BitArray{len: 255, words: [4]uint64{maxUint64, maxUint64, maxUint64, ones63}}, + want: func() [32]byte { + var b [32]byte + binary.BigEndian.PutUint64(b[0:8], ones63) + binary.BigEndian.PutUint64(b[8:16], maxUint64) + binary.BigEndian.PutUint64(b[16:24], maxUint64) + binary.BigEndian.PutUint64(b[24:32], maxUint64) + return b + }(), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.ba.Bytes() + if !bytes.Equal(got[:], tt.want[:]) { + t.Errorf("BitArray.Bytes() = %v, want %v", got, tt.want) + } + + // check if the received bytes has the same bit count as the BitArray.len + count := 0 + for _, b := range got { + count += bits.OnesCount8(b) + } + if count != int(tt.ba.len) { + t.Errorf("BitArray.Bytes() bit count = %v, want %v", count, tt.ba.len) + } + }) + } +} + +func TestRsh(t *testing.T) { + tests := []struct { + name string + initial *BitArray + shiftBy uint8 + expected *BitArray + }{ + { + name: "zero length array", + initial: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + shiftBy: 5, + expected: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "shift by 0", + initial: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + shiftBy: 0, + expected: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "shift by more than length", + initial: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + shiftBy: 65, + expected: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "shift by less than 64", + initial: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + shiftBy: 32, + expected: &BitArray{ + len: 96, + words: [4]uint64{maxUint64, 0x00000000FFFFFFFF, 0, 0}, + }, + }, + { + name: "shift by exactly 64", + initial: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + shiftBy: 64, + expected: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "shift by 127", + initial: &BitArray{ + len: 255, + words: [4]uint64{maxUint64, maxUint64, maxUint64, ones63}, + }, + shiftBy: 127, + expected: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + }, + { + name: "shift by 128", + initial: &BitArray{ + len: 251, + words: [4]uint64{maxUint64, maxUint64, maxUint64, maxUint64}, + }, + shiftBy: 128, + expected: &BitArray{ + len: 123, + words: [4]uint64{maxUint64, 0x7FFFFFFFFFFFFFF, 0, 0}, + }, + }, + { + name: "shift by 192", + initial: &BitArray{ + len: 251, + words: [4]uint64{maxUint64, maxUint64, maxUint64, maxUint64}, + }, + shiftBy: 192, + expected: &BitArray{ + len: 59, + words: [4]uint64{0x7FFFFFFFFFFFFFF, 0, 0, 0}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := new(BitArray).Rsh(tt.initial, tt.shiftBy) + if !result.Equal(tt.expected) { + t.Errorf("Rsh() got = %+v, want %+v", result, tt.expected) + } + }) + } +} + +func TestLsh(t *testing.T) { + tests := []struct { + name string + x *BitArray + n uint8 + want *BitArray + }{ + { + name: "empty array", + x: emptyBitArray, + n: 5, + want: emptyBitArray, + }, + { + name: "shift by 0", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + n: 0, + want: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "shift within first word", + x: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + n: 4, + want: &BitArray{ + len: 8, + words: [4]uint64{0xF0, 0, 0, 0}, // 11110000 + }, + }, + { + name: "shift across word boundary", + x: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + n: 62, + want: &BitArray{ + len: 66, + words: [4]uint64{0xC000000000000000, 0x3, 0, 0}, + }, + }, + { + name: "shift by 64 (full word)", + x: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 + }, + n: 64, + want: &BitArray{ + len: 72, + words: [4]uint64{0, 0xFF, 0, 0}, + }, + }, + { + name: "shift by 128", + x: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 + }, + n: 128, + want: &BitArray{ + len: 136, + words: [4]uint64{0, 0, 0xFF, 0}, + }, + }, + { + name: "shift by 192", + x: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 + }, + n: 192, + want: &BitArray{ + len: 200, + words: [4]uint64{0, 0, 0, 0xFF}, + }, + }, + { + name: "shift causing length overflow", + x: &BitArray{ + len: 200, + words: [4]uint64{0xFF, 0, 0, 0}, + }, + n: 60, + want: &BitArray{ + len: 255, // capped at maxUint8 + words: [4]uint64{ + 0xF000000000000000, + 0xF, + 0, + 0, + }, + }, + }, + { + name: "shift sparse bits", + x: &BitArray{ + len: 8, + words: [4]uint64{0xAA, 0, 0, 0}, // 10101010 + }, + n: 4, + want: &BitArray{ + len: 12, + words: [4]uint64{0xAA0, 0, 0, 0}, // 101010100000 + }, + }, + { + name: "shift partial word across boundary", + x: &BitArray{ + len: 100, + words: [4]uint64{0xFF, 0xFF, 0, 0}, + }, + n: 60, + want: &BitArray{ + len: 160, + words: [4]uint64{ + 0xF000000000000000, + 0xF00000000000000F, + 0xF, + 0, + }, + }, + }, + { + name: "near maximum length shift", + x: &BitArray{ + len: 251, + words: [4]uint64{0xFF, 0, 0, 0}, + }, + n: 4, + want: &BitArray{ + len: 255, // capped at maxUint8 + words: [4]uint64{0xFF0, 0, 0, 0}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := new(BitArray).Lsh(tt.x, tt.n) + if !got.Equal(tt.want) { + t.Errorf("Lsh() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAppend(t *testing.T) { + tests := []struct { + name string + x *BitArray + y *BitArray + want *BitArray + }{ + { + name: "both empty arrays", + x: emptyBitArray, + y: emptyBitArray, + want: emptyBitArray, + }, + { + name: "first array empty", + x: emptyBitArray, + y: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + want: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + }, + { + name: "second array empty", + x: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + y: emptyBitArray, + want: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + }, + { + name: "within first word", + x: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + y: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + want: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 + }, + }, + { + name: "different lengths within word", + x: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + y: &BitArray{ + len: 2, + words: [4]uint64{0x3, 0, 0, 0}, // 11 + }, + want: &BitArray{ + len: 6, + words: [4]uint64{0x3F, 0, 0, 0}, // 111111 + }, + }, + { + name: "across word boundary", + x: &BitArray{ + len: 62, + words: [4]uint64{0x3FFFFFFFFFFFFFFF, 0, 0, 0}, + }, + y: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + want: &BitArray{ + len: 66, + words: [4]uint64{maxUint64, 0x3, 0, 0}, + }, + }, + { + name: "across multiple words", + x: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + y: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + want: &BitArray{ + len: 192, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0}, + }, + }, + { + name: "sparse bits", + x: &BitArray{ + len: 8, + words: [4]uint64{0xAA, 0, 0, 0}, // 10101010 + }, + y: &BitArray{ + len: 8, + words: [4]uint64{0x55, 0, 0, 0}, // 01010101 + }, + want: &BitArray{ + len: 16, + words: [4]uint64{0xAA55, 0, 0, 0}, // 1010101001010101 + }, + }, + { + name: "result exactly at length limit", + x: &BitArray{ + len: 251, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFFF}, + }, + y: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, + }, + want: &BitArray{ + len: 255, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFFF}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := new(BitArray).Append(tt.x, tt.y) + if !got.Equal(tt.want) { + t.Errorf("Append() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestEqualMSBs(t *testing.T) { + tests := []struct { + name string + a *BitArray + b *BitArray + want bool + }{ + { + name: "equal lengths, equal values", + a: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + b: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + want: true, + }, + { + name: "equal lengths, different values", + a: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + b: &BitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFF0, 0, 0, 0}, + }, + want: false, + }, + { + name: "different lengths, a longer but same prefix", + a: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + b: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + want: true, + }, + { + name: "different lengths, b longer but same prefix", + a: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + b: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + want: true, + }, + { + name: "different lengths, different prefix", + a: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + b: &BitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFF0, 0, 0, 0}, + }, + want: false, + }, + { + name: "zero length arrays", + a: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + b: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + want: true, + }, + { + name: "one zero length array", + a: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + b: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + want: true, + }, + { + name: "max length difference", + a: &BitArray{ + len: 251, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFF}, + }, + b: &BitArray{ + len: 1, + words: [4]uint64{0x1, 0, 0, 0}, + }, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.a.EqualMSBs(tt.b); got != tt.want { + t.Errorf("PrefixEqual() = %v, want %v", got, tt.want) + } + // Test symmetry: a.PrefixEqual(b) should equal b.PrefixEqual(a) + if got := tt.b.EqualMSBs(tt.a); got != tt.want { + t.Errorf("PrefixEqual() symmetric test = %v, want %v", got, tt.want) + } + }) + } +} + +func TestLSBs(t *testing.T) { + tests := []struct { + name string + x *BitArray + pos uint8 + want *BitArray + }{ + { + name: "zero position", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + pos: 0, + want: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "position beyond length", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + pos: 65, + want: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "get last 4 bits", + x: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 + }, + pos: 4, + want: &BitArray{ + len: 4, + words: [4]uint64{0x0F, 0, 0, 0}, // 1111 + }, + }, + { + name: "get bits across word boundary", + x: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + pos: 64, + want: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "get bits from max length array", + x: &BitArray{ + len: 251, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFF}, + }, + pos: 200, + want: &BitArray{ + len: 51, + words: [4]uint64{0x7FFFFFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "empty array", + x: emptyBitArray, + pos: 1, + want: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "sparse bits", + x: &BitArray{ + len: 16, + words: [4]uint64{0xAAAA, 0, 0, 0}, // 1010101010101010 + }, + pos: 8, + want: &BitArray{ + len: 8, + words: [4]uint64{0xAA, 0, 0, 0}, // 10101010 + }, + }, + { + name: "position equals length", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + pos: 64, + want: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := new(BitArray).LSBs(tt.x, tt.pos) + if !got.Equal(tt.want) { + t.Errorf("LSBs() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestLSBsFromLSB(t *testing.T) { + tests := []struct { + name string + initial BitArray + length uint8 + expected BitArray + }{ + { + name: "zero", + initial: BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + length: 0, + expected: BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "get 32 LSBs", + initial: BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + length: 32, + expected: BitArray{ + len: 32, + words: [4]uint64{0x00000000FFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "get 1 LSB", + initial: BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + length: 1, + expected: BitArray{ + len: 1, + words: [4]uint64{0x1, 0, 0, 0}, + }, + }, + { + name: "get 100 LSBs across words", + initial: BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + length: 100, + expected: BitArray{ + len: 100, + words: [4]uint64{maxUint64, 0x0000000FFFFFFFFF, 0, 0}, + }, + }, + { + name: "get 64 LSBs at word boundary", + initial: BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + length: 64, + expected: BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "get 128 LSBs at word boundary", + initial: BitArray{ + len: 192, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0}, + }, + length: 128, + expected: BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + }, + { + name: "get 150 LSBs in third word", + initial: BitArray{ + len: 192, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0}, + }, + length: 150, + expected: BitArray{ + len: 150, + words: [4]uint64{maxUint64, maxUint64, 0x3FFFFF, 0}, + }, + }, + { + name: "get 220 LSBs in fourth word", + initial: BitArray{ + len: 255, + words: [4]uint64{maxUint64, maxUint64, maxUint64, maxUint64}, + }, + length: 220, + expected: BitArray{ + len: 220, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0xFFFFFFF}, + }, + }, + { + name: "get 251 LSBs", + initial: BitArray{ + len: 255, + words: [4]uint64{maxUint64, maxUint64, maxUint64, maxUint64}, + }, + length: 251, + expected: BitArray{ + len: 251, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFF}, + }, + }, + { + name: "get 100 LSBs from sparse bits", + initial: BitArray{ + len: 128, + words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0x5555555555555555, 0, 0}, + }, + length: 100, + expected: BitArray{ + len: 100, + words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0x0000000555555555, 0, 0}, + }, + }, + { + name: "no change when new length equals current length", + initial: BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + length: 64, + expected: BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "no change when new length greater than current length", + initial: BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + length: 128, + expected: BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := new(BitArray).LSBsFromLSB(&tt.initial, tt.length) + if !result.Equal(&tt.expected) { + t.Errorf("Truncate() got = %+v, want %+v", result, tt.expected) + } + }) + } +} + +func TestMSBs(t *testing.T) { + tests := []struct { + name string + x *BitArray + n uint8 + want *BitArray + }{ + { + name: "empty array", + x: emptyBitArray, + n: 0, + want: emptyBitArray, + }, + { + name: "get all bits", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + n: 64, + want: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "get more bits than available", + x: &BitArray{ + len: 32, + words: [4]uint64{0xFFFFFFFF, 0, 0, 0}, + }, + n: 64, + want: &BitArray{ + len: 32, + words: [4]uint64{0xFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "get half of available bits", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + n: 32, + want: &BitArray{ + len: 32, + words: [4]uint64{0xFFFFFFFF00000000 >> 32, 0, 0, 0}, + }, + }, + { + name: "get MSBs across word boundary", + x: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + n: 100, + want: &BitArray{ + len: 100, + words: [4]uint64{maxUint64, maxUint64 >> 28, 0, 0}, + }, + }, + { + name: "get MSBs from max length array", + x: &BitArray{ + len: 255, + words: [4]uint64{maxUint64, maxUint64, maxUint64, ones63}, + }, + n: 64, + want: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "get zero bits", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + n: 0, + want: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "sparse bits", + x: &BitArray{ + len: 128, + words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0x5555555555555555, 0, 0}, + }, + n: 64, + want: &BitArray{ + len: 64, + words: [4]uint64{0x5555555555555555, 0, 0, 0}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := new(BitArray).MSBs(tt.x, tt.n) + if !got.Equal(tt.want) { + t.Errorf("MSBs() = %v, want %v", got, tt.want) + } + + if got.len != tt.want.len { + t.Errorf("MSBs() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestWriteAndUnmarshalBinary(t *testing.T) { + tests := []struct { + name string + ba BitArray + want []byte // Expected bytes after writing + }{ + { + name: "empty bit array", + ba: BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + want: []byte{0}, // Just the length byte + }, + { + name: "8 bits", + ba: BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, + }, + want: []byte{8, 0xFF}, // length byte + 1 data byte + }, + { + name: "10 bits requiring 2 bytes", + ba: BitArray{ + len: 10, + words: [4]uint64{0x3FF, 0, 0, 0}, // 1111111111 in binary + }, + want: []byte{10, 0x3, 0xFF}, // length byte + 2 data bytes + }, + { + name: "64 bits", + ba: BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + want: append( + []byte{64}, // length byte + []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}..., // 8 data bytes + ), + }, + { + name: "251 bits", + ba: BitArray{ + len: 251, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFF}, + }, + want: func() []byte { + b := make([]byte, 33) // 1 length byte + 32 data bytes + b[0] = 251 // length byte + // First byte is 0x07 (from the most significant bits) + b[1] = 0x07 + // Rest of the bytes are 0xFF + for i := 2; i < 33; i++ { + b[i] = 0xFF + } + return b + }(), + }, + { + name: "sparse bits", + ba: BitArray{ + len: 16, + words: [4]uint64{0xAAAA, 0, 0, 0}, // 1010101010101010 in binary + }, + want: []byte{16, 0xAA, 0xAA}, // length byte + 2 data bytes + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + buf := new(bytes.Buffer) + gotN, err := tt.ba.Write(buf) + assert.NoError(t, err) + + // Check number of bytes written + if gotN != len(tt.want) { + t.Errorf("Write() wrote %d bytes, want %d", gotN, len(tt.want)) + } + + // Check written bytes + got := buf.Bytes() + if !bytes.Equal(got, tt.want) { + t.Errorf("Write() = %v, want %v", got, tt.want) + } + + var gotBitArray BitArray + err = gotBitArray.UnmarshalBinary(got) + require.NoError(t, err) + if !gotBitArray.Equal(&tt.ba) { + t.Errorf("UnmarshalBinary() = %v, want %v", gotBitArray, tt.ba) + } + }) + } +} + +func TestCommonPrefix(t *testing.T) { + tests := []struct { + name string + x *BitArray + y *BitArray + want *BitArray + }{ + { + name: "empty arrays", + x: emptyBitArray, + y: emptyBitArray, + want: emptyBitArray, + }, + { + name: "one empty array", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + y: emptyBitArray, + want: emptyBitArray, + }, + { + name: "identical arrays - single word", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + y: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + want: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "identical arrays - multiple words", + x: &BitArray{ + len: 192, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0}, + }, + y: &BitArray{ + len: 192, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0}, + }, + want: &BitArray{ + len: 192, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0}, + }, + }, + { + name: "different lengths with common prefix - first word", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + y: &BitArray{ + len: 32, + words: [4]uint64{0xFFFFFFFF, 0, 0, 0}, + }, + want: &BitArray{ + len: 32, + words: [4]uint64{0xFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "different lengths with common prefix - multiple words", + x: &BitArray{ + len: 255, + words: [4]uint64{maxUint64, maxUint64, maxUint64, ones63}, + }, + y: &BitArray{ + len: 127, + words: [4]uint64{maxUint64, ones63, 0, 0}, + }, + want: &BitArray{ + len: 127, + words: [4]uint64{maxUint64, ones63, 0, 0}, + }, + }, + { + name: "different at first bit", + x: &BitArray{ + len: 64, + words: [4]uint64{ones63, 0, 0, 0}, + }, + y: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + want: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "different in middle of first word", + x: &BitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFF0FFFFFFF, 0, 0, 0}, + }, + y: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + want: &BitArray{ + len: 32, + words: [4]uint64{0xFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "different in second word", + x: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, 0xFFFFFFFF0FFFFFFF, 0, 0}, + }, + y: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + want: &BitArray{ + len: 32, + words: [4]uint64{0xFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "different in third word", + x: &BitArray{ + len: 192, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0}, + }, + y: &BitArray{ + len: 192, + words: [4]uint64{0, 0, 0xFFFFFFFFFFFFFF0F, 0}, + }, + want: &BitArray{ + len: 56, + words: [4]uint64{0xFFFFFFFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "different in last word", + x: &BitArray{ + len: 251, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFF}, + }, + y: &BitArray{ + len: 251, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFF0FFFFFFF}, + }, + want: &BitArray{ + len: 27, + words: [4]uint64{0x7FFFFFF}, + }, + }, + { + name: "sparse bits with common prefix", + x: &BitArray{ + len: 128, + words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0xAAAAAAAAAAAAAAAA, 0, 0}, + }, + y: &BitArray{ + len: 128, + words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0xAAAAAAAAAAAAA000, 0, 0}, + }, + want: &BitArray{ + len: 52, + words: [4]uint64{0xAAAAAAAAAAAAA, 0, 0, 0}, + }, + }, + { + name: "max length difference", + x: &BitArray{ + len: 255, + words: [4]uint64{maxUint64, maxUint64, maxUint64, ones63}, + }, + y: &BitArray{ + len: 1, + words: [4]uint64{0x1, 0, 0, 0}, + }, + want: &BitArray{ + len: 1, + words: [4]uint64{0x1, 0, 0, 0}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := new(BitArray) + gotSymmetric := new(BitArray) + + got.CommonMSBs(tt.x, tt.y) + if !got.Equal(tt.want) { + t.Errorf("CommonMSBs() = %v, want %v", got, tt.want) + } + + // Test symmetry: x.CommonMSBs(y) should equal y.CommonMSBs(x) + gotSymmetric.CommonMSBs(tt.y, tt.x) + if !gotSymmetric.Equal(tt.want) { + t.Errorf("CommonMSBs() symmetric test = %v, want %v", gotSymmetric, tt.want) + } + }) + } +} + +func TestIsBitSetFromLSB(t *testing.T) { + tests := []struct { + name string + ba BitArray + pos uint8 + want bool + }{ + { + name: "empty array", + ba: BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + pos: 0, + want: false, + }, + { + name: "first bit set", + ba: BitArray{ + len: 64, + words: [4]uint64{1, 0, 0, 0}, + }, + pos: 0, + want: true, + }, + { + name: "last bit in first word", + ba: BitArray{ + len: 64, + words: [4]uint64{1 << 63, 0, 0, 0}, + }, + pos: 63, + want: true, + }, + { + name: "first bit in second word", + ba: BitArray{ + len: 128, + words: [4]uint64{0, 1, 0, 0}, + }, + pos: 64, + want: true, + }, + { + name: "bit beyond length", + ba: BitArray{ + len: 64, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + pos: 65, + want: false, + }, + { + name: "alternating bits", + ba: BitArray{ + len: 8, + words: [4]uint64{0xAA, 0, 0, 0}, // 10101010 in binary + }, + pos: 1, + want: true, + }, + { + name: "alternating bits - unset position", + ba: BitArray{ + len: 8, + words: [4]uint64{0xAA, 0, 0, 0}, // 10101010 in binary + }, + pos: 0, + want: false, + }, + { + name: "bit in last word", + ba: BitArray{ + len: 251, + words: [4]uint64{0, 0, 0, 1 << 59}, + }, + pos: 251, + want: false, // position 251 is beyond the highest valid bit (250) + }, + { + name: "251 bits", + ba: BitArray{ + len: 251, + words: [4]uint64{0, 0, 0, 1 << 58}, + }, + pos: 250, + want: true, + }, + { + name: "highest valid bit (255)", + ba: BitArray{ + len: 255, + words: [4]uint64{0, 0, 0, 1 << 62}, // bit 255 set + }, + pos: 254, + want: true, + }, + { + name: "position at length boundary", + ba: BitArray{ + len: 100, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + pos: 100, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.ba.IsBitSetFromLSB(tt.pos) + if got != tt.want { + t.Errorf("IsBitSetFromLSB(%d) = %v, want %v", tt.pos, got, tt.want) + } + }) + } +} + +func TestIsBitSet(t *testing.T) { + tests := []struct { + name string + ba BitArray + pos uint8 + want bool + }{ + { + name: "empty array", + ba: BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + pos: 0, + want: false, + }, + { + name: "first bit (MSB) set", + ba: BitArray{ + len: 8, + words: [4]uint64{0x80, 0, 0, 0}, // 10000000 + }, + pos: 0, + want: true, + }, + { + name: "last bit (LSB) set", + ba: BitArray{ + len: 8, + words: [4]uint64{0x01, 0, 0, 0}, // 00000001 + }, + pos: 7, + want: true, + }, + { + name: "alternating bits", + ba: BitArray{ + len: 8, + words: [4]uint64{0xAA, 0, 0, 0}, // 10101010 + }, + pos: 0, + want: true, + }, + { + name: "alternating bits - unset position", + ba: BitArray{ + len: 8, + words: [4]uint64{0xAA, 0, 0, 0}, // 10101010 + }, + pos: 1, + want: false, + }, + { + name: "position beyond length", + ba: BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, + }, + pos: 8, + want: false, + }, + { + name: "bit in second word", + ba: BitArray{ + len: 128, + words: [4]uint64{0, 1, 0, 0}, + }, + pos: 63, + want: true, + }, + { + name: "251 bits", + ba: BitArray{ + len: 251, + words: [4]uint64{0, 0, 0, 1 << 58}, + }, + pos: 0, + want: true, + }, + { + name: "position at length boundary", + ba: BitArray{ + len: 100, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + pos: 99, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.ba.IsBitSet(tt.pos) + if got != tt.want { + t.Errorf("IsBitSet(%d) = %v, want %v", tt.pos, got, tt.want) + } + }) + } +} + +func TestFeltConversion(t *testing.T) { + tests := []struct { + name string + ba BitArray + length uint8 + want string // hex representation of felt + }{ + { + name: "empty bit array", + ba: BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + length: 0, + want: "0x0", + }, + { + name: "single word", + ba: BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + length: 64, + want: "0xffffffffffffffff", + }, + { + name: "two words", + ba: BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + length: 128, + want: "0xffffffffffffffffffffffffffffffff", + }, + { + name: "three words", + ba: BitArray{ + len: 192, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0}, + }, + length: 192, + want: "0xffffffffffffffffffffffffffffffffffffffffffffffff", + }, + { + name: "251 bits", + ba: BitArray{ + len: 251, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFF}, + }, + length: 251, + want: "0x7ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", + }, + { + name: "sparse bits", + ba: BitArray{ + len: 128, + words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0x5555555555555555, 0, 0}, + }, + length: 128, + want: "0x5555555555555555aaaaaaaaaaaaaaaa", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test Felt() conversion + gotFelt := tt.ba.Felt() + assert.Equal(t, tt.want, gotFelt.String()) + + // Test SetFelt() conversion (round trip) + var newBA BitArray + newBA.SetFelt(tt.length, &gotFelt) + assert.Equal(t, tt.ba.len, newBA.len) + assert.Equal(t, tt.ba.words, newBA.words) + }) + } +} + +func TestSetFeltValidation(t *testing.T) { + tests := []struct { + name string + feltStr string + length uint8 + shouldMatch bool + }{ + { + name: "valid felt with matching length", + feltStr: "0xf", + length: 4, + shouldMatch: true, + }, + { + name: "felt larger than specified length", + feltStr: "0xff", + length: 4, + shouldMatch: false, + }, + { + name: "zero felt with non-zero length", + feltStr: "0x0", + length: 8, + shouldMatch: true, + }, + { + name: "max felt with max length", + feltStr: "0x7ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", + length: 251, + shouldMatch: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var f felt.Felt + _, err := f.SetString(tt.feltStr) + require.NoError(t, err) + + var ba BitArray + ba.SetFelt(tt.length, &f) + + // Convert back to felt and compare + roundTrip := ba.Felt() + if tt.shouldMatch { + assert.True(t, roundTrip.Equal(&f), + "expected %s, got %s", f.String(), roundTrip.String()) + } else { + assert.False(t, roundTrip.Equal(&f), + "values should not match: original %s, roundtrip %s", + f.String(), roundTrip.String()) + } + }) + } +} + +func TestSetBit(t *testing.T) { + tests := []struct { + name string + bit uint8 + want BitArray + }{ + { + name: "set bit 0", + bit: 0, + want: BitArray{ + len: 1, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "set bit 1", + bit: 1, + want: BitArray{ + len: 1, + words: [4]uint64{1, 0, 0, 0}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := new(BitArray).SetBit(tt.bit) + if !got.Equal(&tt.want) { + t.Errorf("SetBit(%v) = %v, want %v", tt.bit, got, tt.want) + } + }) + } +} + +func TestCmp(t *testing.T) { + tests := []struct { + name string + x BitArray + y BitArray + want int + }{ + { + name: "equal empty arrays", + x: BitArray{len: 0, words: [4]uint64{0, 0, 0, 0}}, + y: BitArray{len: 0, words: [4]uint64{0, 0, 0, 0}}, + want: 0, + }, + { + name: "equal non-empty arrays", + x: BitArray{len: 64, words: [4]uint64{maxUint64, 0, 0, 0}}, + y: BitArray{len: 64, words: [4]uint64{maxUint64, 0, 0, 0}}, + want: 0, + }, + { + name: "different lengths - x shorter", + x: BitArray{len: 32, words: [4]uint64{maxUint64, 0, 0, 0}}, + y: BitArray{len: 64, words: [4]uint64{maxUint64, 0, 0, 0}}, + want: -1, + }, + { + name: "different lengths - x longer", + x: BitArray{len: 64, words: [4]uint64{maxUint64, 0, 0, 0}}, + y: BitArray{len: 32, words: [4]uint64{maxUint64, 0, 0, 0}}, + want: 1, + }, + { + name: "same length, x < y in first word", + x: BitArray{len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFE, 0, 0, 0}}, + y: BitArray{len: 64, words: [4]uint64{maxUint64, 0, 0, 0}}, + want: -1, + }, + { + name: "same length, x > y in first word", + x: BitArray{len: 64, words: [4]uint64{maxUint64, 0, 0, 0}}, + y: BitArray{len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFE, 0, 0, 0}}, + want: 1, + }, + { + name: "same length, difference in last word", + x: BitArray{len: 251, words: [4]uint64{0, 0, 0, 0x7FFFFFFFFFFFFFF}}, + y: BitArray{len: 251, words: [4]uint64{0, 0, 0, 0x7FFFFFFFFFFFFF0}}, + want: 1, + }, + { + name: "same length, sparse bits", + x: BitArray{len: 128, words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0xAAAAAAAAAAAAAAAA, 0, 0}}, + y: BitArray{len: 128, words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0xAAAAAAAAAAAAA000, 0, 0}}, + want: 1, + }, + { + name: "max length difference", + x: BitArray{len: 255, words: [4]uint64{maxUint64, maxUint64, maxUint64, ones63}}, + y: BitArray{len: 1, words: [4]uint64{0x1, 0, 0, 0}}, + want: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.x.Cmp(&tt.y) + if got != tt.want { + t.Errorf("Cmp() = %v, want %v", got, tt.want) + } + + // Test anti-symmetry: if x.Cmp(y) = z then y.Cmp(x) = -z + gotReverse := tt.y.Cmp(&tt.x) + if gotReverse != -tt.want { + t.Errorf("Reverse Cmp() = %v, want %v", gotReverse, -tt.want) + } + + // Test transitivity with self: x.Cmp(x) should always be 0 + if tt.x.Cmp(&tt.x) != 0 { + t.Error("Self Cmp() != 0") + } + }) + } +} + +func TestSetBytes(t *testing.T) { + tests := []struct { + name string + length uint8 + data []byte + want BitArray + }{ + { + name: "empty data", + length: 0, + data: []byte{}, + want: BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "single byte", + length: 8, + data: []byte{0xFF}, + want: BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, + }, + }, + { + name: "two bytes", + length: 16, + data: []byte{0xAA, 0xFF}, + want: BitArray{ + len: 16, + words: [4]uint64{0xAAFF, 0, 0, 0}, + }, + }, + { + name: "three bytes", + length: 24, + data: []byte{0xAA, 0xBB, 0xCC}, + want: BitArray{ + len: 24, + words: [4]uint64{0xAABBCC, 0, 0, 0}, + }, + }, + { + name: "four bytes", + length: 32, + data: []byte{0xAA, 0xBB, 0xCC, 0xDD}, + want: BitArray{ + len: 32, + words: [4]uint64{0xAABBCCDD, 0, 0, 0}, + }, + }, + { + name: "eight bytes (full word)", + length: 64, + data: []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}, + want: BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "sixteen bytes (two words)", + length: 128, + data: []byte{ + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, + }, + want: BitArray{ + len: 128, + words: [4]uint64{ + 0xAAAAAAAAAAAAAAAA, + 0xFFFFFFFFFFFFFFFF, + 0, 0, + }, + }, + }, + { + name: "thirty-two bytes (full array)", + length: 251, + data: []byte{ + 0x7F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + }, + want: BitArray{ + len: 251, + words: [4]uint64{ + maxUint64, + maxUint64, + maxUint64, + 0x7FFFFFFFFFFFFFF, + }, + }, + }, + { + name: "truncate to length", + length: 4, + data: []byte{0xFF}, + want: BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, + }, + }, + { + name: "data larger than 32 bytes", + length: 251, + data: []byte{ + 0x7F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, // extra bytes should be ignored + }, + want: BitArray{ + len: 251, + words: [4]uint64{ + maxUint64, + maxUint64, + maxUint64, + 0x7FFFFFFFFFFFFFF, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := new(BitArray).SetBytes(tt.length, tt.data) + if !got.Equal(&tt.want) { + t.Errorf("SetBytes(%d, %v) = %v, want %v", tt.length, tt.data, got, tt.want) + } + }) + } +} + +func TestSubset(t *testing.T) { + tests := []struct { + name string + x *BitArray + startPos uint8 + endPos uint8 + want *BitArray + }{ + { + name: "empty array", + x: emptyBitArray, + startPos: 0, + endPos: 0, + want: emptyBitArray, + }, + { + name: "invalid range - start >= end", + x: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 + }, + startPos: 4, + endPos: 2, + want: emptyBitArray, + }, + { + name: "invalid range - start >= length", + x: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 + }, + startPos: 8, + endPos: 10, + want: emptyBitArray, + }, + { + name: "full range", + x: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 + }, + startPos: 0, + endPos: 8, + want: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 + }, + }, + { + name: "middle subset", + x: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 + }, + startPos: 2, + endPos: 5, + want: &BitArray{ + len: 3, + words: [4]uint64{0x7, 0, 0, 0}, // 111 + }, + }, + { + name: "end subset", + x: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 + }, + startPos: 5, + endPos: 8, + want: &BitArray{ + len: 3, + words: [4]uint64{0x7, 0, 0, 0}, // 111 + }, + }, + { + name: "start subset", + x: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 + }, + startPos: 0, + endPos: 3, + want: &BitArray{ + len: 3, + words: [4]uint64{0x7, 0, 0, 0}, // 111 + }, + }, + { + name: "endPos beyond length", + x: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 + }, + startPos: 4, + endPos: 10, + want: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + }, + { + name: "sparse bits", + x: &BitArray{ + len: 8, + words: [4]uint64{0xAA, 0, 0, 0}, // 10101010 + }, + startPos: 2, + endPos: 6, + want: &BitArray{ + len: 4, + words: [4]uint64{0xA, 0, 0, 0}, // 1010 + }, + }, + { + name: "across word boundary", + x: &BitArray{ + len: 128, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + }, + startPos: 60, + endPos: 68, + want: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, + }, + }, + { + name: "max length subset", + x: &BitArray{ + len: 251, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFF}, + }, + startPos: 1, + endPos: 251, + want: &BitArray{ + len: 250, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x3FFFFFFFFFFFFFF}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := new(BitArray).Subset(tt.x, tt.startPos, tt.endPos) + if !got.Equal(tt.want) { + t.Errorf("Subset(%v, %d, %d) = %v, want %v", tt.x, tt.startPos, tt.endPos, got, tt.want) + } + }) + } +} From 1d0f8684958d6010521062b2ab129a7aae562b47 Mon Sep 17 00:00:00 2001 From: weiihann Date: Thu, 23 Jan 2025 13:22:49 +0800 Subject: [PATCH 40/44] linter --- core/trie2/trieutils/bitarray.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/trie2/trieutils/bitarray.go b/core/trie2/trieutils/bitarray.go index b83a59923e..e9f781fab3 100644 --- a/core/trie2/trieutils/bitarray.go +++ b/core/trie2/trieutils/bitarray.go @@ -18,7 +18,7 @@ const ( MaxBitArraySize = 33 // (1 + 4 * 8) bytes ) -var emptyBitArray = new(BitArray) //nolint:unused //TODO(weiihann): remove this nolint when we replace legacy trie +var emptyBitArray = new(BitArray) // TODO(weiihann): remove this nolint when we replace legacy trie // Represents a bit array with length representing the number of used bits. // It uses a little endian representation to do bitwise operations of the words efficiently. From dcbe81578918acc278eb2649463156f4bd7b29b9 Mon Sep 17 00:00:00 2001 From: weiihann Date: Thu, 23 Jan 2025 14:11:36 +0800 Subject: [PATCH 41/44] add comments --- core/trie2/trienode/node.go | 1 + core/trie2/trienode/nodeset.go | 12 +++++++++--- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/core/trie2/trienode/node.go b/core/trie2/trienode/node.go index 55f58ab34f..c27e9e7775 100644 --- a/core/trie2/trienode/node.go +++ b/core/trie2/trienode/node.go @@ -4,6 +4,7 @@ import ( "github.com/NethermindEth/juno/core/felt" ) +// Represents a raw trie node, which contains the encoded blob and the hash of the node. type Node struct { blob []byte hash felt.Felt diff --git a/core/trie2/trienode/nodeset.go b/core/trie2/trienode/nodeset.go index ab891be3a0..8b0c2f6f02 100644 --- a/core/trie2/trienode/nodeset.go +++ b/core/trie2/trienode/nodeset.go @@ -9,11 +9,13 @@ import ( "github.com/NethermindEth/juno/core/trie2/trieutils" ) +// Contains a set of nodes, which are indexed by their path in the trie. +// It is not thread safe. type NodeSet struct { - Owner felt.Felt + Owner felt.Felt // The owner (i.e. contract address) Nodes map[trieutils.BitArray]*Node - updates int - deletes int + updates int // the count of updated and inserted nodes + deletes int // the count of deleted nodes } func NewNodeSet(owner felt.Felt) *NodeSet { @@ -29,6 +31,7 @@ func (ns *NodeSet) Add(key trieutils.BitArray, node *Node) { ns.Nodes[key] = node } +// Iterates over the nodes in a sorted order and calls the callback for each node. func (ns *NodeSet) ForEach(desc bool, callback func(key trieutils.BitArray, node *Node) error) error { paths := make([]trieutils.BitArray, 0, len(ns.Nodes)) for key := range ns.Nodes { @@ -54,6 +57,8 @@ func (ns *NodeSet) ForEach(desc bool, callback func(key trieutils.BitArray, node return nil } +// Merges the other node set into the current node set. +// The owner of both node sets must be the same. func (ns *NodeSet) MergeSet(other *NodeSet) error { if ns.Owner != other.Owner { return fmt.Errorf("cannot merge node sets with different owners %x-%x", ns.Owner, other.Owner) @@ -64,6 +69,7 @@ func (ns *NodeSet) MergeSet(other *NodeSet) error { return nil } +// Adds a set of nodes to the current node set. func (ns *NodeSet) Merge(owner felt.Felt, other map[trieutils.BitArray]*Node) error { if ns.Owner != owner { return fmt.Errorf("cannot merge node sets with different owners %x-%x", ns.Owner, owner) From 7a851a87299b3f21a87e9199c2abd5ce6536ee65 Mon Sep 17 00:00:00 2001 From: weiihann Date: Thu, 6 Feb 2025 17:01:54 +0800 Subject: [PATCH 42/44] Remove leading zero bytes from bit array --- core/trie2/trieutils/bitarray.go | 33 +++++++++++++++++++------ core/trie2/trieutils/bitarray_test.go | 35 +++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 7 deletions(-) diff --git a/core/trie2/trieutils/bitarray.go b/core/trie2/trieutils/bitarray.go index e9f781fab3..6d53057225 100644 --- a/core/trie2/trieutils/bitarray.go +++ b/core/trie2/trieutils/bitarray.go @@ -15,6 +15,7 @@ import ( const ( maxUint64 = uint64(math.MaxUint64) // 0xFFFFFFFFFFFFFFFF maxUint8 = uint8(math.MaxUint8) + bytes32 = 32 MaxBitArraySize = 33 // (1 + 4 * 8) bytes ) @@ -456,7 +457,7 @@ func (b *BitArray) IsEmpty() bool { // Serialises the BitArray into a bytes buffer in the following format: // - First byte: length of the bit array (0-255) -// - Remaining bytes: the necessary bytes included in big endian order +// - Remaining bytes: the necessary bytes included in big endian order, without leading zeros // Example: // // BitArray{len: 10, words: [4]uint64{0x03FF}} -> [0x0A, 0x03, 0xFF] @@ -481,16 +482,17 @@ func (b *BitArray) UnmarshalBinary(data []byte) error { } length := data[0] - byteCount := (uint(length) + 7) / 8 // Round up to nearest byte + byteCount := (int(length) + 7) / 8 // Get the total number of bytes needed to represent the bit array - if len(data) < int(byteCount)+1 { - return fmt.Errorf("invalid data length: got %d bytes, expected %d", len(data), byteCount+1) + if len(data) > byteCount+1 { + return fmt.Errorf("invalid data length: got %d bytes, expected <= %d", len(data), byteCount+1) } b.len = length var bs [32]byte - copy(bs[32-byteCount:], data[1:]) + bitArrBytes := data[1:] + copy(bs[32-len(bitArrBytes):], bitArrBytes) // Fill up the non-zero bytes at the end of the byte array b.setBytes32(bs[:]) return nil @@ -715,14 +717,31 @@ func (b *BitArray) byteCount() uint { } // Returns a slice containing only the bytes that are actually used by the bit array, -// as specified by the length. The returned slice is in big-endian order. +// as specified by the length. Leading zero bytes will be removed. +// The returned slice is in big-endian order. // // Example: // // len = 10, words = [0x3FF, 0, 0, 0] -> [0x03, 0xFF] func (b *BitArray) activeBytes() []byte { + if b.len == 0 { + return nil + } + wordsBytes := b.Bytes() - return wordsBytes[32-b.byteCount():] + end := uint(bytes32) + start := end - b.byteCount() + + // Find first non-zero byte + for start < end-1 && wordsBytes[start] == 0 { + start++ + } + + if start == end { + return nil + } + + return wordsBytes[start:end] } func (b *BitArray) rsh64(x *BitArray) { diff --git a/core/trie2/trieutils/bitarray_test.go b/core/trie2/trieutils/bitarray_test.go index 6589f94939..1835b538bf 100644 --- a/core/trie2/trieutils/bitarray_test.go +++ b/core/trie2/trieutils/bitarray_test.go @@ -1095,6 +1095,41 @@ func TestWriteAndUnmarshalBinary(t *testing.T) { }, want: []byte{16, 0xAA, 0xAA}, // length byte + 2 data bytes }, + { + name: "leading zeros in first byte", + ba: BitArray{ + len: 8, + words: [4]uint64{0x0F, 0, 0, 0}, // 00001111 + }, + want: []byte{8, 0x0F}, // length byte + 1 data byte + }, + { + name: "all leading zeros in first byte", + ba: BitArray{ + len: 8, + words: [4]uint64{0x00, 0, 0, 0}, // 00000000 + }, + want: []byte{8, 0x00}, // length byte + 1 data byte + }, + { + name: "leading zeros across multiple bytes", + ba: BitArray{ + len: 24, + words: [4]uint64{0x0000FF, 0, 0, 0}, // 000000000000000011111111 + }, + want: []byte{24, 0xFF}, // length byte + 1 data byte + }, + { + name: "leading zeros in large number", + ba: BitArray{ + len: 255, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, // All 1s in lower bits, zeros in upper bits + }, + want: append( + []byte{255}, // length byte + bytes.Repeat([]byte{0xFF}, 16)..., // 16 bytes of all 1s + ), + }, } for _, tt := range tests { From 5f183decdd8ca221a7e82eba7238b5f15a433ffc Mon Sep 17 00:00:00 2001 From: weiihann Date: Wed, 12 Feb 2025 09:23:23 +0800 Subject: [PATCH 43/44] comments --- db/buckets.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/db/buckets.go b/db/buckets.go index fe6ea5a85c..329eb15993 100644 --- a/db/buckets.go +++ b/db/buckets.go @@ -38,9 +38,9 @@ const ( MempoolTail // key of the tail node MempoolLength // number of transactions MempoolNode - ClassTrie - ContractTrieContract - ContractTrieStorage + ClassTrie // ClassTrie + Node path -> Trie Node + ContractTrieContract // ContractTrieContract + Node path -> Trie Node + ContractTrieStorage // ContractTrieStorage + Contract Address + Node path -> Trie Node ) // Key flattens a prefix and series of byte arrays into a single []byte. From 18b6b2a05dee9eb4569fca4c44e1e5a8187c1c6e Mon Sep 17 00:00:00 2001 From: weiihann Date: Mon, 17 Feb 2025 18:37:37 +0800 Subject: [PATCH 44/44] Fix encoding bug --- core/trie2/collector.go | 2 +- core/trie2/id.go | 30 ++++------ core/trie2/node_enc.go | 110 +++++++++++++++++++----------------- core/trie2/node_enc_test.go | 14 +---- core/trie2/proof.go | 7 +-- core/trie2/proof_test.go | 12 ++-- core/trie2/trie.go | 26 +++++---- core/trie2/trie_test.go | 23 +++----- 8 files changed, 102 insertions(+), 122 deletions(-) diff --git a/core/trie2/collector.go b/core/trie2/collector.go index 710f55b5bb..d4bcca1aa0 100644 --- a/core/trie2/collector.go +++ b/core/trie2/collector.go @@ -91,7 +91,7 @@ func (c *collector) collectChildren(path *Path, n *binaryNode, parallel bool) [2 var wg sync.WaitGroup var mu sync.Mutex - for i := 0; i < 2; i++ { + for i := range 2 { wg.Add(1) go func(idx int) { defer wg.Done() diff --git a/core/trie2/id.go b/core/trie2/id.go index 49e3b09fe6..ed9a8c517d 100644 --- a/core/trie2/id.go +++ b/core/trie2/id.go @@ -15,10 +15,8 @@ const ( // Represents the identifier for uniquely identifying a trie. type ID struct { - TrieType TrieType - Root felt.Felt // The root hash of the trie - Owner felt.Felt // The contract address which the trie belongs to - StorageRoot felt.Felt // The root hash of the storage trie of a contract. + TrieType TrieType + Owner felt.Felt // The contract address which the trie belongs to } // Returns the corresponding DB bucket for the trie @@ -39,31 +37,25 @@ func (id *ID) Bucket() db.Bucket { } // Constructs an identifier for a class trie with the provided class trie root hash -func ClassTrieID(root felt.Felt) *ID { +func ClassTrieID() *ID { return &ID{ - TrieType: ClassTrie, - Root: root, - Owner: felt.Zero, // class trie does not have an owner - StorageRoot: felt.Zero, // only contract storage trie has a storage root + TrieType: ClassTrie, + Owner: felt.Zero, // class trie does not have an owner } } // Constructs an identifier for a contract trie or a contract's storage trie -func ContractTrieID(root, owner, storageRoot felt.Felt) *ID { +func ContractTrieID(owner felt.Felt) *ID { return &ID{ - TrieType: ContractTrie, - Root: root, - Owner: owner, - StorageRoot: storageRoot, + TrieType: ContractTrie, + Owner: owner, } } // A general identifier, typically used for temporary trie -func TrieID(root felt.Felt) *ID { +func TrieID() *ID { return &ID{ - TrieType: Empty, - Root: root, - Owner: felt.Zero, - StorageRoot: felt.Zero, + TrieType: Empty, + Owner: felt.Zero, } } diff --git a/core/trie2/node_enc.go b/core/trie2/node_enc.go index 3cc0229ecf..83d1e75626 100644 --- a/core/trie2/node_enc.go +++ b/core/trie2/node_enc.go @@ -102,31 +102,15 @@ func nodeToBytes(n node) []byte { } // Decodes the encoded bytes and returns the corresponding node -func decodeNode(blob []byte, hash felt.Felt, pathLen, maxPathLen uint8) (node, error) { +func decodeNode(blob []byte, hash *felt.Felt, pathLen, maxPathLen uint8) (node, error) { if len(blob) == 0 { return nil, errors.New("cannot decode empty blob") } - - var ( - n node - err error - isLeaf bool - ) - if pathLen > maxPathLen { - return nil, fmt.Errorf("node path length (%d) is greater than max path length (%d)", pathLen, maxPathLen) + return nil, fmt.Errorf("node path length (%d) > max (%d)", pathLen, maxPathLen) } - isLeaf = pathLen == maxPathLen - - if len(blob) == hashOrValueNodeSize { - var f felt.Felt - f.SetBytes(blob) - if isLeaf { - n = &valueNode{Felt: f} - } else { - n = &hashNode{Felt: f} - } + if n, ok := decodeHashOrValueNode(blob, pathLen, maxPathLen); ok { return n, nil } @@ -135,44 +119,66 @@ func decodeNode(blob []byte, hash felt.Felt, pathLen, maxPathLen uint8) (node, e switch nodeType { case binaryNodeType: - binary := &binaryNode{flags: nodeFlag{hash: &hashNode{Felt: hash}}} // cache the hash - binary.children[0], err = decodeNode(blob[:hashOrValueNodeSize], hash, pathLen+1, maxPathLen) - if err != nil { - return nil, err - } - binary.children[1], err = decodeNode(blob[hashOrValueNodeSize:], hash, pathLen+1, maxPathLen) - if err != nil { - return nil, err - } - - n = binary + return decodeBinaryNode(blob, hash, pathLen, maxPathLen) case edgeNodeType: - // Ensure the blob length is within the valid range for an edge node - if len(blob) > edgeNodeMaxSize || len(blob) < hashOrValueNodeSize { - return nil, fmt.Errorf("invalid node size: %d", len(blob)) - } + return decodeEdgeNode(blob, hash, pathLen, maxPathLen) + default: + panic(fmt.Sprintf("unknown node type: %d", nodeType)) + } +} - edge := &edgeNode{ - path: &trieutils.BitArray{}, - flags: nodeFlag{hash: &hashNode{Felt: hash}}, // cache the hash - } - edge.child, err = decodeNode(blob[:hashOrValueNodeSize], felt.Felt{}, pathLen, maxPathLen) - if err != nil { - return nil, err - } - if err := edge.path.UnmarshalBinary(blob[hashOrValueNodeSize:]); err != nil { - return nil, err +func decodeHashOrValueNode(blob []byte, pathLen, maxPathLen uint8) (node, bool) { + if len(blob) == hashOrValueNodeSize { + var f felt.Felt + f.SetBytes(blob) + if pathLen == maxPathLen { + return &valueNode{Felt: f}, true } + return &hashNode{Felt: f}, true + } + return nil, false +} - // We do another path length check to see if the node is a leaf - if pathLen+edge.path.Len() == maxPathLen { - edge.child = &valueNode{Felt: edge.child.(*hashNode).Felt} - } +func decodeBinaryNode(blob []byte, hash *felt.Felt, pathLen, maxPathLen uint8) (*binaryNode, error) { + if len(blob) < 2*hashOrValueNodeSize { + return nil, fmt.Errorf("invalid binary node size: %d", len(blob)) + } - n = edge - default: - panic(fmt.Sprintf("unknown decode node type: %d", nodeType)) + binary := &binaryNode{} + if hash != nil { + binary.flags.hash = &hashNode{Felt: *hash} + } + + var err error + if binary.children[0], err = decodeNode(blob[:hashOrValueNodeSize], hash, pathLen+1, maxPathLen); err != nil { + return nil, err + } + if binary.children[1], err = decodeNode(blob[hashOrValueNodeSize:], hash, pathLen+1, maxPathLen); err != nil { + return nil, err + } + return binary, nil +} + +func decodeEdgeNode(blob []byte, hash *felt.Felt, pathLen, maxPathLen uint8) (*edgeNode, error) { + if len(blob) > edgeNodeMaxSize || len(blob) < hashOrValueNodeSize { + return nil, fmt.Errorf("invalid edge node size: %d", len(blob)) + } + + edge := &edgeNode{path: &trieutils.BitArray{}} + if hash != nil { + edge.flags.hash = &hashNode{Felt: *hash} } - return n, nil + var err error + if edge.child, err = decodeNode(blob[:hashOrValueNodeSize], nil, pathLen, maxPathLen); err != nil { + return nil, err + } + if err := edge.path.UnmarshalBinary(blob[hashOrValueNodeSize:]); err != nil { + return nil, err + } + + if pathLen+edge.path.Len() == maxPathLen { + edge.child = &valueNode{Felt: edge.child.(*hashNode).Felt} + } + return edge, nil } diff --git a/core/trie2/node_enc_test.go b/core/trie2/node_enc_test.go index 022e1730f9..a002606e87 100644 --- a/core/trie2/node_enc_test.go +++ b/core/trie2/node_enc_test.go @@ -40,7 +40,6 @@ func TestNodeEncodingDecoding(t *testing.T) { node: &edgeNode{ path: newPath(1, 0, 1), child: &valueNode{Felt: newFelt(123)}, - flags: nodeFlag{}, }, pathLen: 8, maxPath: 8, @@ -50,7 +49,6 @@ func TestNodeEncodingDecoding(t *testing.T) { node: &edgeNode{ path: newPath(0, 1, 0), child: &hashNode{Felt: newFelt(456)}, - flags: nodeFlag{}, }, pathLen: 3, maxPath: 8, @@ -62,7 +60,6 @@ func TestNodeEncodingDecoding(t *testing.T) { &hashNode{Felt: newFelt(111)}, &hashNode{Felt: newFelt(222)}, }, - flags: nodeFlag{}, }, pathLen: 0, maxPath: 8, @@ -74,7 +71,6 @@ func TestNodeEncodingDecoding(t *testing.T) { &valueNode{Felt: newFelt(555)}, &valueNode{Felt: newFelt(666)}, }, - flags: nodeFlag{}, }, pathLen: 7, maxPath: 8, @@ -96,7 +92,6 @@ func TestNodeEncodingDecoding(t *testing.T) { node: &edgeNode{ path: newPath(), child: &hashNode{Felt: newFelt(1111)}, - flags: nodeFlag{}, }, pathLen: 0, maxPath: 8, @@ -106,7 +101,6 @@ func TestNodeEncodingDecoding(t *testing.T) { node: &edgeNode{ path: newPath(1, 1, 1, 1, 1, 1, 1, 1), child: &valueNode{Felt: newFelt(2222)}, - flags: nodeFlag{}, }, pathLen: 0, maxPath: 8, @@ -120,7 +114,7 @@ func TestNodeEncodingDecoding(t *testing.T) { // Try to decode hash := tt.node.hash(crypto.Pedersen) - decoded, err := decodeNode(encoded, *hash, tt.pathLen, tt.maxPath) + decoded, err := decodeNode(encoded, hash, tt.pathLen, tt.maxPath) if tt.wantErr { require.Error(t, err) @@ -163,16 +157,14 @@ func TestNodeEncodingDecodingBoundary(t *testing.T) { // Test with invalid path lengths t.Run("invalid path lengths", func(t *testing.T) { - hash := felt.Zero blob := make([]byte, hashOrValueNodeSize) - _, err := decodeNode(blob, hash, 255, 8) // pathLen > maxPath + _, err := decodeNode(blob, nil, 255, 8) // pathLen > maxPath require.Error(t, err) }) // Test with empty buffer t.Run("empty buffer", func(t *testing.T) { - hash := felt.Zero - _, err := decodeNode([]byte{}, hash, 0, 8) + _, err := decodeNode([]byte{}, nil, 0, 8) require.Error(t, err) }) } diff --git a/core/trie2/proof.go b/core/trie2/proof.go index 1916cb1339..49210cb210 100644 --- a/core/trie2/proof.go +++ b/core/trie2/proof.go @@ -152,11 +152,6 @@ func VerifyProof(root, key *felt.Felt, proof *ProofNodeSet, hash crypto.HashFn) // - Zero element proof: A single edge proof suffices for verification. The proof is invalid if there are additional elements. // // The function returns a boolean indicating if there are more elements and an error if the range proof is invalid. -// -// TODO(weiihann): Given a binary leaf and a left-sibling first key, if the right sibling is removed, the proof would still be valid. -// Conversely, given a binary leaf and a right-sibling last key, if the left sibling is removed, the proof would still be valid. -// Range proof should not be valid for both of these cases, but currently is, which is an attack vector. -// The problem probably lies in how we do root hash calculation. func VerifyRangeProof(rootHash, first *felt.Felt, keys, values []*felt.Felt, proof *ProofNodeSet) (bool, error) { //nolint:funlen,gocyclo // Ensure the number of keys and values are the same if len(keys) != len(values) { @@ -164,7 +159,7 @@ func VerifyRangeProof(rootHash, first *felt.Felt, keys, values []*felt.Felt, pro } // Ensure all keys are monotonically increasing and values contain no deletions - for i := 0; i < len(keys); i++ { + for i := range keys { if i < len(keys)-1 && keys[i].Cmp(keys[i+1]) > 0 { return false, errors.New("keys are not monotonic increasing") } diff --git a/core/trie2/proof_test.go b/core/trie2/proof_test.go index ed462e7992..69390425d7 100644 --- a/core/trie2/proof_test.go +++ b/core/trie2/proof_test.go @@ -372,7 +372,7 @@ func TestSingleSideRangeProof(t *testing.T) { keys := make([]*felt.Felt, i+1) values := make([]*felt.Felt, i+1) - for j := 0; j < i+1; j++ { + for j := range i + 1 { keys[j] = records[j].key values[j] = records[j].value } @@ -487,7 +487,7 @@ func TestBadRangeProof(t *testing.T) { tr, records := randomTrie(t, 5000) root := tr.Hash() - for i := 0; i < 500; i++ { + for range 500 { start := rand.Intn(len(records)) end := rand.Intn(len(records)-start) + start + 1 @@ -539,7 +539,7 @@ func TestBadRangeProof(t *testing.T) { func BenchmarkProve(b *testing.B) { tr, records := randomTrie(b, 1000) b.ResetTimer() - for i := 0; i < b.N; i++ { + for i := range b.N { proof := NewProofNodeSet() key := records[i%len(records)].key if err := tr.Prove(key, proof); err != nil { @@ -552,7 +552,7 @@ func BenchmarkVerifyProof(b *testing.B) { tr, records := randomTrie(b, 1000) root := tr.Hash() - var proofs []*ProofNodeSet + proofs := make([]*ProofNodeSet, 0, len(records)) for _, record := range records { proof := NewProofNodeSet() if err := tr.Prove(record.key, proof); err != nil { @@ -562,7 +562,7 @@ func BenchmarkVerifyProof(b *testing.B) { } b.ResetTimer() - for i := 0; i < b.N; i++ { + for i := range b.N { index := i % len(records) if _, err := VerifyProof(&root, records[index].key, proofs[index], crypto.Pedersen); err != nil { b.Fatal(err) @@ -589,7 +589,7 @@ func BenchmarkVerifyRangeProof(b *testing.B) { } b.ResetTimer() - for i := 0; i < b.N; i++ { + for range b.N { _, err := VerifyRangeProof(&root, keys[0], keys, values, proof) require.NoError(b, err) } diff --git a/core/trie2/trie.go b/core/trie2/trie.go index d3830e9bc3..d00b16cd73 100644 --- a/core/trie2/trie.go +++ b/core/trie2/trie.go @@ -2,6 +2,7 @@ package trie2 import ( "bytes" + "errors" "fmt" "github.com/NethermindEth/juno/core/crypto" @@ -14,8 +15,6 @@ import ( const contractClassTrieHeight = 251 -var emptyRoot = felt.Felt{} - type Path = trieutils.BitArray type Trie struct { @@ -57,13 +56,13 @@ func New(id *ID, height uint8, hashFn crypto.HashFn, txn db.Transaction) (*Trie, nodeTracer: newTracer(), } - if id.Root != emptyRoot { - root, err := tr.resolveNode(&hashNode{Felt: id.Root}, Path{}) - if err != nil { - return nil, err - } + root, err := tr.resolveNode(nil, Path{}) + if err == nil { tr.root = root + } else if !errors.Is(err, db.ErrKeyNotFound) { + return nil, err } + return tr, nil } @@ -455,7 +454,7 @@ func (t *Trie) delete(n node, prefix, key *Path) (bool, node, error) { } // Resolves the node at the given path from the database -func (t *Trie) resolveNode(hash *hashNode, path Path) (node, error) { +func (t *Trie) resolveNode(hn *hashNode, path Path) (node, error) { buf := bufferPool.Get().(*bytes.Buffer) buf.Reset() defer func() { @@ -469,7 +468,12 @@ func (t *Trie) resolveNode(hash *hashNode, path Path) (node, error) { } blob := buf.Bytes() - return decodeNode(blob, hash.Felt, path.Len(), t.height) + + var hash *felt.Felt + if hn != nil { + hash = &hn.Felt + } + return decodeNode(blob, hash, path.Len(), t.height) } // Calculate the hash of the root node @@ -499,9 +503,9 @@ func (t *Trie) String() string { } func NewEmptyPedersen() (*Trie, error) { - return New(TrieID(felt.Zero), contractClassTrieHeight, crypto.Pedersen, db.NewMemTransaction()) + return New(TrieID(), contractClassTrieHeight, crypto.Pedersen, db.NewMemTransaction()) } func NewEmptyPoseidon() (*Trie, error) { - return New(TrieID(felt.Zero), contractClassTrieHeight, crypto.Poseidon, db.NewMemTransaction()) + return New(TrieID(), contractClassTrieHeight, crypto.Poseidon, db.NewMemTransaction()) } diff --git a/core/trie2/trie_test.go b/core/trie2/trie_test.go index 7effe704a2..4b19a323ef 100644 --- a/core/trie2/trie_test.go +++ b/core/trie2/trie_test.go @@ -182,20 +182,11 @@ func TestHash(t *testing.T) { }) } -func TestMissingRoot(t *testing.T) { - var root felt.Felt - root.SetUint64(1) - - tr, err := New(TrieID(root), contractClassTrieHeight, crypto.Pedersen, db.NewMemTransaction()) - require.Nil(t, tr) - require.Error(t, err) -} - func TestCommit(t *testing.T) { verifyCommit := func(t *testing.T, records []*keyValue) { t.Helper() db := db.NewMemTransaction() - tr, err := New(TrieID(felt.Zero), contractClassTrieHeight, crypto.Pedersen, db) + tr, err := New(TrieID(), contractClassTrieHeight, crypto.Pedersen, db) require.NoError(t, err) for _, record := range records { @@ -203,10 +194,10 @@ func TestCommit(t *testing.T) { require.NoError(t, err) } - root, err := tr.Commit() + _, err = tr.Commit() require.NoError(t, err) - tr2, err := New(TrieID(root), contractClassTrieHeight, crypto.Pedersen, db) + tr2, err := New(TrieID(), contractClassTrieHeight, crypto.Pedersen, db) require.NoError(t, err) for _, record := range records { @@ -361,7 +352,7 @@ func runRandTestBool(rt randTest) bool { //nolint:gocyclo func runRandTest(rt randTest) error { txn := db.NewMemTransaction() - tr, err := New(TrieID(felt.Zero), contractClassTrieHeight, crypto.Pedersen, txn) + tr, err := New(TrieID(), contractClassTrieHeight, crypto.Pedersen, txn) if err != nil { return err } @@ -408,11 +399,11 @@ func runRandTest(rt randTest) error { case opHash: tr.Hash() case opCommit: - root, err := tr.Commit() + _, err := tr.Commit() if err != nil { rt[i].err = fmt.Errorf("commit failed: %w", err) } - newtr, err := New(TrieID(root), contractClassTrieHeight, crypto.Pedersen, txn) + newtr, err := New(TrieID(), contractClassTrieHeight, crypto.Pedersen, txn) if err != nil { rt[i].err = fmt.Errorf("new trie failed: %w", err) } @@ -451,7 +442,7 @@ func randomTrie(t testing.TB, n int) (*Trie, []*keyValue) { tr, _ := NewEmptyPedersen() records := make([]*keyValue, n) - for i := 0; i < n; i++ { + for i := range n { key := new(felt.Felt).SetUint64(uint64(rrand.Uint32() + 1)) records[i] = &keyValue{key: key, value: key} err := tr.Update(key, key)