From 0bd5aa9992ded53c9a30d2afadf9dbb0b51889eb Mon Sep 17 00:00:00 2001 From: weiihann Date: Thu, 2 Jan 2025 11:14:51 +0800 Subject: [PATCH] Add hasher --- core/trie2/hasher.go | 85 +++++++++++++++++++++++++++++ core/trie2/trie.go | 18 +++++- core/trie2/trie_test.go | 118 +++++++++++++++++++++++++++++++++++++++- 3 files changed, 216 insertions(+), 5 deletions(-) create mode 100644 core/trie2/hasher.go diff --git a/core/trie2/hasher.go b/core/trie2/hasher.go new file mode 100644 index 0000000000..7152cd12ec --- /dev/null +++ b/core/trie2/hasher.go @@ -0,0 +1,85 @@ +package trie2 + +import ( + "fmt" + "sync" + + "github.com/NethermindEth/juno/core/crypto" +) + +// hasher handles node hashing for the trie. It supports both sequential and parallel +// hashing modes. +type hasher struct { + hashFn crypto.HashFn // The hash function to use + parallel bool // Whether to hash binary node children in parallel +} + +func newHasher(hash crypto.HashFn, parallel bool) hasher { + return hasher{ + hashFn: hash, + parallel: parallel, + } +} + +// hash computes the hash of a node and returns both the hash node and a cached +// version of the original node. If the node already has a cached hash, returns +// that instead of recomputing. +func (h *hasher) hash(n node) (node, node) { + if hash, _ := n.cache(); hash != nil { + return hash, n + } + + switch n := n.(type) { + case *edgeNode: + collapsed, cached := h.hashEdgeChild(n) + hn := &hashNode{Felt: collapsed.hash(h.hashFn)} + cached.flags.hash = hn + return hn, cached + case *binaryNode: + collapsed, cached := h.hashBinaryChildren(n) + hn := &hashNode{Felt: collapsed.hash(h.hashFn)} + cached.flags.hash = hn + return hn, cached + case valueNode, hashNode: + return n, n + default: + panic(fmt.Sprintf("unknown node type: %T", n)) + } +} + +func (h *hasher) hashEdgeChild(n *edgeNode) (collapsed, cached *edgeNode) { + collapsed, cached = n.copy(), n.copy() + + switch n.child.(type) { + case *edgeNode, *binaryNode: + collapsed.child, cached.child = h.hash(n.child) + } + + return collapsed, cached +} + +func (h *hasher) hashBinaryChildren(n *binaryNode) (collapsed, cached *binaryNode) { + collapsed, cached = n.copy(), n.copy() + + if h.parallel { // TODO(weiihann): double check this parallel strategy + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + collapsed.children[0], cached.children[0] = h.hash(n.children[0]) + }() + + go func() { + defer wg.Done() + collapsed.children[1], cached.children[1] = h.hash(n.children[1]) + }() + + wg.Wait() + } else { + collapsed.children[0], cached.children[0] = h.hash(n.children[0]) + collapsed.children[1], cached.children[1] = h.hash(n.children[1]) + } + + return collapsed, cached +} diff --git a/core/trie2/trie.go b/core/trie2/trie.go index 2e97770488..cff79c3a3a 100644 --- a/core/trie2/trie.go +++ b/core/trie2/trie.go @@ -3,6 +3,7 @@ package trie2 import ( "fmt" + "github.com/NethermindEth/juno/core/crypto" "github.com/NethermindEth/juno/core/felt" ) @@ -11,11 +12,12 @@ type Trie struct { root node reader interface{} // TODO(weiihann): implement reader // committed bool + hashFn crypto.HashFn } // TODO(weiihann): implement this -func NewTrie(height uint8) *Trie { - return &Trie{height: height} +func NewTrie(height uint8, hashFn crypto.HashFn) *Trie { + return &Trie{height: height, hashFn: hashFn} } // Modifies or inserts a key-value pair in the trie. @@ -55,6 +57,13 @@ func (t *Trie) Delete(key *felt.Felt) error { return nil } +// Returns the root hash of the trie. Calling this method will also cache the hash of each node in the trie. +func (t *Trie) Hash() *felt.Felt { + hash, cached := t.hashRoot() + t.root = cached + return hash.(*hashNode).Felt +} + func (t *Trie) get(n node, key *BitArray) (*felt.Felt, node, bool, error) { switch n := n.(type) { case *edgeNode: @@ -244,6 +253,11 @@ func (t *Trie) delete(n node, prefix, key *BitArray) (bool, node, error) { } } +func (t *Trie) hashRoot() (node, node) { + h := newHasher(t.hashFn, false) // TODO(weiihann): handle parallel hashing + return h.hash(t.root) +} + // Converts a Felt value into a BitArray representation suitable for // use as a trie key with the specified height. func (t *Trie) FeltToKey(f *felt.Felt) BitArray { diff --git a/core/trie2/trie_test.go b/core/trie2/trie_test.go index f3001b694b..6361eba4fa 100644 --- a/core/trie2/trie_test.go +++ b/core/trie2/trie_test.go @@ -4,7 +4,9 @@ import ( "math/rand" "testing" + "github.com/NethermindEth/juno/core/crypto" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/utils" "github.com/stretchr/testify/require" ) @@ -60,13 +62,123 @@ func TestDeleteRandom(t *testing.T) { } } +// The expected hashes are taken from Pathfinder's tests +func TestHash(t *testing.T) { + t.Run("one leaf", func(t *testing.T) { + tr := NewTrie(251, crypto.Pedersen) + err := tr.Update(new(felt.Felt).SetUint64(1), new(felt.Felt).SetUint64(2)) + require.NoError(t, err) + hash := tr.Hash() + + expected := "0x2ab889bd35e684623df9b4ea4a4a1f6d9e0ef39b67c1293b8a89dd17e351330" + require.Equal(t, expected, hash.String(), "expected %s, got %s", expected, hash.String()) + }) + + t.Run("two leaves", func(t *testing.T) { + tr := NewTrie(251, crypto.Pedersen) + err := tr.Update(new(felt.Felt).SetUint64(0), new(felt.Felt).SetUint64(2)) + require.NoError(t, err) + err = tr.Update(new(felt.Felt).SetUint64(1), new(felt.Felt).SetUint64(3)) + require.NoError(t, err) + root := tr.Hash() + + expected := "0x79acdb7a3d78052114e21458e8c4aecb9d781ce79308193c61a2f3f84439f66" + require.Equal(t, expected, root.String(), "expected %s, got %s", expected, root.String()) + }) + + t.Run("three leaves", func(t *testing.T) { + tr := NewTrie(251, crypto.Pedersen) + + keys := []*felt.Felt{ + new(felt.Felt).SetUint64(16), + new(felt.Felt).SetUint64(17), + new(felt.Felt).SetUint64(19), + } + + vals := []*felt.Felt{ + new(felt.Felt).SetUint64(10), + new(felt.Felt).SetUint64(11), + new(felt.Felt).SetUint64(12), + } + + for i := range keys { + err := tr.Update(keys[i], vals[i]) + require.NoError(t, err) + } + + expected := "0x7e2184e9e1a651fd556b42b4ff10e44a71b1709f641e0203dc8bd2b528e5e81" + root := tr.Hash() + require.Equal(t, expected, root.String(), "expected %s, got %s", expected, root.String()) + }) + + t.Run("double binary", func(t *testing.T) { + // (249,0,x3) + // | + // (0, 0, x3) + // / \ + // (0,0,x1) (1, 1, 5) + // / \ | + // (2) (3) (5) + + keys := []*felt.Felt{ + new(felt.Felt).SetUint64(0), + new(felt.Felt).SetUint64(1), + new(felt.Felt).SetUint64(3), + } + + vals := []*felt.Felt{ + new(felt.Felt).SetUint64(2), + new(felt.Felt).SetUint64(3), + new(felt.Felt).SetUint64(5), + } + + tr := NewTrie(251, crypto.Pedersen) + for i := range keys { + err := tr.Update(keys[i], vals[i]) + require.NoError(t, err) + } + + expected := "0x6a316f09913454294c6b6751dea8449bc2e235fdc04b2ab0e1ac7fea25cc34f" + root := tr.Hash() + require.Equal(t, expected, root.String(), "expected %s, got %s", expected, root.String()) + }) + + t.Run("binary root", func(t *testing.T) { + // (0, 0, x) + // / \ + // (250, 0, cc) (250, 11111.., dd) + // | | + // (cc) (dd) + + keys := []*felt.Felt{ + new(felt.Felt).SetUint64(0), + utils.HexToFelt(t, "0x7ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"), + } + + vals := []*felt.Felt{ + utils.HexToFelt(t, "0xcc"), + utils.HexToFelt(t, "0xdd"), + } + + tr := NewTrie(251, crypto.Pedersen) + for i := range keys { + err := tr.Update(keys[i], vals[i]) + require.NoError(t, err) + } + + expected := "0x542ced3b6aeef48339129a03e051693eff6a566d3a0a94035fa16ab610dc9e2" + root := tr.Hash() + require.Equal(t, expected, root.String(), "expected %s, got %s", expected, root.String()) + }) +} + type keyValue struct { key *felt.Felt value *felt.Felt } func nonRandomTrie(t *testing.T, numKeys int) (*Trie, []*keyValue) { - tr := NewTrie(251) + tr := NewTrie(251, crypto.Pedersen) records := make([]*keyValue, numKeys) for i := 1; i < numKeys+1; i++ { @@ -82,7 +194,7 @@ func nonRandomTrie(t *testing.T, numKeys int) (*Trie, []*keyValue) { func randomTrie(t testing.TB, n int) (*Trie, []*keyValue) { rrand := rand.New(rand.NewSource(3)) - tr := NewTrie(251) + tr := NewTrie(251, crypto.Pedersen) records := make([]*keyValue, n) for i := 0; i < n; i++ { @@ -111,7 +223,7 @@ func buildTrie(t *testing.T, records []*keyValue) *Trie { t.Fatal("records must have at least one element") } - tempTrie := NewTrie(251) + tempTrie := NewTrie(251, crypto.Pedersen) for _, record := range records { err := tempTrie.Update(record.key, record.value)