diff --git a/core/trie2/collector.go b/core/trie2/collector.go new file mode 100644 index 0000000000..d4bcca1aa0 --- /dev/null +++ b/core/trie2/collector.go @@ -0,0 +1,120 @@ +package trie2 + +import ( + "fmt" + "sync" + + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/trie2/trienode" +) + +// Used as a tool to collect all dirty nodes in a NodeSet +type collector struct { + nodes *trienode.NodeSet +} + +func newCollector(nodes *trienode.NodeSet) *collector { + return &collector{nodes: nodes} +} + +// Collects the nodes in the node set and collapses a node into a hash node +func (c *collector) Collect(n node, parallel bool) *hashNode { + return c.collect(new(Path), n, parallel).(*hashNode) +} + +func (c *collector) collect(path *Path, n node, parallel bool) node { + // This path has not been modified, just return the cache + hash, dirty := n.cache() + if hash != nil && !dirty { + return hash + } + + // Collect children and then parent + switch cn := n.(type) { + case *edgeNode: + collapsed := cn.copy() + + // If the child is a binary node, recurse into it. + // Otherwise, it can only be a hashNode or valueNode. + // Combination of edge (parent) + edge (child) is not possible. + collapsed.child = c.collect(new(Path).Append(path, cn.path), cn.child, parallel) + return c.store(path, collapsed) + case *binaryNode: + collapsed := cn.copy() + collapsed.children = c.collectChildren(path, cn, parallel) + return c.store(path, collapsed) + case *hashNode: + return cn + case *valueNode: // each leaf node is stored as a single entry in the node set + return c.store(path, cn) + default: + panic(fmt.Sprintf("unknown node type: %T", cn)) + } +} + +// Collects the children of a binary node, may apply parallel processing if configured +func (c *collector) collectChildren(path *Path, n *binaryNode, parallel bool) [2]node { + children := [2]node{} + + // Helper function to process a single child + processChild := func(i int) { + child := n.children[i] + // Return early if it's already a hash node + if hn, ok := child.(*hashNode); ok { + children[i] = hn + return + } + + // Create child path + childPath := new(Path).AppendBit(path, uint8(i)) + + if !parallel { + children[i] = c.collect(childPath, child, parallel) + return + } + + // Parallel processing + childSet := trienode.NewNodeSet(c.nodes.Owner) + childCollector := newCollector(childSet) + children[i] = childCollector.collect(childPath, child, parallel) + c.nodes.MergeSet(childSet) //nolint:errcheck // guaranteed to succeed because same owner + } + + if !parallel { + // Sequential processing + processChild(0) + processChild(1) + return children + } + + // Parallel processing + var wg sync.WaitGroup + var mu sync.Mutex + + for i := range 2 { + wg.Add(1) + go func(idx int) { + defer wg.Done() + mu.Lock() + processChild(idx) + mu.Unlock() + }(i) + } + wg.Wait() + + return children +} + +// Stores the node in the node set and returns the hash node +func (c *collector) store(path *Path, n node) node { + hash, _ := n.cache() + + blob := nodeToBytes(n) + if hash == nil { // this is a value node + c.nodes.Add(*path, trienode.NewNode(felt.Felt{}, blob)) + return n + } + + c.nodes.Add(*path, trienode.NewNode(hash.Felt, blob)) + return hash +} diff --git a/core/trie2/errors.go b/core/trie2/errors.go new file mode 100644 index 0000000000..aace1a44d9 --- /dev/null +++ b/core/trie2/errors.go @@ -0,0 +1,8 @@ +package trie2 + +import "errors" + +var ( + ErrCommitted = errors.New("trie is committed") + ErrEmptyRange = errors.New("empty range") +) diff --git a/core/trie2/hasher.go b/core/trie2/hasher.go new file mode 100644 index 0000000000..6917896b0b --- /dev/null +++ b/core/trie2/hasher.go @@ -0,0 +1,117 @@ +package trie2 + +import ( + "fmt" + "sync" + + "github.com/NethermindEth/juno/core/crypto" +) + +// A tool for shashing nodes in the trie. It supports both sequential and parallel +// hashing modes. +type hasher struct { + hashFn crypto.HashFn // The hash function to use + parallel bool // Whether to hash binary node children in parallel +} + +func newHasher(hash crypto.HashFn, parallel bool) hasher { + return hasher{ + hashFn: hash, + parallel: parallel, + } +} + +// Computes the hash of a node and returns both the hash node and a cached +// version of the original node. If the node already has a cached hash, returns +// that instead of recomputing. +func (h *hasher) hash(n node) (node, node) { + if hash, _ := n.cache(); hash != nil { + return hash, n + } + + switch n := n.(type) { + case *edgeNode: + collapsed, cached := h.hashEdgeChild(n) + hn := &hashNode{Felt: *collapsed.hash(h.hashFn)} + cached.flags.hash = hn + return hn, cached + case *binaryNode: + collapsed, cached := h.hashBinaryChildren(n) + hn := &hashNode{Felt: *collapsed.hash(h.hashFn)} + cached.flags.hash = hn + return hn, cached + case *valueNode, *hashNode: + return n, n + default: + panic(fmt.Sprintf("unknown node type: %T", n)) + } +} + +func (h *hasher) hashEdgeChild(n *edgeNode) (collapsed, cached *edgeNode) { + collapsed, cached = n.copy(), n.copy() + + switch n.child.(type) { + case *edgeNode, *binaryNode: + collapsed.child, cached.child = h.hash(n.child) + } + + return collapsed, cached +} + +func (h *hasher) hashBinaryChildren(n *binaryNode) (collapsed, cached *binaryNode) { + collapsed, cached = n.copy(), n.copy() + + if h.parallel { // TODO(weiihann): double check this parallel strategy + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + if n.children[0] != nil { + collapsed.children[0], cached.children[0] = h.hash(n.children[0]) + } else { + collapsed.children[0], cached.children[0] = nilValueNode, nilValueNode + } + }() + + go func() { + defer wg.Done() + if n.children[1] != nil { + collapsed.children[1], cached.children[1] = h.hash(n.children[1]) + } else { + collapsed.children[1], cached.children[1] = nilValueNode, nilValueNode + } + }() + + wg.Wait() + } else { + if n.children[0] != nil { + collapsed.children[0], cached.children[0] = h.hash(n.children[0]) + } else { + collapsed.children[0], cached.children[0] = nilValueNode, nilValueNode + } + + if n.children[1] != nil { + collapsed.children[1], cached.children[1] = h.hash(n.children[1]) + } else { + collapsed.children[1], cached.children[1] = nilValueNode, nilValueNode + } + } + + return collapsed, cached +} + +// Construct trie proofs and returns the collapsed node (i.e. nodes with hash children) +// and the hashed node. +func (h *hasher) proofHash(original node) (collapsed, hashed node) { + switch n := original.(type) { + case *edgeNode: + en, _ := h.hashEdgeChild(n) + return en, &hashNode{Felt: *en.hash(h.hashFn)} + case *binaryNode: + bn, _ := h.hashBinaryChildren(n) + return bn, &hashNode{Felt: *bn.hash(h.hashFn)} + default: + return n, n + } +} diff --git a/core/trie2/id.go b/core/trie2/id.go new file mode 100644 index 0000000000..ed9a8c517d --- /dev/null +++ b/core/trie2/id.go @@ -0,0 +1,61 @@ +package trie2 + +import ( + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/db" +) + +type TrieType uint8 + +const ( + Empty TrieType = iota + ClassTrie + ContractTrie +) + +// Represents the identifier for uniquely identifying a trie. +type ID struct { + TrieType TrieType + Owner felt.Felt // The contract address which the trie belongs to +} + +// Returns the corresponding DB bucket for the trie +func (id *ID) Bucket() db.Bucket { + switch id.TrieType { + case ClassTrie: + return db.ClassTrie + case ContractTrie: + if id.Owner == (felt.Felt{}) { + return db.ContractTrieContract + } + return db.ContractTrieStorage + case Empty: + return db.Bucket(0) + default: + panic("invalid trie type") + } +} + +// Constructs an identifier for a class trie with the provided class trie root hash +func ClassTrieID() *ID { + return &ID{ + TrieType: ClassTrie, + Owner: felt.Zero, // class trie does not have an owner + } +} + +// Constructs an identifier for a contract trie or a contract's storage trie +func ContractTrieID(owner felt.Felt) *ID { + return &ID{ + TrieType: ContractTrie, + Owner: owner, + } +} + +// A general identifier, typically used for temporary trie +func TrieID() *ID { + return &ID{ + TrieType: Empty, + Owner: felt.Zero, + } +} diff --git a/core/trie2/node.go b/core/trie2/node.go new file mode 100644 index 0000000000..f5fc456fe9 --- /dev/null +++ b/core/trie2/node.go @@ -0,0 +1,123 @@ +package trie2 + +import ( + "bytes" + "fmt" + "strings" + + "github.com/NethermindEth/juno/core/crypto" + "github.com/NethermindEth/juno/core/felt" +) + +var ( + _ node = (*binaryNode)(nil) + _ node = (*edgeNode)(nil) + _ node = (*hashNode)(nil) + _ node = (*valueNode)(nil) +) + +type node interface { + hash(crypto.HashFn) *felt.Felt // TODO(weiihann): return felt value instead of pointers + cache() (*hashNode, bool) + write(*bytes.Buffer) error + String() string +} + +type ( + binaryNode struct { + children [2]node // 0 = left, 1 = right + flags nodeFlag + } + edgeNode struct { + child node + path *Path + flags nodeFlag + } + hashNode struct{ felt.Felt } + valueNode struct{ felt.Felt } +) + +// nilValueNode is used when collapsing internal trie nodes for hashing, since unset children need to be hashed correctly +var nilValueNode = &valueNode{felt.Felt{}} + +type nodeFlag struct { + hash *hashNode + dirty bool +} + +func newFlag() nodeFlag { return nodeFlag{dirty: true} } + +func (n *binaryNode) hash(hf crypto.HashFn) *felt.Felt { + return hf(n.children[0].hash(hf), n.children[1].hash(hf)) +} + +func (n *edgeNode) hash(hf crypto.HashFn) *felt.Felt { + var length [32]byte + length[31] = n.path.Len() + pathFelt := n.path.Felt() + lengthFelt := new(felt.Felt).SetBytes(length[:]) + return new(felt.Felt).Add(hf(n.child.hash(hf), &pathFelt), lengthFelt) +} + +func (n *hashNode) hash(crypto.HashFn) *felt.Felt { return &n.Felt } +func (n *valueNode) hash(crypto.HashFn) *felt.Felt { return &n.Felt } + +func (n *binaryNode) cache() (*hashNode, bool) { return n.flags.hash, n.flags.dirty } +func (n *edgeNode) cache() (*hashNode, bool) { return n.flags.hash, n.flags.dirty } +func (n *hashNode) cache() (*hashNode, bool) { return nil, true } +func (n *valueNode) cache() (*hashNode, bool) { return nil, true } + +func (n *binaryNode) String() string { + var left, right string + if n.children[0] != nil { + left = n.children[0].String() + } + if n.children[1] != nil { + right = n.children[1].String() + } + return fmt.Sprintf("Binary[\n left: %s\n right: %s\n]", + indent(left), + indent(right)) +} + +func (n *edgeNode) String() string { + var child string + if n.child != nil { + child = n.child.String() + } + return fmt.Sprintf("Edge{\n path: %s\n child: %s\n}", + n.path.String(), + indent(child)) +} + +func (n hashNode) String() string { + return fmt.Sprintf("Hash(%s)", n.Felt.String()) +} + +func (n valueNode) String() string { + return fmt.Sprintf("Value(%s)", n.Felt.String()) +} + +// TODO(weiihann): check if we want to return a pointer or a value +func (n *binaryNode) copy() *binaryNode { cpy := *n; return &cpy } +func (n *edgeNode) copy() *edgeNode { cpy := *n; return &cpy } + +func (n *edgeNode) pathMatches(key *Path) bool { + return n.path.EqualMSBs(key) +} + +// Returns the common bits between the current node and the given key, starting from the most significant bit +func (n *edgeNode) commonPath(key *Path) Path { + var commonPath Path + commonPath.CommonMSBs(n.path, key) + return commonPath +} + +// Helper function to indent each line of a string +func indent(s string) string { + lines := strings.Split(s, "\n") + for i, line := range lines { + lines[i] = " " + line + } + return strings.Join(lines, "\n") +} diff --git a/core/trie2/node_enc.go b/core/trie2/node_enc.go new file mode 100644 index 0000000000..83d1e75626 --- /dev/null +++ b/core/trie2/node_enc.go @@ -0,0 +1,184 @@ +package trie2 + +import ( + "bytes" + "errors" + "fmt" + "sync" + + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/trie2/trieutils" +) + +const ( + binaryNodeSize = 2 * hashOrValueNodeSize // LeftHash + RightHash + edgeNodeMaxSize = trieutils.MaxBitArraySize + hashOrValueNodeSize // Path + Child Hash + hashOrValueNodeSize = felt.Bytes +) + +var bufferPool = sync.Pool{ + New: func() interface{} { + return new(bytes.Buffer) + }, +} + +// The initial idea was to differentiate between binary and edge nodes by the size of the buffer. +// However, there may be a case where the size of the buffer is the same for both binary and edge nodes. +// In this case, we need to prepend the buffer with a byte to indicate the type of the node. +// Value and hash nodes do not need this because their size is fixed. +const ( + binaryNodeType byte = iota + 1 + edgeNodeType +) + +// Enc(binary) = binaryNodeType + HashNode(left) + HashNode(right) +func (n *binaryNode) write(buf *bytes.Buffer) error { + if err := buf.WriteByte(binaryNodeType); err != nil { + return err + } + + if err := n.children[0].write(buf); err != nil { + return err + } + + if err := n.children[1].write(buf); err != nil { + return err + } + + return nil +} + +// Enc(edge) = edgeNodeType + HashNode(child) + Path +func (n *edgeNode) write(buf *bytes.Buffer) error { + if err := buf.WriteByte(edgeNodeType); err != nil { + return err + } + + if err := n.child.write(buf); err != nil { + return err + } + + if _, err := n.path.Write(buf); err != nil { + return err + } + + return nil +} + +// Enc(hash) = Felt +func (n *hashNode) write(buf *bytes.Buffer) error { + if _, err := buf.Write(n.Felt.Marshal()); err != nil { + return err + } + + return nil +} + +// Enc(value) = Felt +func (n *valueNode) write(buf *bytes.Buffer) error { + if _, err := buf.Write(n.Felt.Marshal()); err != nil { + return err + } + + return nil +} + +// Returns the encoded bytes of a node +func nodeToBytes(n node) []byte { + buf := bufferPool.Get().(*bytes.Buffer) + buf.Reset() + defer func() { + buf.Reset() + bufferPool.Put(buf) + }() + + if err := n.write(buf); err != nil { + panic(err) + } + + res := make([]byte, buf.Len()) + copy(res, buf.Bytes()) + return res +} + +// Decodes the encoded bytes and returns the corresponding node +func decodeNode(blob []byte, hash *felt.Felt, pathLen, maxPathLen uint8) (node, error) { + if len(blob) == 0 { + return nil, errors.New("cannot decode empty blob") + } + if pathLen > maxPathLen { + return nil, fmt.Errorf("node path length (%d) > max (%d)", pathLen, maxPathLen) + } + + if n, ok := decodeHashOrValueNode(blob, pathLen, maxPathLen); ok { + return n, nil + } + + nodeType := blob[0] + blob = blob[1:] + + switch nodeType { + case binaryNodeType: + return decodeBinaryNode(blob, hash, pathLen, maxPathLen) + case edgeNodeType: + return decodeEdgeNode(blob, hash, pathLen, maxPathLen) + default: + panic(fmt.Sprintf("unknown node type: %d", nodeType)) + } +} + +func decodeHashOrValueNode(blob []byte, pathLen, maxPathLen uint8) (node, bool) { + if len(blob) == hashOrValueNodeSize { + var f felt.Felt + f.SetBytes(blob) + if pathLen == maxPathLen { + return &valueNode{Felt: f}, true + } + return &hashNode{Felt: f}, true + } + return nil, false +} + +func decodeBinaryNode(blob []byte, hash *felt.Felt, pathLen, maxPathLen uint8) (*binaryNode, error) { + if len(blob) < 2*hashOrValueNodeSize { + return nil, fmt.Errorf("invalid binary node size: %d", len(blob)) + } + + binary := &binaryNode{} + if hash != nil { + binary.flags.hash = &hashNode{Felt: *hash} + } + + var err error + if binary.children[0], err = decodeNode(blob[:hashOrValueNodeSize], hash, pathLen+1, maxPathLen); err != nil { + return nil, err + } + if binary.children[1], err = decodeNode(blob[hashOrValueNodeSize:], hash, pathLen+1, maxPathLen); err != nil { + return nil, err + } + return binary, nil +} + +func decodeEdgeNode(blob []byte, hash *felt.Felt, pathLen, maxPathLen uint8) (*edgeNode, error) { + if len(blob) > edgeNodeMaxSize || len(blob) < hashOrValueNodeSize { + return nil, fmt.Errorf("invalid edge node size: %d", len(blob)) + } + + edge := &edgeNode{path: &trieutils.BitArray{}} + if hash != nil { + edge.flags.hash = &hashNode{Felt: *hash} + } + + var err error + if edge.child, err = decodeNode(blob[:hashOrValueNodeSize], nil, pathLen, maxPathLen); err != nil { + return nil, err + } + if err := edge.path.UnmarshalBinary(blob[hashOrValueNodeSize:]); err != nil { + return nil, err + } + + if pathLen+edge.path.Len() == maxPathLen { + edge.child = &valueNode{Felt: edge.child.(*hashNode).Felt} + } + return edge, nil +} diff --git a/core/trie2/node_enc_test.go b/core/trie2/node_enc_test.go new file mode 100644 index 0000000000..a002606e87 --- /dev/null +++ b/core/trie2/node_enc_test.go @@ -0,0 +1,170 @@ +package trie2 + +import ( + "bytes" + "testing" + + "github.com/NethermindEth/juno/core/crypto" + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/trie2/trieutils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNodeEncodingDecoding(t *testing.T) { + newFelt := func(v uint64) felt.Felt { + var f felt.Felt + f.SetUint64(v) + return f + } + + // Helper to create a path with specific bits + newPath := func(bits ...uint8) *trieutils.BitArray { + path := &trieutils.BitArray{} + for _, bit := range bits { + path.AppendBit(path, bit) + } + return path + } + + tests := []struct { + name string + node node + pathLen uint8 + maxPath uint8 + wantErr bool + errMsg string + }{ + { + name: "edge node with value child", + node: &edgeNode{ + path: newPath(1, 0, 1), + child: &valueNode{Felt: newFelt(123)}, + }, + pathLen: 8, + maxPath: 8, + }, + { + name: "edge node with hash child", + node: &edgeNode{ + path: newPath(0, 1, 0), + child: &hashNode{Felt: newFelt(456)}, + }, + pathLen: 3, + maxPath: 8, + }, + { + name: "binary node with two hash children", + node: &binaryNode{ + children: [2]node{ + &hashNode{Felt: newFelt(111)}, + &hashNode{Felt: newFelt(222)}, + }, + }, + pathLen: 0, + maxPath: 8, + }, + { + name: "binary node with two leaf children", + node: &binaryNode{ + children: [2]node{ + &valueNode{Felt: newFelt(555)}, + &valueNode{Felt: newFelt(666)}, + }, + }, + pathLen: 7, + maxPath: 8, + }, + { + name: "value node at max path length", + node: &valueNode{Felt: newFelt(999)}, + pathLen: 8, + maxPath: 8, + }, + { + name: "hash node below max path length", + node: &hashNode{Felt: newFelt(1000)}, + pathLen: 7, + maxPath: 8, + }, + { + name: "edge node with empty path", + node: &edgeNode{ + path: newPath(), + child: &hashNode{Felt: newFelt(1111)}, + }, + pathLen: 0, + maxPath: 8, + }, + { + name: "edge node with max length path", + node: &edgeNode{ + path: newPath(1, 1, 1, 1, 1, 1, 1, 1), + child: &valueNode{Felt: newFelt(2222)}, + }, + pathLen: 0, + maxPath: 8, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Encode the node + encoded := nodeToBytes(tt.node) + + // Try to decode + hash := tt.node.hash(crypto.Pedersen) + decoded, err := decodeNode(encoded, hash, tt.pathLen, tt.maxPath) + + if tt.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errMsg) + return + } + + require.NoError(t, err) + require.NotNil(t, decoded) + + // Re-encode the decoded node and compare with original encoding + reEncoded := nodeToBytes(decoded) + assert.True(t, bytes.Equal(encoded, reEncoded), "re-encoded node doesn't match original encoding") + + // Test specific node type assertions and properties + switch n := decoded.(type) { + case *edgeNode: + original := tt.node.(*edgeNode) + assert.Equal(t, original.path.String(), n.path.String(), "edge node paths don't match") + case *binaryNode: + // Verify both children are present + assert.NotNil(t, n.children[0], "left child is nil") + assert.NotNil(t, n.children[1], "right child is nil") + case *valueNode: + original := tt.node.(*valueNode) + assert.True(t, original.Felt.Equal(&n.Felt), "value node felts don't match") + case *hashNode: + original := tt.node.(*hashNode) + assert.True(t, original.Felt.Equal(&n.Felt), "hash node felts don't match") + } + }) + } +} + +func TestNodeEncodingDecodingBoundary(t *testing.T) { + // Test empty/nil nodes + t.Run("nil node encoding", func(t *testing.T) { + assert.Panics(t, func() { nodeToBytes(nil) }) + }) + + // Test with invalid path lengths + t.Run("invalid path lengths", func(t *testing.T) { + blob := make([]byte, hashOrValueNodeSize) + _, err := decodeNode(blob, nil, 255, 8) // pathLen > maxPath + require.Error(t, err) + }) + + // Test with empty buffer + t.Run("empty buffer", func(t *testing.T) { + _, err := decodeNode([]byte{}, nil, 0, 8) + require.Error(t, err) + }) +} diff --git a/core/trie2/proof.go b/core/trie2/proof.go new file mode 100644 index 0000000000..49210cb210 --- /dev/null +++ b/core/trie2/proof.go @@ -0,0 +1,596 @@ +package trie2 + +import ( + "errors" + "fmt" + + "github.com/NethermindEth/juno/core/crypto" + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/utils" +) + +type ProofNodeSet = utils.OrderedSet[felt.Felt, node] + +func NewProofNodeSet() *ProofNodeSet { + return utils.NewOrderedSet[felt.Felt, node]() +} + +// Prove generates a Merkle proof for a given key in the trie. +// The result contains the proof nodes on the path from the root to the leaf. +// The value is included in the proof if the key is present in the trie. +// If the key is not present, the proof will contain the nodes on the path to the closest ancestor. +func (t *Trie) Prove(key *felt.Felt, proof *ProofNodeSet) error { + if t.committed { + return ErrCommitted + } + + path := t.FeltToPath(key) + k := &path + + var ( + nodes []node + prefix = new(Path) + rn = t.root + ) + + for k.Len() > 0 && rn != nil { + switch n := rn.(type) { + case *edgeNode: + if !n.pathMatches(k) { + rn = nil // Trie doesn't contain the key + } else { + rn = n.child + prefix.Append(prefix, n.path) + k.LSBs(k, n.path.Len()) + } + nodes = append(nodes, n) + case *binaryNode: + bit := k.MSB() + rn = n.children[bit] + prefix.AppendBit(prefix, bit) + k.LSBs(k, 1) + nodes = append(nodes, n) + case *hashNode: + resolved, err := t.resolveNode(n, *prefix) + if err != nil { + return err + } + rn = resolved + default: + panic(fmt.Sprintf("key: %s, unknown node type: %T", key.String(), n)) + } + } + + // TODO: ideally Hash() should be called before Prove() so that the hashes are cached + // There should be a better way to do this + hasher := newHasher(t.hashFn, false) + for i, n := range nodes { + var hn node + n, hn = hasher.proofHash(n) + if hash, ok := hn.(*hashNode); ok || i == 0 { + if !ok { + hash = &hashNode{Felt: *n.hash(hasher.hashFn)} + } + proof.Put(hash.Felt, n) + } + } + + return nil +} + +// GetRangeProof generates a range proof for the given range of keys. +// The proof contains the proof nodes on the path from the root to the closest ancestor of the left and right keys. +func (t *Trie) GetRangeProof(leftKey, rightKey *felt.Felt, proofSet *ProofNodeSet) error { + err := t.Prove(leftKey, proofSet) + if err != nil { + return err + } + + // If they are the same key, don't need to generate the proof again + if leftKey.Equal(rightKey) { + return nil + } + + err = t.Prove(rightKey, proofSet) + if err != nil { + return err + } + + return nil +} + +// VerifyProof verifies that a proof path is valid for a given key in a binary trie. +// It walks through the proof nodes, verifying each step matches the expected path to reach the key. +// +// The proof is considered invalid if: +// - Any proof node is missing from the node set +// - Any node's computed hash doesn't match its expected hash +// - The path bits don't match the key bits +// - The proof ends before processing all key bits +func VerifyProof(root, key *felt.Felt, proof *ProofNodeSet, hash crypto.HashFn) (felt.Felt, error) { + keyBits := new(Path).SetFelt(contractClassTrieHeight, key) + expected := *root + h := newHasher(hash, false) + + for { + node, ok := proof.Get(expected) + if !ok { + return felt.Zero, fmt.Errorf("proof node not found, expected hash: %s", expected.String()) + } + + nHash, _ := h.hash(node) + + // Verify the hash matches + if !nHash.(*hashNode).Felt.Equal(&expected) { + return felt.Zero, fmt.Errorf("proof node hash mismatch, expected hash: %s, got hash: %s", expected.String(), nHash.String()) + } + + child := get(node, keyBits, false) + switch cld := child.(type) { + case nil: + return felt.Zero, nil + case *hashNode: + expected = cld.Felt + case *valueNode: + return cld.Felt, nil + case *edgeNode, *binaryNode: + if hash, _ := cld.cache(); hash != nil { + expected = hash.Felt + } + } + } +} + +// VerifyRangeProof checks the validity of given key-value pairs and range proof against a provided root hash. +// The key-value pairs should be consecutive (no gaps) and monotonically increasing. +// The range proof contains two edge proofs: one for the first key and another for the last key. +// Both edge proofs can be for existent or non-existent keys. +// This function handles the following special cases: +// +// - All elements proof: The proof can be nil if the range includes all leaves in the trie. +// - Single element proof: Both left and right edge proofs are identical, and the range contains only one element. +// - Zero element proof: A single edge proof suffices for verification. The proof is invalid if there are additional elements. +// +// The function returns a boolean indicating if there are more elements and an error if the range proof is invalid. +func VerifyRangeProof(rootHash, first *felt.Felt, keys, values []*felt.Felt, proof *ProofNodeSet) (bool, error) { //nolint:funlen,gocyclo + // Ensure the number of keys and values are the same + if len(keys) != len(values) { + return false, fmt.Errorf("inconsistent length of proof data, keys: %d, values: %d", len(keys), len(values)) + } + + // Ensure all keys are monotonically increasing and values contain no deletions + for i := range keys { + if i < len(keys)-1 && keys[i].Cmp(keys[i+1]) > 0 { + return false, errors.New("keys are not monotonic increasing") + } + + if values[i] == nil || values[i].Equal(&felt.Zero) { + return false, errors.New("range contains empty leaf") + } + } + + // Special case: no edge proof provided; the given range contains all leaves in the trie + if proof == nil { + tr := NewEmpty(contractClassTrieHeight, crypto.Pedersen) + for i, key := range keys { + if err := tr.Update(key, values[i]); err != nil { + return false, err + } + } + + recomputedRoot := tr.Hash() + if !recomputedRoot.Equal(rootHash) { + return false, fmt.Errorf("root hash mismatch, expected: %s, got: %s", rootHash.String(), recomputedRoot.String()) + } + + return false, nil // no more elements available + } + + var firstKey Path + firstKey.SetFelt(contractClassTrieHeight, first) + + // Special case: there is a provided proof but no key-value pairs, make sure regenerated trie has no more values + // Empty range proof with more elements on the right is not accepted in this function. + // This is due to snap sync specification detail, where the responder must send an existing key (if any) if the requested range is empty. + if len(keys) == 0 { + rootKey, val, err := proofToPath(rootHash, nil, firstKey, proof, true) + if err != nil { + return false, err + } + + if val != nil || hasRightElement(rootKey, firstKey) { + return false, errors.New("more entries available") + } + + return false, nil + } + + last := keys[len(keys)-1] + + var lastKey Path + lastKey.SetFelt(contractClassTrieHeight, last) + + // Special case: there is only one element and two edge keys are the same + if len(keys) == 1 && firstKey.Equal(&lastKey) { + root, val, err := proofToPath(rootHash, nil, firstKey, proof, false) + if err != nil { + return false, err + } + + firstItemKey := new(Path).SetFelt(contractClassTrieHeight, keys[0]) + if !firstKey.Equal(firstItemKey) { + return false, errors.New("correct proof but invalid key") + } + + if val == nil || !values[0].Equal(val) { + return false, errors.New("correct proof but invalid value") + } + + return hasRightElement(root, firstKey), nil + } + + // In all other cases, we require two edge paths available. + // First, ensure that the last key is greater than the first key + if last.Cmp(first) <= 0 { + return false, errors.New("last key is less than first key") + } + + root, _, err := proofToPath(rootHash, nil, firstKey, proof, true) + if err != nil { + return false, err + } + + root, _, err = proofToPath(rootHash, root, lastKey, proof, true) + if err != nil { + return false, err + } + + empty, err := unsetInternal(root, firstKey, lastKey) + if err != nil { + return false, err + } + + tr := NewEmpty(contractClassTrieHeight, crypto.Pedersen) + if !empty { + tr.root = root + } + + for i, key := range keys { + if err := tr.Update(key, values[i]); err != nil { + return false, err + } + } + + newRoot := tr.Hash() + + // Verify that the recomputed root hash matches the provided root hash + if !newRoot.Equal(rootHash) { + return false, fmt.Errorf("root hash mismatch, expected: %s, got: %s", rootHash.String(), newRoot.String()) + } + + return hasRightElement(root, lastKey), nil +} + +// proofToPath converts a Merkle proof to trie node path. All necessary nodes will be resolved and leave the remaining +// as hashes. The given edge proof can be existent or non-existent. +func proofToPath(rootHash *felt.Felt, root node, keyBits Path, proof *ProofNodeSet, allowNonExistent bool) (node, *felt.Felt, error) { + // Retrieves the node from the proof node set given the node hash + retrieveNode := func(hash felt.Felt) (node, error) { + n, _ := proof.Get(hash) + if n == nil { + return nil, fmt.Errorf("proof node not found, expected hash: %s", hash.String()) + } + return n, nil + } + + // Must resolve the root node first if it's not provided + if root == nil { + n, err := retrieveNode(*rootHash) + if err != nil { + return nil, nil, err + } + root = n + } + + var ( + err error + child, parent node + val *felt.Felt + ) + + parent = root + for { + msb := keyBits.MSB() // key gets modified in get(), we need the current msb to get the correct child during linking + child = get(parent, &keyBits, false) + switch n := child.(type) { + case nil: + if allowNonExistent { + return root, nil, nil + } + return nil, nil, errors.New("the node is not in the trie") + case *edgeNode: + parent = child + continue + case *binaryNode: + parent = child + continue + case *hashNode: + child, err = retrieveNode(n.Felt) + if err != nil { + return nil, nil, err + } + case *valueNode: + val = &n.Felt + } + // Link the parent and child + switch p := parent.(type) { + case *edgeNode: + p.child = child + case *binaryNode: + p.children[msb] = child + default: + panic(fmt.Sprintf("unknown parent node type: %T", p)) + } + + // We reached the leaf + if val != nil { + return root, val, nil + } + + parent = child + } +} + +// unsetInternal removes all internal node references (hashNode, embedded node). +// It should be called after a trie is constructed with two edge paths. Also +// the given boundary keys must be the ones used to construct the edge paths. +// +// It's the key step for range proof. All visited nodes should be marked dirty +// since the node content might be modified. +// +// Note we have the assumption here the given boundary keys are different +// and right is larger than left. +// +//nolint:gocyclo +func unsetInternal(n node, left, right Path) (bool, error) { + // Step down to the fork point. There are two scenarios that can happen: + // - the fork point is an edgeNode: either the key of left proof or + // right proof doesn't match with the edge node's path + // - the fork point is a binaryNode: both two edge proofs are allowed + // to point to a non-existent key + var ( + pos uint8 + parent node + + // fork indicator for edge nodes + edgeForkLeft, edgeForkRight int + ) + +findFork: + for { + switch rn := n.(type) { + case *edgeNode: + rn.flags = newFlag() + + if left.Len()-pos < rn.path.Len() { + edgeForkLeft = new(Path).LSBs(&left, pos).Cmp(rn.path) + } else { + subKey := new(Path).Subset(&left, pos, pos+rn.path.Len()) + edgeForkLeft = subKey.Cmp(rn.path) + } + + if right.Len()-pos < rn.path.Len() { + edgeForkRight = new(Path).LSBs(&right, pos).Cmp(rn.path) + } else { + subKey := new(Path).Subset(&right, pos, pos+rn.path.Len()) + edgeForkRight = subKey.Cmp(rn.path) + } + + if edgeForkLeft != 0 || edgeForkRight != 0 { + break findFork + } + + parent = n + n = rn.child + pos += rn.path.Len() + case *binaryNode: + rn.flags = newFlag() + + leftnode, rightnode := rn.children[left.Bit(pos)], rn.children[right.Bit(pos)] + if leftnode == nil || rightnode == nil || leftnode != rightnode { + break findFork + } + parent = n + n = rn.children[left.Bit(pos)] + pos++ + default: + panic(fmt.Sprintf("%T: invalid node: %v", n, n)) + } + } + + switch rn := n.(type) { + case *edgeNode: + // There can have these five scenarios: + // - both proofs are less than the trie path => no valid range + // - both proofs are greater than the trie path => no valid range + // - left proof is less and right proof is greater => valid range, unset the shortnode entirely + // - left proof points to the shortnode, but right proof is greater + // - right proof points to the shortnode, but left proof is less + if edgeForkLeft == -1 && edgeForkRight == -1 { + return false, ErrEmptyRange + } + if edgeForkLeft == 1 && edgeForkRight == 1 { + return false, ErrEmptyRange + } + if edgeForkLeft != 0 && edgeForkRight != 0 { + // The fork point is root node, unset the entire trie + if parent == nil { + return true, nil + } + parent.(*binaryNode).children[left.Bit(pos-1)] = nil + return false, nil + } + // Only one proof points to non-existent key + if edgeForkRight != 0 { + if _, ok := rn.child.(*valueNode); ok { + // The fork point is root node, unset the entire trie + if parent == nil { + return true, nil + } + parent.(*binaryNode).children[left.Bit(pos-1)] = nil + return false, nil + } + return false, unset(rn, rn.child, new(Path).LSBs(&left, pos), rn.path.Len(), false) + } + if edgeForkLeft != 0 { + if _, ok := rn.child.(*valueNode); ok { + // The fork point is root node, unset the entire trie + if parent == nil { + return true, nil + } + parent.(*binaryNode).children[right.Bit(pos-1)] = nil + return false, nil + } + return false, unset(rn, rn.child, new(Path).LSBs(&right, pos), rn.path.Len(), true) + } + return false, nil + case *binaryNode: + leftBit := left.Bit(pos) + rightBit := right.Bit(pos) + if leftBit == 0 && rightBit == 0 { + rn.children[1] = nil + } + if leftBit == 1 && rightBit == 1 { + rn.children[0] = nil + } + if err := unset(rn, rn.children[leftBit], new(Path).LSBs(&left, pos), 1, false); err != nil { + return false, err + } + if err := unset(rn, rn.children[rightBit], new(Path).LSBs(&right, pos), 1, true); err != nil { + return false, err + } + return false, nil + default: + panic(fmt.Sprintf("%T: invalid node: %v", n, n)) + } +} + +// unset removes all internal node references either the left most or right most. +func unset(parent, child node, key *Path, pos uint8, removeLeft bool) error { + switch cld := child.(type) { + case *binaryNode: + keyBit := key.Bit(pos) + cld.flags = newFlag() // Mark dirty + if removeLeft { + // Remove left child if we're removing left side + if keyBit == 1 { + cld.children[0] = nil + } + } else { + // Remove right child if we're removing right side + if keyBit == 0 { + cld.children[1] = nil + } + } + return unset(cld, cld.children[key.Bit(pos)], key, pos+1, removeLeft) + + case *edgeNode: + keyPos := new(Path).LSBs(key, pos) + keyBit := key.Bit(pos - 1) + if !cld.pathMatches(keyPos) { + // We append zeros to the path to match the length of the remaining key + // The key length is guaranteed to be >= path length + edgePath := new(Path).AppendZeros(cld.path, keyPos.Len()-cld.path.Len()) + // Found fork point, non-existent branch + if removeLeft { + if edgePath.Cmp(keyPos) < 0 { + // Edge node path is in range, unset entire branch + parent.(*binaryNode).children[keyBit] = nil + } + } else { + if edgePath.Cmp(keyPos) > 0 { + parent.(*binaryNode).children[keyBit] = nil + } + } + return nil + } + if _, ok := cld.child.(*valueNode); ok { + parent.(*binaryNode).children[keyBit] = nil + return nil + } + cld.flags = newFlag() + return unset(cld, cld.child, key, pos+cld.path.Len(), removeLeft) + + case nil, *hashNode, *valueNode: + // Child is nil, nothing to unset + return nil + default: + panic("it shouldn't happen") // hashNode, valueNode + } +} + +// hasRightElement checks if there is a right sibling for the given key in the trie. +// This function assumes that the entire path has been resolved. +func hasRightElement(node node, key Path) bool { + for node != nil { + switch n := node.(type) { + case *binaryNode: + bit := key.MSB() + if bit == 0 && n.children[1] != nil { + // right sibling exists + return true + } + node = n.children[bit] + key.LSBs(&key, 1) + case *edgeNode: + if !n.pathMatches(&key) { + // There's a divergence in the path, check if the node path is greater than the key + // If so, that means that this node comes after the search key, which indicates that + // there are elements with larger values + var edgePath *Path + if key.Len() > n.path.Len() { + edgePath = new(Path).AppendZeros(n.path, key.Len()-n.path.Len()) + } else { + edgePath = n.path + } + return edgePath.Cmp(&key) > 0 + } + node = n.child + key.LSBs(&key, n.path.Len()) + case *valueNode: + return false // resolved the whole path + default: + panic(fmt.Sprintf("unknown node type: %T", n)) + } + } + return false +} + +// Resolves the whole path of the given key and node. +// If skipResolved is true, it will only return the immediate child node of the current node +func get(rn node, key *Path, skipResolved bool) node { + for { + switch n := rn.(type) { + case *edgeNode: + if !n.pathMatches(key) { + return nil + } + rn = n.child + key.LSBs(key, n.path.Len()) + case *binaryNode: + bit := key.MSB() + rn = n.children[bit] + key.LSBs(key, 1) + case *hashNode: + return n + case *valueNode: + return n + case nil: + return nil + } + + if !skipResolved { + return rn + } + } +} diff --git a/core/trie2/proof_test.go b/core/trie2/proof_test.go new file mode 100644 index 0000000000..69390425d7 --- /dev/null +++ b/core/trie2/proof_test.go @@ -0,0 +1,716 @@ +package trie2 + +import ( + "math/rand" + "testing" + + "github.com/NethermindEth/juno/core/crypto" + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/utils" + "github.com/stretchr/testify/require" +) + +func TestProve(t *testing.T) { + n := 1000 + tempTrie, records := nonRandomTrie(t, n) + + for _, record := range records { + root := tempTrie.Hash() + + proofSet := NewProofNodeSet() + err := tempTrie.Prove(record.key, proofSet) + require.NoError(t, err) + + val, err := VerifyProof(&root, record.key, proofSet, crypto.Pedersen) + if err != nil { + t.Fatalf("failed for key %s", record.key.String()) + } + + if !val.Equal(record.value) { + t.Fatalf("expected value %s, got %s", record.value.String(), val.String()) + } + } +} + +func TestProveNonExistent(t *testing.T) { + n := 1000 + tempTrie, _ := nonRandomTrie(t, n) + + for i := 1; i < n+1; i++ { + root := tempTrie.Hash() + + keyFelt := new(felt.Felt).SetUint64(uint64(i + n)) + proofSet := NewProofNodeSet() + err := tempTrie.Prove(keyFelt, proofSet) + require.NoError(t, err) + + val, err := VerifyProof(&root, keyFelt, proofSet, crypto.Pedersen) + if err != nil { + t.Fatalf("failed for key %s", keyFelt.String()) + } + require.Equal(t, felt.Zero, val) + } +} + +func TestProveRandom(t *testing.T) { + tempTrie, records := randomTrie(t, 1000) + + for _, record := range records { + root := tempTrie.Hash() + + proofSet := NewProofNodeSet() + err := tempTrie.Prove(record.key, proofSet) + require.NoError(t, err) + + val, err := VerifyProof(&root, record.key, proofSet, crypto.Pedersen) + require.NoError(t, err) + + if !val.Equal(record.value) { + t.Fatalf("expected value %s, got %s", record.value.String(), val.String()) + } + } +} + +func TestProveCustom(t *testing.T) { + tests := []testTrie{ + { + name: "simple binary", + buildFn: buildSimpleTrie, + testKeys: []testKey{ + { + name: "prove existing key", + key: new(felt.Felt).SetUint64(1), + expected: new(felt.Felt).SetUint64(3), + }, + }, + }, + { + name: "simple double binary", + buildFn: buildSimpleDoubleBinaryTrie, + testKeys: []testKey{ + { + name: "prove existing key 0", + key: new(felt.Felt).SetUint64(0), + expected: new(felt.Felt).SetUint64(2), + }, + { + name: "prove existing key 3", + key: new(felt.Felt).SetUint64(3), + expected: new(felt.Felt).SetUint64(5), + }, + { + name: "prove non-existent key 2", + key: new(felt.Felt).SetUint64(2), + expected: new(felt.Felt).SetUint64(0), + }, + { + name: "prove non-existent key 123", + key: new(felt.Felt).SetUint64(123), + expected: new(felt.Felt).SetUint64(0), + }, + }, + }, + { + name: "simple binary root", + buildFn: buildSimpleBinaryRootTrie, + testKeys: []testKey{ + { + name: "prove existing key", + key: new(felt.Felt).SetUint64(0), + expected: utils.HexToFelt(t, "0xcc"), + }, + }, + }, + { + name: "left-right edge", + buildFn: func(t *testing.T) (*Trie, []*keyValue) { + tr, err := NewEmptyPedersen() + require.NoError(t, err) + + records := []*keyValue{ + {key: utils.HexToFelt(t, "0xff"), value: utils.HexToFelt(t, "0xaa")}, + } + + for _, record := range records { + err = tr.Update(record.key, record.value) + require.NoError(t, err) + } + return tr, records + }, + testKeys: []testKey{ + { + name: "prove existing key", + key: utils.HexToFelt(t, "0xff"), + expected: utils.HexToFelt(t, "0xaa"), + }, + }, + }, + { + name: "three key trie", + buildFn: build3KeyTrie, + testKeys: []testKey{ + { + name: "prove existing key", + key: new(felt.Felt).SetUint64(2), + expected: new(felt.Felt).SetUint64(6), + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + tr, _ := test.buildFn(t) + + for _, tc := range test.testKeys { + t.Run(tc.name, func(t *testing.T) { + proofSet := NewProofNodeSet() + + root := tr.Hash() + err := tr.Prove(tc.key, proofSet) + require.NoError(t, err) + + val, err := VerifyProof(&root, tc.key, proofSet, crypto.Pedersen) + require.NoError(t, err) + if !val.Equal(tc.expected) { + t.Fatalf("expected value %s, got %s", tc.expected.String(), val.String()) + } + }) + } + }) + } +} + +// TestRangeProof tests normal range proof with both edge proofs +func TestRangeProof(t *testing.T) { + n := 500 + tr, records := randomTrie(t, n) + root := tr.Hash() + + for i := 0; i < 100; i++ { + start := rand.Intn(n) + end := rand.Intn(n-start) + start + 1 + + proof := NewProofNodeSet() + err := tr.GetRangeProof(records[start].key, records[end-1].key, proof) + require.NoError(t, err) + + keys := []*felt.Felt{} + values := []*felt.Felt{} + for i := start; i < end; i++ { + keys = append(keys, records[i].key) + values = append(values, records[i].value) + } + + _, err = VerifyRangeProof(&root, records[start].key, keys, values, proof) + require.NoError(t, err) + } +} + +// TestRangeProofWithNonExistentProof tests normal range proof with non-existent proofs +func TestRangeProofWithNonExistentProof(t *testing.T) { + n := 500 + tr, records := randomTrie(t, n) + root := tr.Hash() + + for i := 0; i < 100; i++ { + start := rand.Intn(n) + end := rand.Intn(n-start) + start + 1 + + first := decrementFelt(records[start].key) + if start != 0 && first.Equal(records[start-1].key) { + continue + } + + proof := NewProofNodeSet() + err := tr.GetRangeProof(first, records[end-1].key, proof) + require.NoError(t, err) + + keys := make([]*felt.Felt, end-start) + values := make([]*felt.Felt, end-start) + for i := start; i < end; i++ { + keys[i-start] = records[i].key + values[i-start] = records[i].value + } + + _, err = VerifyRangeProof(&root, first, keys, values, proof) + require.NoError(t, err) + } +} + +// TestRangeProofWithInvalidNonExistentProof tests range proof with invalid non-existent proofs. +// One scenario is when there is a gap between the first element and the left edge proof. +func TestRangeProofWithInvalidNonExistentProof(t *testing.T) { + n := 500 + tr, records := randomTrie(t, n) + root := tr.Hash() + + start, end := 100, 200 + first := decrementFelt(records[start].key) + + proof := NewProofNodeSet() + err := tr.GetRangeProof(first, records[end-1].key, proof) + require.NoError(t, err) + + start = 105 // Gap created + keys := make([]*felt.Felt, end-start) + values := make([]*felt.Felt, end-start) + for i := start; i < end; i++ { + keys[i-start] = records[i].key + values[i-start] = records[i].value + } + + _, err = VerifyRangeProof(&root, first, keys, values, proof) + require.Error(t, err) +} + +func TestRangeProofCustom(t *testing.T) { + tr, records := build4KeysTrieD(t) + root := tr.Hash() + + proof := NewProofNodeSet() + err := tr.GetRangeProof(records[0].key, records[2].key, proof) + require.NoError(t, err) + + _, err = VerifyRangeProof(&root, records[0].key, []*felt.Felt{records[0].key, records[1].key, records[2].key, records[3].key}, []*felt.Felt{records[0].value, records[1].value, records[2].value, records[3].value}, proof) + require.NoError(t, err) +} + +func TestOneElementRangeProof(t *testing.T) { + n := 1000 + tr, records := randomTrie(t, n) + root := tr.Hash() + + t.Run("both edge proofs with the same key", func(t *testing.T) { + start := 100 + proof := NewProofNodeSet() + err := tr.GetRangeProof(records[start].key, records[start].key, proof) + require.NoError(t, err) + + _, err = VerifyRangeProof(&root, records[start].key, []*felt.Felt{records[start].key}, []*felt.Felt{records[start].value}, proof) + require.NoError(t, err) + }) + + t.Run("left non-existent edge proof", func(t *testing.T) { + start := 100 + proof := NewProofNodeSet() + err := tr.GetRangeProof(decrementFelt(records[start].key), records[start].key, proof) + require.NoError(t, err) + + _, err = VerifyRangeProof(&root, decrementFelt(records[start].key), []*felt.Felt{records[start].key}, []*felt.Felt{records[start].value}, proof) + require.NoError(t, err) + }) + + t.Run("right non-existent edge proof", func(t *testing.T) { + end := 100 + proof := NewProofNodeSet() + err := tr.GetRangeProof(records[end].key, incrementFelt(records[end].key), proof) + require.NoError(t, err) + + _, err = VerifyRangeProof(&root, records[end].key, []*felt.Felt{records[end].key}, []*felt.Felt{records[end].value}, proof) + require.NoError(t, err) + }) + + t.Run("both non-existent edge proofs", func(t *testing.T) { + start := 100 + first, last := decrementFelt(records[start].key), incrementFelt(records[start].key) + proof := NewProofNodeSet() + err := tr.GetRangeProof(first, last, proof) + require.NoError(t, err) + + _, err = VerifyRangeProof(&root, first, []*felt.Felt{records[start].key}, []*felt.Felt{records[start].value}, proof) + require.NoError(t, err) + }) + + t.Run("1 key trie", func(t *testing.T) { + tr, records := build1KeyTrie(t) + root := tr.Hash() + + proof := NewProofNodeSet() + err := tr.GetRangeProof(&felt.Zero, records[0].key, proof) + require.NoError(t, err) + + _, err = VerifyRangeProof(&root, records[0].key, []*felt.Felt{records[0].key}, []*felt.Felt{records[0].value}, proof) + require.NoError(t, err) + }) +} + +// TestAllElementsRangeProof tests the range proof with all elements and nil proof. +func TestAllElementsRangeProof(t *testing.T) { + n := 1000 + tr, records := randomTrie(t, n) + root := tr.Hash() + + keys := make([]*felt.Felt, n) + values := make([]*felt.Felt, n) + for i, record := range records { + keys[i] = record.key + values[i] = record.value + } + + _, err := VerifyRangeProof(&root, nil, keys, values, nil) + require.NoError(t, err) + + // Should also work with proof + proof := NewProofNodeSet() + err = tr.GetRangeProof(records[0].key, records[n-1].key, proof) + require.NoError(t, err) + + _, err = VerifyRangeProof(&root, keys[0], keys, values, proof) + require.NoError(t, err) +} + +// TestSingleSideRangeProof tests the range proof starting with zero. +func TestSingleSideRangeProof(t *testing.T) { + tr, records := randomTrie(t, 1000) + root := tr.Hash() + + for i := 0; i < len(records); i += 100 { + proof := NewProofNodeSet() + err := tr.GetRangeProof(&felt.Zero, records[i].key, proof) + require.NoError(t, err) + + keys := make([]*felt.Felt, i+1) + values := make([]*felt.Felt, i+1) + for j := range i + 1 { + keys[j] = records[j].key + values[j] = records[j].value + } + + _, err = VerifyRangeProof(&root, &felt.Zero, keys, values, proof) + require.NoError(t, err) + } +} + +func TestGappedRangeProof(t *testing.T) { + tr, records := nonRandomTrie(t, 5) + root := tr.Hash() + + first, last := 1, 4 + proof := NewProofNodeSet() + err := tr.GetRangeProof(records[first].key, records[last].key, proof) + require.NoError(t, err) + + keys := []*felt.Felt{} + values := []*felt.Felt{} + for i := first; i <= last; i++ { + if i == (first+last)/2 { + continue + } + + keys = append(keys, records[i].key) + values = append(values, records[i].value) + } + + _, err = VerifyRangeProof(&root, records[first].key, keys, values, proof) + require.Error(t, err) +} + +func TestEmptyRangeProof(t *testing.T) { + tr, records := randomTrie(t, 1000) + root := tr.Hash() + + cases := []struct { + pos int + err bool + }{ + {len(records) - 1, false}, + {500, true}, + } + + for _, c := range cases { + proof := NewProofNodeSet() + first := incrementFelt(records[c.pos].key) + err := tr.GetRangeProof(first, first, proof) + require.NoError(t, err) + + _, err = VerifyRangeProof(&root, first, nil, nil, proof) + if c.err { + require.Error(t, err) + } else { + require.NoError(t, err) + } + } +} + +func TestHasRightElement(t *testing.T) { + tr, records := randomTrie(t, 10000) + root := tr.Hash() + + cases := []struct { + start int + end int + hasMore bool + }{ + {-1, 1, true}, // single element with non-existent left proof + {0, 1, true}, // single element with existent left proof + {0, 100, true}, // start to middle + {50, 100, true}, // middle only + {50, len(records), false}, // middle to end + {len(records) - 1, len(records), false}, // Single last element with two existent proofs(point to same key) + {0, len(records), false}, // The whole set with existent left proof + {-1, len(records), false}, // The whole set with non-existent left proof + } + + for _, c := range cases { + var ( + first *felt.Felt + start = c.start + end = c.end + proof = NewProofNodeSet() + ) + if start == -1 { + first = &felt.Zero + start = 0 + } else { + first = records[start].key + } + + err := tr.GetRangeProof(first, records[end-1].key, proof) + require.NoError(t, err) + + keys := []*felt.Felt{} + values := []*felt.Felt{} + for i := start; i < end; i++ { + keys = append(keys, records[i].key) + values = append(values, records[i].value) + } + + hasMore, err := VerifyRangeProof(&root, first, keys, values, proof) + require.NoError(t, err) + require.Equal(t, c.hasMore, hasMore) + } +} + +// TestBadRangeProof generates random bad proof scenarios and verifies that the proof is invalid. +func TestBadRangeProof(t *testing.T) { + tr, records := randomTrie(t, 5000) + root := tr.Hash() + + for range 500 { + start := rand.Intn(len(records)) + end := rand.Intn(len(records)-start) + start + 1 + + proof := NewProofNodeSet() + err := tr.GetRangeProof(records[start].key, records[end-1].key, proof) + require.NoError(t, err) + + keys := []*felt.Felt{} + values := []*felt.Felt{} + for j := start; j < end; j++ { + keys = append(keys, records[j].key) + values = append(values, records[j].value) + } + + first := keys[0] + testCase := rand.Intn(6) + + index := rand.Intn(end - start) + switch testCase { + case 0: // modified key + keys[index] = new(felt.Felt).SetUint64(rand.Uint64()) + case 1: // modified value + values[index] = new(felt.Felt).SetUint64(rand.Uint64()) + case 2: // out of order + index2 := rand.Intn(end - start) + if index2 == index { + continue + } + keys[index], keys[index2] = keys[index2], keys[index] + values[index], values[index2] = values[index2], values[index] + case 3: // set random key to empty + keys[index] = &felt.Zero + case 4: // set random value to empty + values[index] = &felt.Zero + case 5: // gapped + if end-start < 100 || index == 0 || index == end-start-1 { + continue + } + keys = append(keys[:index], keys[index+1:]...) + values = append(values[:index], values[index+1:]...) + } + _, err = VerifyRangeProof(&root, first, keys, values, proof) + if err == nil { + t.Fatalf("expected error for test case %d, index %d, start %d, end %d", testCase, index, start, end) + } + } +} + +func BenchmarkProve(b *testing.B) { + tr, records := randomTrie(b, 1000) + b.ResetTimer() + for i := range b.N { + proof := NewProofNodeSet() + key := records[i%len(records)].key + if err := tr.Prove(key, proof); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkVerifyProof(b *testing.B) { + tr, records := randomTrie(b, 1000) + root := tr.Hash() + + proofs := make([]*ProofNodeSet, 0, len(records)) + for _, record := range records { + proof := NewProofNodeSet() + if err := tr.Prove(record.key, proof); err != nil { + b.Fatal(err) + } + proofs = append(proofs, proof) + } + + b.ResetTimer() + for i := range b.N { + index := i % len(records) + if _, err := VerifyProof(&root, records[index].key, proofs[index], crypto.Pedersen); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkVerifyRangeProof(b *testing.B) { + tr, records := randomTrie(b, 1000) + root := tr.Hash() + + start := 2 + end := start + 500 + + proof := NewProofNodeSet() + err := tr.GetRangeProof(records[start].key, records[end-1].key, proof) + require.NoError(b, err) + + keys := make([]*felt.Felt, end-start) + values := make([]*felt.Felt, end-start) + for i := start; i < end; i++ { + keys[i-start] = records[i].key + values[i-start] = records[i].value + } + + b.ResetTimer() + for range b.N { + _, err := VerifyRangeProof(&root, keys[0], keys, values, proof) + require.NoError(b, err) + } +} + +func buildTrie(t *testing.T, records []*keyValue) *Trie { + if len(records) == 0 { + t.Fatal("records must have at least one element") + } + + tempTrie, err := NewEmptyPedersen() + require.NoError(t, err) + + for _, record := range records { + err = tempTrie.Update(record.key, record.value) + require.NoError(t, err) + } + + return tempTrie +} + +func build1KeyTrie(t *testing.T) (*Trie, []*keyValue) { + return nonRandomTrie(t, 1) +} + +func buildSimpleTrie(t *testing.T) (*Trie, []*keyValue) { + // (250, 0, x1) edge + // | + // (0,0,x1) binary + // / \ + // (2) (3) + records := []*keyValue{ + {key: new(felt.Felt).SetUint64(0), value: new(felt.Felt).SetUint64(2)}, + {key: new(felt.Felt).SetUint64(1), value: new(felt.Felt).SetUint64(3)}, + } + + return buildTrie(t, records), records +} + +func buildSimpleBinaryRootTrie(t *testing.T) (*Trie, []*keyValue) { + // PF + // (0, 0, x) + // / \ + // (250, 0, cc) (250, 11111.., dd) + // | | + // (cc) (dd) + + // JUNO + // (0, 0, x) + // / \ + // (251, 0, cc) (251, 11111.., dd) + records := []*keyValue{ + {key: new(felt.Felt).SetUint64(0), value: utils.HexToFelt(t, "0xcc")}, + {key: utils.HexToFelt(t, "0x7ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"), value: utils.HexToFelt(t, "0xdd")}, + } + return buildTrie(t, records), records +} + +//nolint:dupl +func buildSimpleDoubleBinaryTrie(t *testing.T) (*Trie, []*keyValue) { + // (249,0,x3) // Edge + // | + // (0, 0, x3) // Binary + // / \ + // (0,0,x1) // B (1, 1, 5) // Edge leaf + // / \ | + // (2) (3) (5) + records := []*keyValue{ + {key: new(felt.Felt).SetUint64(0), value: new(felt.Felt).SetUint64(2)}, + {key: new(felt.Felt).SetUint64(1), value: new(felt.Felt).SetUint64(3)}, + {key: new(felt.Felt).SetUint64(3), value: new(felt.Felt).SetUint64(5)}, + } + return buildTrie(t, records), records +} + +//nolint:dupl +func build3KeyTrie(t *testing.T) (*Trie, []*keyValue) { + // Starknet + // -------- + // + // Edge + // | + // Binary with len 249 parent + // / \ + // Binary (250) Edge with len 250 + // / \ / + // 0x4 0x5 0x6 child + + // Juno + // ---- + // + // Node (path 249) + // / \ + // Node (binary) \ + // / \ / + // 0x4 0x5 0x6 + records := []*keyValue{ + {key: new(felt.Felt).SetUint64(0), value: new(felt.Felt).SetUint64(4)}, + {key: new(felt.Felt).SetUint64(1), value: new(felt.Felt).SetUint64(5)}, + {key: new(felt.Felt).SetUint64(2), value: new(felt.Felt).SetUint64(6)}, + } + + return buildTrie(t, records), records +} + +func decrementFelt(f *felt.Felt) *felt.Felt { + return new(felt.Felt).Sub(f, new(felt.Felt).SetUint64(1)) +} + +func incrementFelt(f *felt.Felt) *felt.Felt { + return new(felt.Felt).Add(f, new(felt.Felt).SetUint64(1)) +} + +type testKey struct { + name string + key *felt.Felt + expected *felt.Felt +} + +type testTrie struct { + name string + buildFn func(*testing.T) (*Trie, []*keyValue) + testKeys []testKey +} diff --git a/core/trie2/tracer.go b/core/trie2/tracer.go new file mode 100644 index 0000000000..3197965221 --- /dev/null +++ b/core/trie2/tracer.go @@ -0,0 +1,55 @@ +package trie2 + +import ( + "maps" +) + +// Tracks the changes to the trie, so that we know which node needs to be updated or deleted in the database +type nodeTracer struct { + inserts map[Path]struct{} + deletes map[Path]struct{} +} + +func newTracer() *nodeTracer { + return &nodeTracer{ + inserts: make(map[Path]struct{}), + deletes: make(map[Path]struct{}), + } +} + +// Tracks the newly inserted trie node. If the trie node was previously deleted, remove it from the deletion set +// as it means that the node will not be deleted in the database +func (t *nodeTracer) onInsert(key *Path) { + k := *key + if _, present := t.deletes[k]; present { + delete(t.deletes, k) + return + } + t.inserts[k] = struct{}{} +} + +// Tracks the newly deleted trie node. If the trie node was previously inserted, remove it from the insertion set +// as it means that the node will not be inserted in the database +func (t *nodeTracer) onDelete(key *Path) { + k := *key + if _, present := t.inserts[k]; present { + delete(t.inserts, k) + return + } + t.deletes[k] = struct{}{} +} + +func (t *nodeTracer) copy() *nodeTracer { + return &nodeTracer{ + inserts: maps.Clone(t.inserts), + deletes: maps.Clone(t.deletes), + } +} + +func (t *nodeTracer) deletedNodes() []Path { + keys := make([]Path, 0, len(t.deletes)) + for k := range t.deletes { + keys = append(keys, k) + } + return keys +} diff --git a/core/trie2/trie.go b/core/trie2/trie.go new file mode 100644 index 0000000000..d00b16cd73 --- /dev/null +++ b/core/trie2/trie.go @@ -0,0 +1,511 @@ +package trie2 + +import ( + "bytes" + "errors" + "fmt" + + "github.com/NethermindEth/juno/core/crypto" + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/trie2/triedb" + "github.com/NethermindEth/juno/core/trie2/trienode" + "github.com/NethermindEth/juno/core/trie2/trieutils" + "github.com/NethermindEth/juno/db" +) + +const contractClassTrieHeight = 251 + +type Path = trieutils.BitArray + +type Trie struct { + // Height of the trie + height uint8 + + // The owner of the trie, only used for contract trie. If not empty, this is a storage trie. + owner felt.Felt + + // The root node of the trie + root node + + // Hash function used to hash the trie nodes + hashFn crypto.HashFn + + // The underlying database to store and retrieve trie nodes + db triedb.TrieDB + + // Check if the trie has been committed. Trie is unusable once committed. + committed bool + + // Maintains the records of trie changes, ensuring all nodes are modified or garbage collected properly + nodeTracer *nodeTracer + + // Tracks the number of leaves inserted since the last hashing operation + pendingHashes int + + // Tracks the total number of updates (inserts/deletes) since the last commit + pendingUpdates int +} + +func New(id *ID, height uint8, hashFn crypto.HashFn, txn db.Transaction) (*Trie, error) { + database := triedb.New(txn, id.Bucket()) + tr := &Trie{ + owner: id.Owner, + height: height, + hashFn: hashFn, + db: database, + nodeTracer: newTracer(), + } + + root, err := tr.resolveNode(nil, Path{}) + if err == nil { + tr.root = root + } else if !errors.Is(err, db.ErrKeyNotFound) { + return nil, err + } + + return tr, nil +} + +// Creates an empty trie, only used for temporary trie construction +func NewEmpty(height uint8, hashFn crypto.HashFn) *Trie { + return &Trie{ + height: height, + hashFn: hashFn, + root: nil, + nodeTracer: newTracer(), + db: triedb.EmptyDatabase{}, + } +} + +// Modifies or inserts a key-value pair in the trie. +// If value is zero, the key is deleted from the trie. +func (t *Trie) Update(key, value *felt.Felt) error { + if t.committed { + return ErrCommitted + } + + if err := t.update(key, value); err != nil { + return err + } + t.pendingUpdates++ + t.pendingHashes++ + + return nil +} + +// Retrieves the value associated with the given key. +// Returns felt.Zero if the key doesn't exist. +// May update the trie's internal structure if nodes need to be resolved. +// TODO(weiihann): +// The State should keep track of the modified key and values, so that we can avoid traversing the trie +// No action needed for the Trie, remove this once State provides the functionality +func (t *Trie) Get(key *felt.Felt) (felt.Felt, error) { + if t.committed { + return felt.Zero, ErrCommitted + } + + k := t.FeltToPath(key) + + var ret felt.Felt + val, root, didResolve, err := t.get(t.root, new(Path), &k) + // In Starknet, a non-existent key is mapped to felt.Zero + if val == nil { + ret = felt.Zero + } else { + ret = *val + } + + if err == nil && didResolve { + t.root = root + } + + return ret, err +} + +// Removes the given key from the trie. +func (t *Trie) Delete(key *felt.Felt) error { + if t.committed { + return ErrCommitted + } + + k := t.FeltToPath(key) + _, n, err := t.delete(t.root, new(Path), &k) + if err != nil { + return err + } + t.root = n + t.pendingUpdates++ + t.pendingHashes++ + return nil +} + +// Returns the root hash of the trie. Calling this method will also cache the hash of each node in the trie. +func (t *Trie) Hash() felt.Felt { + hash, cached := t.hashRoot() + t.root = cached + return hash.(*hashNode).Felt +} + +// Collapses the trie into a single hash node and flush the node changes to the database. +func (t *Trie) Commit() (felt.Felt, error) { + defer func() { + t.committed = true + }() + + // Trie is empty and can be classified into two types of situations: + // (a) The trie was empty and no update happens => return empty root + // (b) The trie was non-empty and all nodes are dropped => commit and return empty root + if t.root == nil { + paths := t.nodeTracer.deletedNodes() + if len(paths) == 0 { // case (a) + return felt.Zero, nil + } + // case (b) + nodes := trienode.NewNodeSet(t.owner) + for _, path := range paths { + nodes.Add(path, trienode.NewDeleted()) + } + err := nodes.ForEach(true, func(key trieutils.BitArray, node *trienode.Node) error { + return t.db.Delete(t.owner, key) + }) + return felt.Zero, err + } + + rootHash := t.Hash() + if hashedNode, dirty := t.root.cache(); !dirty { + t.root = hashedNode + return rootHash, nil + } + + nodes := trienode.NewNodeSet(t.owner) + for _, Path := range t.nodeTracer.deletedNodes() { + nodes.Add(Path, trienode.NewDeleted()) + } + + t.root = newCollector(nodes).Collect(t.root, t.pendingUpdates > 100) //nolint:mnd // TODO(weiihann): 100 is arbitrary + t.pendingUpdates = 0 + + err := nodes.ForEach(true, func(key trieutils.BitArray, node *trienode.Node) error { + if node.IsDeleted() { + return t.db.Delete(t.owner, key) + } + return t.db.Put(t.owner, key, node.Blob()) + }) + if err != nil { + return felt.Felt{}, err + } + + return rootHash, nil +} + +func (t *Trie) Copy() *Trie { + return &Trie{ + height: t.height, + owner: t.owner, + root: t.root, + hashFn: t.hashFn, + committed: t.committed, + db: t.db, + nodeTracer: t.nodeTracer.copy(), + pendingHashes: t.pendingHashes, + pendingUpdates: t.pendingUpdates, + } +} + +func (t *Trie) get(n node, prefix, key *Path) (*felt.Felt, node, bool, error) { + switch n := n.(type) { + case *edgeNode: + if !n.pathMatches(key) { + return nil, n, false, nil + } + val, child, didResolve, err := t.get(n.child, new(Path).Append(prefix, n.path), key.LSBs(key, n.path.Len())) + if err == nil && didResolve { + n = n.copy() + n.child = child + } + return val, n, didResolve, err + case *binaryNode: + bit := key.MSB() + val, child, didResolve, err := t.get(n.children[bit], new(Path).AppendBit(prefix, bit), key.LSBs(key, 1)) + if err == nil && didResolve { + n = n.copy() + n.children[bit] = child + } + return val, n, didResolve, err + case *hashNode: + child, err := t.resolveNode(n, *prefix) + if err != nil { + return nil, nil, false, err + } + value, newNode, _, err := t.get(child, prefix, key) + return value, newNode, true, err + case *valueNode: + return &n.Felt, n, false, nil + case nil: + return nil, nil, false, nil + default: + panic(fmt.Sprintf("unknown node type: %T", n)) + } +} + +// Modifies the trie by either inserting/updating a value or deleting a key. +// The operation is determined by whether the value is zero (delete) or non-zero (insert/update). +func (t *Trie) update(key, value *felt.Felt) error { + k := t.FeltToPath(key) + if value.IsZero() { + _, n, err := t.delete(t.root, new(Path), &k) + if err != nil { + return err + } + t.root = n + } else { + _, n, err := t.insert(t.root, new(Path), &k, &valueNode{Felt: *value}) + if err != nil { + return err + } + t.root = n + } + return nil +} + +//nolint:gocyclo,funlen +func (t *Trie) insert(n node, prefix, key *Path, value node) (bool, node, error) { + // We reach the end of the key + if key.Len() == 0 { + if v, ok := n.(*valueNode); ok { + vFelt := value.(*valueNode).Felt + return !v.Equal(&vFelt), value, nil + } + return true, value, nil + } + + switch n := n.(type) { + case *edgeNode: + match := n.commonPath(key) // get the matching bits between the current node and the key + // If the match is the same as the Path, just keep this edge node as it is and update the value + if match.Len() == n.path.Len() { + dirty, newNode, err := t.insert(n.child, new(Path).Append(prefix, n.path), key.LSBs(key, match.Len()), value) + if !dirty || err != nil { + return false, n, err + } + return true, &edgeNode{ + path: n.path, + child: newNode, + flags: newFlag(), + }, nil + } + // Otherwise branch out at the bit position where they differ + branch := &binaryNode{flags: newFlag()} + var err error + pathPrefix := new(Path).MSBs(n.path, match.Len()+1) + _, branch.children[n.path.Bit(match.Len())], err = t.insert( + nil, new(Path).Append(prefix, pathPrefix), new(Path).LSBs(n.path, match.Len()+1), n.child, + ) + if err != nil { + return false, n, err + } + + keyPrefix := new(Path).MSBs(key, match.Len()+1) + _, branch.children[key.Bit(match.Len())], err = t.insert( + nil, new(Path).Append(prefix, keyPrefix), new(Path).LSBs(key, match.Len()+1), value, + ) + if err != nil { + return false, n, err + } + + // Replace this edge node with the new binary node if it occurs at the current MSB + if match.IsEmpty() { + return true, branch, nil + } + matchPrefix := new(Path).MSBs(key, match.Len()) + t.nodeTracer.onInsert(new(Path).Append(prefix, matchPrefix)) + + // Otherwise, create a new edge node with the Path being the common Path and the branch as the child + return true, &edgeNode{path: matchPrefix, child: branch, flags: newFlag()}, nil + + case *binaryNode: + // Go to the child node based on the MSB of the key + bit := key.MSB() + dirty, newNode, err := t.insert( + n.children[bit], new(Path).AppendBit(prefix, bit), new(Path).LSBs(key, 1), value, + ) + if !dirty || err != nil { + return false, n, err + } + // Replace the child node with the new node + n = n.copy() + n.flags = newFlag() + n.children[bit] = newNode + return true, n, nil + case nil: + t.nodeTracer.onInsert(prefix) + // We reach the end of the key, return the value node + if key.IsEmpty() { + return true, value, nil + } + // Otherwise, return a new edge node with the Path being the key and the value as the child + return true, &edgeNode{path: key, child: value, flags: newFlag()}, nil + case *hashNode: + child, err := t.resolveNode(n, *prefix) + if err != nil { + return false, n, err + } + dirty, newNode, err := t.insert(child, prefix, key, value) + if !dirty || err != nil { + return false, child, err + } + return true, newNode, nil + default: + panic(fmt.Sprintf("unknown node type: %T", n)) + } +} + +//nolint:gocyclo,funlen +func (t *Trie) delete(n node, prefix, key *Path) (bool, node, error) { + switch n := n.(type) { + case *edgeNode: + match := n.commonPath(key) + // Mismatched, don't do anything + if match.Len() < n.path.Len() { + return false, n, nil + } + // If the whole key matches, remove the entire edge node + if match.Len() == key.Len() { + t.nodeTracer.onDelete(prefix) + return true, nil, nil + } + + // Otherwise, key is longer than current node path, so we need to delete the child. + // Child can never be nil because it's guaranteed that we have at least 2 other values in the subtrie. + keyPrefix := new(Path).MSBs(key, n.path.Len()) + dirty, child, err := t.delete(n.child, new(Path).Append(prefix, keyPrefix), new(Path).LSBs(key, n.path.Len())) + if !dirty || err != nil { + return false, n, err + } + switch child := child.(type) { + case *edgeNode: + t.nodeTracer.onDelete(new(Path).Append(prefix, n.path)) + return true, &edgeNode{path: new(Path).Append(n.path, child.path), child: child.child, flags: newFlag()}, nil + default: + return true, &edgeNode{path: new(Path).Set(n.path), child: child, flags: newFlag()}, nil + } + case *binaryNode: + bit := key.MSB() + keyPrefix := new(Path).MSBs(key, 1) + dirty, newNode, err := t.delete(n.children[bit], new(Path).Append(prefix, keyPrefix), key.LSBs(key, 1)) + if !dirty || err != nil { + return false, n, err + } + n = n.copy() + n.flags = newFlag() + n.children[bit] = newNode + + // If the child node is not nil, that means we still have 2 children in this binary node + if newNode != nil { + return true, n, nil + } + + // Otherwise, we need to combine this binary node with the other child + // If it's a hash node, we need to resolve it first. + // If the other child is an edge node, we prepend the bit prefix to the other child path + other := bit ^ 1 + bitPrefix := new(Path).SetBit(other) + + if hn, ok := n.children[other].(*hashNode); ok { + var cPath Path + cPath.Append(prefix, bitPrefix) + cNode, err := t.resolveNode(hn, cPath) + if err != nil { + return false, nil, err + } + n.children[other] = cNode + } + + if cn, ok := n.children[other].(*edgeNode); ok { + t.nodeTracer.onDelete(new(Path).Append(prefix, bitPrefix)) + return true, &edgeNode{ + path: new(Path).Append(bitPrefix, cn.path), + child: cn.child, + flags: newFlag(), + }, nil + } + + // other child is not an edge node, create a new edge node with the bit prefix as the Path + // containing the other child as the child + return true, &edgeNode{path: bitPrefix, child: n.children[other], flags: newFlag()}, nil + case *valueNode: + return true, nil, nil + case nil: + return false, nil, nil + case *hashNode: + child, err := t.resolveNode(n, *prefix) + if err != nil { + return false, nil, err + } + + dirty, newNode, err := t.delete(child, prefix, key) + if !dirty || err != nil { + return false, child, err + } + return true, newNode, nil + default: + panic(fmt.Sprintf("unknown node type: %T", n)) + } +} + +// Resolves the node at the given path from the database +func (t *Trie) resolveNode(hn *hashNode, path Path) (node, error) { + buf := bufferPool.Get().(*bytes.Buffer) + buf.Reset() + defer func() { + buf.Reset() + bufferPool.Put(buf) + }() + + _, err := t.db.Get(buf, t.owner, path) + if err != nil { + return nil, err + } + + blob := buf.Bytes() + + var hash *felt.Felt + if hn != nil { + hash = &hn.Felt + } + return decodeNode(blob, hash, path.Len(), t.height) +} + +// Calculate the hash of the root node +func (t *Trie) hashRoot() (node, node) { + if t.root == nil { + return &hashNode{Felt: felt.Zero}, nil + } + h := newHasher(t.hashFn, t.pendingHashes > 100) //nolint:mnd //TODO(weiihann): 100 is arbitrary + hashed, cached := h.hash(t.root) + t.pendingHashes = 0 + return hashed, cached +} + +// Converts a Felt value into a Path representation suitable to +// use as a trie key with the specified height. +func (t *Trie) FeltToPath(f *felt.Felt) Path { + var key Path + key.SetFelt(t.height, f) + return key +} + +func (t *Trie) String() string { + if t.root == nil { + return "" + } + return t.root.String() +} + +func NewEmptyPedersen() (*Trie, error) { + return New(TrieID(), contractClassTrieHeight, crypto.Pedersen, db.NewMemTransaction()) +} + +func NewEmptyPoseidon() (*Trie, error) { + return New(TrieID(), contractClassTrieHeight, crypto.Poseidon, db.NewMemTransaction()) +} diff --git a/core/trie2/trie_test.go b/core/trie2/trie_test.go new file mode 100644 index 0000000000..4b19a323ef --- /dev/null +++ b/core/trie2/trie_test.go @@ -0,0 +1,484 @@ +package trie2 + +import ( + "fmt" + "io" + "math/rand" + "reflect" + "sort" + "testing" + "testing/quick" + + "github.com/NethermindEth/juno/core/crypto" + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/utils" + "github.com/davecgh/go-spew/spew" + "github.com/stretchr/testify/require" +) + +func TestUpdate(t *testing.T) { + verifyRecords := func(t *testing.T, tr *Trie, records []*keyValue) { + t.Helper() + for _, record := range records { + got, err := tr.Get(record.key) + require.NoError(t, err) + require.True(t, got.Equal(record.value), "expected %v, got %v", record.value, got) + } + } + + t.Run("sequential", func(t *testing.T) { + tr, records := nonRandomTrie(t, 10000) + verifyRecords(t, tr, records) + }) + + t.Run("random", func(t *testing.T) { + tr, records := randomTrie(t, 10000) + verifyRecords(t, tr, records) + }) +} + +func TestDelete(t *testing.T) { + verifyDelete := func(t *testing.T, tr *Trie, records []*keyValue) { + t.Helper() + for _, record := range records { + err := tr.Delete(record.key) + require.NoError(t, err) + + got, err := tr.Get(record.key) + require.NoError(t, err) + require.True(t, got.Equal(&felt.Zero), "expected %v, got %v", &felt.Zero, got) + } + } + + t.Run("sequential", func(t *testing.T) { + tr, records := nonRandomTrie(t, 10000) + verifyDelete(t, tr, records) + }) + + t.Run("random", func(t *testing.T) { + tr, records := randomTrie(t, 10000) + // Delete in reverse order for random case + for i, j := 0, len(records)-1; i < j; i, j = i+1, j-1 { + records[i], records[j] = records[j], records[i] + } + verifyDelete(t, tr, records) + }) +} + +// The expected hashes are taken from Pathfinder's tests +func TestHash(t *testing.T) { + t.Run("empty", func(t *testing.T) { + tr, _ := NewEmptyPedersen() + hash := tr.Hash() + require.Equal(t, felt.Zero, hash) + }) + + t.Run("one leaf", func(t *testing.T) { + tr, _ := NewEmptyPedersen() + err := tr.Update(new(felt.Felt).SetUint64(1), new(felt.Felt).SetUint64(2)) + require.NoError(t, err) + hash := tr.Hash() + + expected := "0x2ab889bd35e684623df9b4ea4a4a1f6d9e0ef39b67c1293b8a89dd17e351330" + require.Equal(t, expected, hash.String(), "expected %s, got %s", expected, hash.String()) + }) + + t.Run("two leaves", func(t *testing.T) { + tr, _ := NewEmptyPedersen() + err := tr.Update(new(felt.Felt).SetUint64(0), new(felt.Felt).SetUint64(2)) + require.NoError(t, err) + err = tr.Update(new(felt.Felt).SetUint64(1), new(felt.Felt).SetUint64(3)) + require.NoError(t, err) + root := tr.Hash() + + expected := "0x79acdb7a3d78052114e21458e8c4aecb9d781ce79308193c61a2f3f84439f66" + require.Equal(t, expected, root.String(), "expected %s, got %s", expected, root.String()) + }) + + t.Run("three leaves", func(t *testing.T) { + tr, _ := NewEmptyPedersen() + + keys := []*felt.Felt{ + new(felt.Felt).SetUint64(16), + new(felt.Felt).SetUint64(17), + new(felt.Felt).SetUint64(19), + } + + vals := []*felt.Felt{ + new(felt.Felt).SetUint64(10), + new(felt.Felt).SetUint64(11), + new(felt.Felt).SetUint64(12), + } + + for i := range keys { + err := tr.Update(keys[i], vals[i]) + require.NoError(t, err) + } + + expected := "0x7e2184e9e1a651fd556b42b4ff10e44a71b1709f641e0203dc8bd2b528e5e81" + root := tr.Hash() + require.Equal(t, expected, root.String(), "expected %s, got %s", expected, root.String()) + }) + + t.Run("double binary", func(t *testing.T) { + // (249,0,x3) + // | + // (0, 0, x3) + // / \ + // (0,0,x1) (1, 1, 5) + // / \ | + // (2) (3) (5) + + keys := []*felt.Felt{ + new(felt.Felt).SetUint64(0), + new(felt.Felt).SetUint64(1), + new(felt.Felt).SetUint64(3), + } + + vals := []*felt.Felt{ + new(felt.Felt).SetUint64(2), + new(felt.Felt).SetUint64(3), + new(felt.Felt).SetUint64(5), + } + + tr, _ := NewEmptyPedersen() + for i := range keys { + err := tr.Update(keys[i], vals[i]) + require.NoError(t, err) + } + + expected := "0x6a316f09913454294c6b6751dea8449bc2e235fdc04b2ab0e1ac7fea25cc34f" + root := tr.Hash() + require.Equal(t, expected, root.String(), "expected %s, got %s", expected, root.String()) + }) + + t.Run("binary root", func(t *testing.T) { + // (0, 0, x) + // / \ + // (250, 0, cc) (250, 11111.., dd) + // | | + // (cc) (dd) + + keys := []*felt.Felt{ + new(felt.Felt).SetUint64(0), + utils.HexToFelt(t, "0x7ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"), + } + + vals := []*felt.Felt{ + utils.HexToFelt(t, "0xcc"), + utils.HexToFelt(t, "0xdd"), + } + + tr, _ := NewEmptyPedersen() + for i := range keys { + err := tr.Update(keys[i], vals[i]) + require.NoError(t, err) + } + + expected := "0x542ced3b6aeef48339129a03e051693eff6a566d3a0a94035fa16ab610dc9e2" + root := tr.Hash() + require.Equal(t, expected, root.String(), "expected %s, got %s", expected, root.String()) + }) +} + +func TestCommit(t *testing.T) { + verifyCommit := func(t *testing.T, records []*keyValue) { + t.Helper() + db := db.NewMemTransaction() + tr, err := New(TrieID(), contractClassTrieHeight, crypto.Pedersen, db) + require.NoError(t, err) + + for _, record := range records { + err := tr.Update(record.key, record.value) + require.NoError(t, err) + } + + _, err = tr.Commit() + require.NoError(t, err) + + tr2, err := New(TrieID(), contractClassTrieHeight, crypto.Pedersen, db) + require.NoError(t, err) + + for _, record := range records { + got, err := tr2.Get(record.key) + require.NoError(t, err) + require.True(t, got.Equal(record.value), "expected %v, got %v", record.value, got) + } + } + + t.Run("sequential", func(t *testing.T) { + _, records := nonRandomTrie(t, 10000) + verifyCommit(t, records) + }) + + t.Run("random", func(t *testing.T) { + _, records := randomTrie(t, 10000) + verifyCommit(t, records) + }) +} + +func TestRandom(t *testing.T) { + if err := quick.Check(runRandTestBool, nil); err != nil { + if cerr, ok := err.(*quick.CheckError); ok { + t.Fatalf("random test iteration %d failed: %s", cerr.Count, spew.Sdump(cerr.In)) + } + t.Fatal(err) + } +} + +func TestSpecificRandomFailure(t *testing.T) { + // Create test steps that match the failing sequence + key1 := utils.HexToFelt(t, "0x5c0d77d04056ae1e0c49bce8a3223b0373ccf94c1b04935d") + key2 := utils.HexToFelt(t, "0x19bc6d500358f3046b0743c903") + key3 := utils.HexToFelt(t, "0xf7d2180a5a138b325ea522e2b65b58a5dffb859b1fbbdab0e5efb125e8a840") + + steps := []randTestStep{ + {op: opProve, key: key1}, + {op: opCommit}, + {op: opDelete, key: key2}, + {op: opHash}, + {op: opCommit}, + {op: opCommit}, + {op: opGet, key: key2}, + {op: opUpdate, key: key2, value: new(felt.Felt).SetUint64(7)}, + {op: opHash}, + {op: opDelete, key: key1}, + {op: opUpdate, key: key1, value: new(felt.Felt).SetUint64(10)}, + {op: opHash}, + {op: opGet, key: key2}, + {op: opUpdate, key: key1, value: new(felt.Felt).SetUint64(13)}, + {op: opCommit}, + {op: opHash}, + {op: opProve, key: key1}, + {op: opProve, key: key3}, + {op: opUpdate, key: key3, value: new(felt.Felt).SetUint64(18)}, + {op: opCommit}, + {op: opGet, key: key2}, + {op: opUpdate, key: key2, value: new(felt.Felt).SetUint64(21)}, + {op: opHash}, + {op: opHash}, + {op: opGet, key: key2}, + {op: opProve, key: key1}, + {op: opUpdate, key: key2, value: new(felt.Felt).SetUint64(26)}, + {op: opHash}, + {op: opGet, key: key2}, + {op: opCommit}, + {op: opCommit}, + {op: opProve, key: key3}, // This is where the original failure occurred + } + + // Add debug logging + t.Log("Starting test sequence") + err := runRandTest(steps) + if err != nil { + t.Logf("Test failed at step: %v", err) + // Print the state of the trie at failure + for i, step := range steps { + if step.err != nil { + t.Logf("Failed at step %d: %v", i, step) + break + } + } + } + require.NoError(t, err, "specific random test sequence should not fail") +} + +const ( + opUpdate = iota + opDelete + opGet + opHash + opCommit + opProve + opMax // max number of operations, not an actual operation +) + +type randTestStep struct { + op int + key *felt.Felt // for opUpdate, opDelete, opGet + value *felt.Felt // for opUpdate + err error +} + +type randTest []randTestStep + +func (randTest) Generate(r *rand.Rand, size int) reflect.Value { + finishedFn := func() bool { + size-- + return size == 0 + } + return reflect.ValueOf(generateSteps(finishedFn, r)) +} + +func generateSteps(finished func() bool, r io.Reader) randTest { + var allKeys []*felt.Felt + random := []byte{0} + + genKey := func() *felt.Felt { + _, _ = r.Read(random) + // Create a new key with 10% probability or when < 2 keys exist + if len(allKeys) < 2 || random[0]%100 > 90 { + size := random[0] % 32 // ensure key size is between 1 and 32 bytes + key := make([]byte, size) + _, _ = r.Read(key) + allKeys = append(allKeys, new(felt.Felt).SetBytes(key)) + } + // 90% probability to return an existing key + idx := int(random[0]) % len(allKeys) + return allKeys[idx] + } + + var steps randTest + for !finished() { + _, _ = r.Read(random) + step := randTestStep{op: int(random[0]) % opMax} + switch step.op { + case opUpdate: + step.key = genKey() + step.value = new(felt.Felt).SetUint64(uint64(len(steps))) + case opGet, opDelete, opProve: + step.key = genKey() + } + steps = append(steps, step) + } + return steps +} + +func runRandTestBool(rt randTest) bool { + return runRandTest(rt) == nil +} + +//nolint:gocyclo +func runRandTest(rt randTest) error { + txn := db.NewMemTransaction() + tr, err := New(TrieID(), contractClassTrieHeight, crypto.Pedersen, txn) + if err != nil { + return err + } + + values := make(map[felt.Felt]felt.Felt) // keeps track of the content of the trie + + for i, step := range rt { + switch step.op { + case opUpdate: + err := tr.Update(step.key, step.value) + if err != nil { + rt[i].err = fmt.Errorf("update failed: key %s, %w", step.key.String(), err) + } + values[*step.key] = *step.value + case opDelete: + err := tr.Delete(step.key) + if err != nil { + rt[i].err = fmt.Errorf("delete failed: key %s, %w", step.key.String(), err) + } + delete(values, *step.key) + case opGet: + got, err := tr.Get(step.key) + if err != nil { + rt[i].err = fmt.Errorf("get failed: key %s, %w", step.key.String(), err) + } + want := values[*step.key] + if !got.Equal(&want) { + rt[i].err = fmt.Errorf("mismatch in get: key %s, expected %v, got %v", step.key.String(), want.String(), got.String()) + } + case opProve: + hash := tr.Hash() + if hash.Equal(&felt.Zero) { + continue + } + proof := NewProofNodeSet() + err := tr.Prove(step.key, proof) + if err != nil { + rt[i].err = fmt.Errorf("prove failed for key %s: %w", step.key.String(), err) + } + _, err = VerifyProof(&hash, step.key, proof, crypto.Pedersen) + if err != nil { + rt[i].err = fmt.Errorf("verify proof failed for key %s: %w", step.key.String(), err) + } + case opHash: + tr.Hash() + case opCommit: + _, err := tr.Commit() + if err != nil { + rt[i].err = fmt.Errorf("commit failed: %w", err) + } + newtr, err := New(TrieID(), contractClassTrieHeight, crypto.Pedersen, txn) + if err != nil { + rt[i].err = fmt.Errorf("new trie failed: %w", err) + } + tr = newtr + } + + if rt[i].err != nil { + return rt[i].err + } + } + return nil +} + +type keyValue struct { + key *felt.Felt + value *felt.Felt +} + +func nonRandomTrie(t *testing.T, numKeys int) (*Trie, []*keyValue) { + tr, _ := NewEmptyPedersen() + records := make([]*keyValue, numKeys) + + for i := 1; i < numKeys+1; i++ { + key := new(felt.Felt).SetUint64(uint64(i)) + records[i-1] = &keyValue{key: key, value: key} + err := tr.Update(key, key) + require.NoError(t, err) + } + + return tr, records +} + +func randomTrie(t testing.TB, n int) (*Trie, []*keyValue) { + rrand := rand.New(rand.NewSource(3)) + + tr, _ := NewEmptyPedersen() + records := make([]*keyValue, n) + + for i := range n { + key := new(felt.Felt).SetUint64(uint64(rrand.Uint32() + 1)) + records[i] = &keyValue{key: key, value: key} + err := tr.Update(key, key) + require.NoError(t, err) + } + + // Sort records by key + sort.Slice(records, func(i, j int) bool { + return records[i].key.Cmp(records[j].key) < 0 + }) + + return tr, records +} + +func build4KeysTrieD(t *testing.T) (*Trie, []*keyValue) { + records := []*keyValue{ + {key: new(felt.Felt).SetUint64(1), value: new(felt.Felt).SetUint64(4)}, + {key: new(felt.Felt).SetUint64(4), value: new(felt.Felt).SetUint64(5)}, + {key: new(felt.Felt).SetUint64(6), value: new(felt.Felt).SetUint64(6)}, + {key: new(felt.Felt).SetUint64(7), value: new(felt.Felt).SetUint64(7)}, + } + + return buildTestTrie(t, records), records +} + +func buildTestTrie(t *testing.T, records []*keyValue) *Trie { + if len(records) == 0 { + t.Fatal("records must have at least one element") + } + + tempTrie, _ := NewEmptyPedersen() + + for _, record := range records { + err := tempTrie.Update(record.key, record.value) + require.NoError(t, err) + } + + return tempTrie +} diff --git a/core/trie2/triedb/database.go b/core/trie2/triedb/database.go new file mode 100644 index 0000000000..979c0ab840 --- /dev/null +++ b/core/trie2/triedb/database.go @@ -0,0 +1,128 @@ +package triedb + +import ( + "bytes" + "errors" + "sync" + + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/trie2/trieutils" + "github.com/NethermindEth/juno/db" +) + +var ErrCallEmptyDatabase = errors.New("call to empty database") + +var dbBufferPool = sync.Pool{ + New: func() any { + return new(bytes.Buffer) + }, +} + +var ( + _ TrieDB = (*Database)(nil) + _ TrieDB = (*EmptyDatabase)(nil) +) + +type TrieDB interface { + Get(buf *bytes.Buffer, owner felt.Felt, path trieutils.BitArray) (int, error) + Put(owner felt.Felt, path trieutils.BitArray, blob []byte) error + Delete(owner felt.Felt, path trieutils.BitArray) error +} + +type Database struct { + txn db.Transaction + prefix db.Bucket +} + +func New(txn db.Transaction, prefix db.Bucket) *Database { + return &Database{txn: txn, prefix: prefix} +} + +func (d *Database) Get(buf *bytes.Buffer, owner felt.Felt, path trieutils.BitArray) (int, error) { + dbBuf := dbBufferPool.Get().(*bytes.Buffer) + dbBuf.Reset() + defer func() { + dbBuf.Reset() + dbBufferPool.Put(dbBuf) + }() + + if err := d.dbKey(dbBuf, owner, path); err != nil { + return 0, err + } + + err := d.txn.Get(dbBuf.Bytes(), func(blob []byte) error { + buf.Write(blob) + return nil + }) + if err != nil { + return 0, err + } + + return buf.Len(), nil +} + +func (d *Database) Put(owner felt.Felt, path trieutils.BitArray, blob []byte) error { + buffer := dbBufferPool.Get().(*bytes.Buffer) + buffer.Reset() + defer func() { + buffer.Reset() + dbBufferPool.Put(buffer) + }() + + if err := d.dbKey(buffer, owner, path); err != nil { + return err + } + + return d.txn.Set(buffer.Bytes(), blob) +} + +func (d *Database) Delete(owner felt.Felt, path trieutils.BitArray) error { + buffer := dbBufferPool.Get().(*bytes.Buffer) + buffer.Reset() + defer func() { + buffer.Reset() + dbBufferPool.Put(buffer) + }() + + if err := d.dbKey(buffer, owner, path); err != nil { + return err + } + + return d.txn.Delete(buffer.Bytes()) +} + +func (d *Database) dbKey(buf *bytes.Buffer, owner felt.Felt, path trieutils.BitArray) error { + _, err := buf.Write(d.prefix.Key()) + if err != nil { + return err + } + + if owner != (felt.Felt{}) { + oBytes := owner.Bytes() + _, err = buf.Write(oBytes[:]) + if err != nil { + return err + } + } + + _, err = path.Write(buf) + if err != nil { + return err + } + + return nil +} + +type EmptyDatabase struct{} + +func (EmptyDatabase) Get(buf *bytes.Buffer, owner felt.Felt, path trieutils.BitArray) (int, error) { + return 0, ErrCallEmptyDatabase +} + +func (EmptyDatabase) Put(owner felt.Felt, path trieutils.BitArray, blob []byte) error { + return ErrCallEmptyDatabase +} + +func (EmptyDatabase) Delete(owner felt.Felt, path trieutils.BitArray) error { + return ErrCallEmptyDatabase +} diff --git a/core/trie2/trienode/node.go b/core/trie2/trienode/node.go new file mode 100644 index 0000000000..c27e9e7775 --- /dev/null +++ b/core/trie2/trienode/node.go @@ -0,0 +1,31 @@ +package trienode + +import ( + "github.com/NethermindEth/juno/core/felt" +) + +// Represents a raw trie node, which contains the encoded blob and the hash of the node. +type Node struct { + blob []byte + hash felt.Felt +} + +func (r *Node) IsDeleted() bool { + return len(r.blob) == 0 +} + +func (r *Node) Blob() []byte { + return r.blob +} + +func (r *Node) Hash() felt.Felt { + return r.hash +} + +func NewNode(hash felt.Felt, blob []byte) *Node { + return &Node{hash: hash, blob: blob} +} + +func NewDeleted() *Node { + return &Node{hash: felt.Felt{}, blob: nil} +} diff --git a/core/trie2/trienode/nodeset.go b/core/trie2/trienode/nodeset.go new file mode 100644 index 0000000000..8b0c2f6f02 --- /dev/null +++ b/core/trie2/trienode/nodeset.go @@ -0,0 +1,97 @@ +package trienode + +import ( + "fmt" + "maps" + "sort" + + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/trie2/trieutils" +) + +// Contains a set of nodes, which are indexed by their path in the trie. +// It is not thread safe. +type NodeSet struct { + Owner felt.Felt // The owner (i.e. contract address) + Nodes map[trieutils.BitArray]*Node + updates int // the count of updated and inserted nodes + deletes int // the count of deleted nodes +} + +func NewNodeSet(owner felt.Felt) *NodeSet { + return &NodeSet{Owner: owner, Nodes: make(map[trieutils.BitArray]*Node)} +} + +func (ns *NodeSet) Add(key trieutils.BitArray, node *Node) { + if node.IsDeleted() { + ns.deletes += 1 + } else { + ns.updates += 1 + } + ns.Nodes[key] = node +} + +// Iterates over the nodes in a sorted order and calls the callback for each node. +func (ns *NodeSet) ForEach(desc bool, callback func(key trieutils.BitArray, node *Node) error) error { + paths := make([]trieutils.BitArray, 0, len(ns.Nodes)) + for key := range ns.Nodes { + paths = append(paths, key) + } + + if desc { // longest path first + sort.Slice(paths, func(i, j int) bool { + return paths[i].Cmp(&paths[j]) > 0 + }) + } else { + sort.Slice(paths, func(i, j int) bool { + return paths[i].Cmp(&paths[j]) < 0 + }) + } + + for _, key := range paths { + if err := callback(key, ns.Nodes[key]); err != nil { + return err + } + } + + return nil +} + +// Merges the other node set into the current node set. +// The owner of both node sets must be the same. +func (ns *NodeSet) MergeSet(other *NodeSet) error { + if ns.Owner != other.Owner { + return fmt.Errorf("cannot merge node sets with different owners %x-%x", ns.Owner, other.Owner) + } + maps.Copy(ns.Nodes, other.Nodes) + ns.updates += other.updates + ns.deletes += other.deletes + return nil +} + +// Adds a set of nodes to the current node set. +func (ns *NodeSet) Merge(owner felt.Felt, other map[trieutils.BitArray]*Node) error { + if ns.Owner != owner { + return fmt.Errorf("cannot merge node sets with different owners %x-%x", ns.Owner, owner) + } + + for path, node := range other { + prev, ok := ns.Nodes[path] + if ok { // node already exists, revoke the counter first + if prev.IsDeleted() { + ns.deletes -= 1 + } else { + ns.updates -= 1 + } + } + // overwrite the existing node (if it exists) + if node.IsDeleted() { + ns.deletes += 1 + } else { + ns.updates += 1 + } + ns.Nodes[path] = node + } + + return nil +} diff --git a/core/trie2/trienode/trienode_test.go b/core/trie2/trienode/trienode_test.go new file mode 100644 index 0000000000..9b22e9e6e2 --- /dev/null +++ b/core/trie2/trienode/trienode_test.go @@ -0,0 +1,188 @@ +package trienode + +import ( + "testing" + + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/core/trie2/trieutils" + "github.com/stretchr/testify/require" +) + +func TestNodeSet(t *testing.T) { + t.Run("new node set", func(t *testing.T) { + ns := NewNodeSet(felt.Zero) + require.Equal(t, felt.Zero, ns.Owner) + require.Empty(t, ns.Nodes) + require.Zero(t, ns.updates) + require.Zero(t, ns.deletes) + }) + + t.Run("add nodes", func(t *testing.T) { + ns := NewNodeSet(felt.Zero) + + // Add a regular node + key1 := trieutils.NewBitArray(8, 0xFF) + node1 := NewNode(felt.Zero, []byte{1, 2, 3}) + ns.Add(key1, node1) + require.Equal(t, 1, ns.updates) + require.Equal(t, 0, ns.deletes) + + // Add a deleted node + key2 := trieutils.NewBitArray(8, 0xAA) + node2 := NewDeleted() + ns.Add(key2, node2) + require.Equal(t, 1, ns.updates) + require.Equal(t, 1, ns.deletes) + + // Verify nodes are stored correctly + require.Equal(t, node1, ns.Nodes[key1]) + require.Equal(t, node2, ns.Nodes[key2]) + }) + + t.Run("merge sets", func(t *testing.T) { + ns1 := NewNodeSet(felt.Zero) + ns2 := NewNodeSet(felt.Zero) + + // Add nodes to first set + key1 := trieutils.NewBitArray(8, 0xFF) + node1 := NewNode(felt.Zero, []byte{1, 2, 3}) + ns1.Add(key1, node1) + + // Add nodes to second set + key2 := trieutils.NewBitArray(8, 0xAA) + node2 := NewDeleted() + ns2.Add(key2, node2) + + // Merge sets + err := ns1.MergeSet(ns2) + require.NoError(t, err) + + // Verify merged state + require.Equal(t, 2, len(ns1.Nodes)) + require.Equal(t, node1, ns1.Nodes[key1]) + require.Equal(t, node2, ns1.Nodes[key2]) + require.Equal(t, 1, ns1.updates) + require.Equal(t, 1, ns1.deletes) + }) + + t.Run("merge with different owners", func(t *testing.T) { + owner1 := new(felt.Felt).SetUint64(123) + owner2 := new(felt.Felt).SetUint64(456) + ns1 := NewNodeSet(*owner1) + ns2 := NewNodeSet(*owner2) + + err := ns1.MergeSet(ns2) + require.Error(t, err) + }) + + t.Run("merge map", func(t *testing.T) { + owner := new(felt.Felt).SetUint64(123) + ns := NewNodeSet(*owner) + + // Create a map to merge + nodes := make(map[trieutils.BitArray]*Node) + key1 := trieutils.NewBitArray(8, 0xFF) + node1 := NewNode(felt.Zero, []byte{1, 2, 3}) + nodes[key1] = node1 + + // Merge map + err := ns.Merge(*owner, nodes) + require.NoError(t, err) + + // Verify merged state + require.Equal(t, 1, len(ns.Nodes)) + require.Equal(t, node1, ns.Nodes[key1]) + require.Equal(t, 1, ns.updates) + require.Equal(t, 0, ns.deletes) + }) + + t.Run("foreach", func(t *testing.T) { + ns := NewNodeSet(felt.Zero) + + // Add nodes in random order + keys := []trieutils.BitArray{ + trieutils.NewBitArray(8, 0xFF), + trieutils.NewBitArray(8, 0xAA), + trieutils.NewBitArray(8, 0x55), + } + for _, key := range keys { + ns.Add(key, NewNode(felt.Zero, []byte{1})) + } + + t.Run("ascending order", func(t *testing.T) { + var visited []trieutils.BitArray + _ = ns.ForEach(false, func(key trieutils.BitArray, node *Node) error { + visited = append(visited, key) + return nil + }) + + // Verify ascending order + for i := 1; i < len(visited); i++ { + require.True(t, visited[i-1].Cmp(&visited[i]) < 0) + } + }) + + t.Run("descending order", func(t *testing.T) { + var visited []trieutils.BitArray + _ = ns.ForEach(true, func(key trieutils.BitArray, node *Node) error { + visited = append(visited, key) + return nil + }) + + // Verify descending order + for i := 1; i < len(visited); i++ { + require.True(t, visited[i-1].Cmp(&visited[i]) > 0) + } + }) + }) +} + +func TestNode(t *testing.T) { + t.Run("new node", func(t *testing.T) { + hash := new(felt.Felt).SetUint64(123) + blob := []byte{1, 2, 3} + node := NewNode(*hash, blob) + + require.Equal(t, *hash, node.hash) + require.Equal(t, blob, node.blob) + require.False(t, node.IsDeleted()) + }) + + t.Run("new deleted node", func(t *testing.T) { + node := NewDeleted() + require.True(t, node.IsDeleted()) + require.Equal(t, felt.Zero, node.hash) + require.Nil(t, node.blob) + }) + + t.Run("is deleted", func(t *testing.T) { + tests := []struct { + name string + blob []byte + expected bool + }{ + { + name: "nil blob", + blob: nil, + expected: true, + }, + { + name: "empty blob", + blob: []byte{}, + expected: true, + }, + { + name: "non-empty blob", + blob: []byte{1, 2, 3}, + expected: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + node := NewNode(felt.Zero, test.blob) + require.Equal(t, test.expected, node.IsDeleted()) + }) + } + }) +} diff --git a/core/trie2/trieutils/bitarray.go b/core/trie2/trieutils/bitarray.go new file mode 100644 index 0000000000..6d53057225 --- /dev/null +++ b/core/trie2/trieutils/bitarray.go @@ -0,0 +1,903 @@ +package trieutils + +import ( + "bytes" + "encoding/binary" + "encoding/hex" + "errors" + "fmt" + "math" + "math/bits" + + "github.com/NethermindEth/juno/core/felt" +) + +const ( + maxUint64 = uint64(math.MaxUint64) // 0xFFFFFFFFFFFFFFFF + maxUint8 = uint8(math.MaxUint8) + bytes32 = 32 + MaxBitArraySize = 33 // (1 + 4 * 8) bytes +) + +var emptyBitArray = new(BitArray) // TODO(weiihann): remove this nolint when we replace legacy trie + +// Represents a bit array with length representing the number of used bits. +// It uses a little endian representation to do bitwise operations of the words efficiently. +// For example, if len is 10, it means that the 2^9, 2^8, ..., 2^0 bits are used. +// The max length is 255 bits (uint8), because our use case only need up to 251 bits for a given trie key. +// Although words can be used to represent 256 bits, we don't want to add an additional byte for the length. +type BitArray struct { + len uint8 // number of used bits + words [4]uint64 // little endian (i.e. words[0] is the least significant) +} + +func NewBitArray(length uint8, val uint64) BitArray { + var b BitArray + b.SetUint64(length, val) + return b +} + +// Returns the felt representation of the bit array. +func (b *BitArray) Felt() felt.Felt { + var f felt.Felt + bt := b.Bytes() + f.SetBytes(bt[:]) + return f +} + +func (b *BitArray) Len() uint8 { + return b.len +} + +// Returns the bytes representation of the bit array in big endian format +func (b *BitArray) Bytes() [32]byte { + var res [32]byte + + binary.BigEndian.PutUint64(res[0:8], b.words[3]) + binary.BigEndian.PutUint64(res[8:16], b.words[2]) + binary.BigEndian.PutUint64(res[16:24], b.words[1]) + binary.BigEndian.PutUint64(res[24:32], b.words[0]) + + return res +} + +// Sets the bit array to the least significant 'n' bits of x. +// n is counted from the least significant bit, starting at 0. +// If length >= x.len, the bit array is an exact copy of x. +// For example: +// +// x = 11001011 (len=8) +// LSBsFromLSB(x, 4) = 1011 (len=4) +// LSBsFromLSB(x, 10) = 11001011 (len=8, original x) +// LSBsFromLSB(x, 0) = 0 (len=0) +// +//nolint:mnd +func (b *BitArray) LSBsFromLSB(x *BitArray, n uint8) *BitArray { + if n >= x.len { + return b.Set(x) + } + + b.len = n + + switch { + case n == 0: + b.words = [4]uint64{0, 0, 0, 0} + case n <= 64: + b.words[0] = x.words[0] & (maxUint64 >> (64 - n)) + b.words[1], b.words[2], b.words[3] = 0, 0, 0 + case n <= 128: + b.words[0] = x.words[0] + b.words[1] = x.words[1] & (maxUint64 >> (128 - n)) + b.words[2], b.words[3] = 0, 0 + case n <= 192: + b.words[0] = x.words[0] + b.words[1] = x.words[1] + b.words[2] = x.words[2] & (maxUint64 >> (192 - n)) + b.words[3] = 0 + default: + b.words[0] = x.words[0] + b.words[1] = x.words[1] + b.words[2] = x.words[2] + b.words[3] = x.words[3] & (maxUint64 >> (256 - uint16(n))) + } + + return b +} + +// Returns the least significant bits of `x` with `n` counted from the most significant bit, starting at 0. +// Think of this method as array[n:] +// For example: +// +// x = 11001011 (len=8) +// LSBs(x, 1) = 1001011 (len=7) +// LSBs(x, 10) = 0 (len=0) +// LSBs(x, 0) = 11001011 (len=8, original x) +func (b *BitArray) LSBs(x *BitArray, n uint8) *BitArray { + if n == 0 { + return b.Set(x) + } + + if n > x.Len() { + return b.clear() + } + + return b.LSBsFromLSB(x, x.Len()-n) +} + +// Checks if the current bit array share the same most significant bits with another, where the length of +// the check is determined by the shorter array. Returns true if either array has +// length 0, or if the first min(b.len, x.len) MSBs are identical. +// +// For example: +// +// a = 1101 (len=4) +// b = 11010111 (len=8) +// a.EqualMSBs(b) = true // First 4 MSBs match +// +// a = 1100 (len=4) +// b = 1101 (len=4) +// a.EqualMSBs(b) = false // All bits compared, not equal +// +// a = 1100 (len=4) +// b = [] (len=0) +// a.EqualMSBs(b) = true // Zero length is always a prefix match +func (b *BitArray) EqualMSBs(x *BitArray) bool { + if b.len == x.len { + return b.Equal(x) + } + + if b.len == 0 || x.len == 0 { + return true + } + + // Compare only the first min(b.len, x.len) bits + minLen := b.len + if x.len < minLen { + minLen = x.len + } + + return new(BitArray).MSBs(b, minLen).Equal(new(BitArray).MSBs(x, minLen)) +} + +// Sets the bit array to the most significant 'n' bits of x, that is position 0 to n (exclusive). +// If n >= x.len, the bit array is an exact copy of x. +// Think of this method as array[0:n] +// For example: +// +// x = 11001011 (len=8) +// MSBs(x, 4) = 1100 (len=4) +// MSBs(x, 10) = 11001011 (len=8, original x) +// MSBs(x, 0) = 0 (len=0) +func (b *BitArray) MSBs(x *BitArray, n uint8) *BitArray { + if n >= x.len { + return b.Set(x) + } + + return b.Rsh(x, x.len-n) +} + +// Sets the bit array to the longest sequence of matching most significant bits between two bit arrays. +// For example: +// +// x = 1101 0111 (len=8) +// y = 1101 0000 (len=8) +// CommonMSBs(x,y) = 1101 (len=4) +func (b *BitArray) CommonMSBs(x, y *BitArray) *BitArray { + if x.len == 0 || y.len == 0 { + return b.clear() + } + + long, short := x, y + if x.len < y.len { + long, short = y, x + } + + // Align arrays by right-shifting longer array and then XOR to find differences + // Example: + // short = 1100 (len=4) + // long = 1101 0111 (len=8) + // + // Step 1: Right shift longer array by 4 + // short = 1100 + // long = 1101 + // + // Step 2: XOR shows difference at last bit + // 1100 (short) + // 1101 (aligned long) + // ---- XOR + // 0001 (difference at last position) + // We can then use the position of the first set bit and right-shift to get the common MSBs + diff := long.len - short.len + b.Rsh(long, diff).Xor(b, short) + divergentBit := findFirstSetBit(b) + + return b.Rsh(short, divergentBit) +} + +// Sets the bit array to x >> n and returns the bit array. +// +//nolint:mnd +func (b *BitArray) Rsh(x *BitArray, n uint8) *BitArray { + if x.len == 0 { + return b.Set(x) + } + + if n >= x.len { + return b.clear() + } + + switch { + case n == 0: + return b.Set(x) + case n >= 192: + b.rsh192(x) + b.len = x.len - n + n -= 192 + b.words[0] >>= n + case n >= 128: + b.rsh128(x) + b.len = x.len - n + n -= 128 + b.words[0] = (b.words[0] >> n) | (b.words[1] << (64 - n)) + b.words[1] >>= n + case n >= 64: + b.rsh64(x) + b.len = x.len - n + n -= 64 + b.words[0] = (b.words[0] >> n) | (b.words[1] << (64 - n)) + b.words[1] = (b.words[1] >> n) | (b.words[2] << (64 - n)) + b.words[2] >>= n + default: + b.Set(x) + b.len -= n + b.words[0] = (b.words[0] >> n) | (b.words[1] << (64 - n)) + b.words[1] = (b.words[1] >> n) | (b.words[2] << (64 - n)) + b.words[2] = (b.words[2] >> n) | (b.words[3] << (64 - n)) + b.words[3] >>= n + } + + b.truncateToLength() + return b +} + +// Lsh sets the bit array to x << n and returns the bit array. +// +//nolint:mnd +func (b *BitArray) Lsh(x *BitArray, n uint8) *BitArray { + if x.len == 0 || n == 0 { + return b.Set(x) + } + + // If the result will overflow, we set the length to the max length + // but we still shift `n` bits + if n > maxUint8-x.len { + b.len = maxUint8 + } else { + b.len = x.len + n + } + + switch { + case n >= 192: + b.lsh192(x) + n -= 192 + b.words[3] <<= n + case n >= 128: + b.lsh128(x) + n -= 128 + b.words[3] = (b.words[3] << n) | (b.words[2] >> (64 - n)) + b.words[2] <<= n + case n >= 64: + b.lsh64(x) + n -= 64 + b.words[3] = (b.words[3] << n) | (b.words[2] >> (64 - n)) + b.words[2] = (b.words[2] << n) | (b.words[1] >> (64 - n)) + b.words[1] <<= n + default: + b.words[3], b.words[2], b.words[1], b.words[0] = x.words[3], x.words[2], x.words[1], x.words[0] + b.words[3] = (b.words[3] << n) | (b.words[2] >> (64 - n)) + b.words[2] = (b.words[2] << n) | (b.words[1] >> (64 - n)) + b.words[1] = (b.words[1] << n) | (b.words[0] >> (64 - n)) + b.words[0] <<= n + } + + b.truncateToLength() + return b +} + +// Sets the bit array to the concatenation of x and y and returns the bit array. +// For example: +// +// x = 000 (len=3) +// y = 111 (len=3) +// Append(x,y) = 000111 (len=6) +func (b *BitArray) Append(x, y *BitArray) *BitArray { + if x.len == 0 || y.len == maxUint8 { + return b.Set(y) + } + if y.len == 0 { + return b.Set(x) + } + + // Then shift left by y's length and OR with y + return b.Lsh(x, y.len).Or(b, y) +} + +// Sets the bit array to the concatenation of x and a single bit. +func (b *BitArray) AppendBit(x *BitArray, bit uint8) *BitArray { + return b.Append(x, new(BitArray).SetBit(bit)) +} + +// Sets the bit array to the concatenation of x and n zeros. +func (b *BitArray) AppendZeros(x *BitArray, n uint8) *BitArray { + return b.Append(x, new(BitArray).Zeros(n)) +} + +// Sets the bit array to a subset of x from startPos (inclusive) to endPos (exclusive), +// where position 0 is the MSB. If startPos >= endPos or if startPos >= x.len, +// returns an empty BitArray. +// Think of this method as array[start:end] +// For example: +// +// x = 001011011 (len=9) +// Subset(x, 2, 5) = 101 (len=3) +func (b *BitArray) Subset(x *BitArray, startPos, endPos uint8) *BitArray { + // Check for invalid inputs + if startPos >= endPos || startPos >= x.len { + return b.clear() + } + + // Clamp endPos to x.len if it exceeds it + if endPos > x.len { + endPos = x.len + } + + length := endPos - startPos + + // First, trim off the MSBs that are not part of the subset + b.LSBs(x, startPos) + + // Then, we create a mask of ones and appends zeros to the end to match the length + mask := new(BitArray).Ones(length) + zeros := &BitArray{len: b.len - length} + mask.Append(mask, zeros) + + // Apply the mask to the bit array and then only take the first `length` bits + return b.And(b, mask).MSBs(b, length) +} + +// Sets the bit array to x | y and returns the bit array. +func (b *BitArray) Or(x, y *BitArray) *BitArray { + b.words[0] = x.words[0] | y.words[0] + b.words[1] = x.words[1] | y.words[1] + b.words[2] = x.words[2] | y.words[2] + b.words[3] = x.words[3] | y.words[3] + b.len = x.len + return b +} + +func (b *BitArray) And(x, y *BitArray) *BitArray { + b.words[0] = x.words[0] & y.words[0] + b.words[1] = x.words[1] & y.words[1] + b.words[2] = x.words[2] & y.words[2] + b.words[3] = x.words[3] & y.words[3] + b.len = x.len + return b +} + +// Sets the bit array to x ^ y and returns the bit array. +func (b *BitArray) Xor(x, y *BitArray) *BitArray { + b.words[0] = x.words[0] ^ y.words[0] + b.words[1] = x.words[1] ^ y.words[1] + b.words[2] = x.words[2] ^ y.words[2] + b.words[3] = x.words[3] ^ y.words[3] + return b +} + +// Checks if two bit arrays are equal +func (b *BitArray) Equal(x *BitArray) bool { + // TODO(weiihann): this is really not a good thing to do... + if b == nil && x == nil { + return true + } else if b == nil || x == nil { + return false + } + + return b.len == x.len && + b.words[0] == x.words[0] && + b.words[1] == x.words[1] && + b.words[2] == x.words[2] && + b.words[3] == x.words[3] +} + +// Returns true if bit n-th is set, where n = 0 is LSB. +func (b *BitArray) IsBitSetFromLSB(n uint8) bool { + return b.BitFromLSB(n) == 1 +} + +// Returns the bit value at position n, where n = 0 is LSB. +// If n is out of bounds, returns 0. +func (b *BitArray) BitFromLSB(n uint8) uint8 { + if n >= b.len { + return 0 + } + + if (b.words[n/64] & (1 << (n % 64))) != 0 { + return 1 + } + + return 0 +} + +func (b *BitArray) IsBitSet(n uint8) bool { + return b.Bit(n) == 1 +} + +// Returns the bit value at position n, where n = 0 is MSB. +// If n is out of bounds, returns 0. +func (b *BitArray) Bit(n uint8) uint8 { + if n >= b.Len() { + return 0 + } + + return b.BitFromLSB(b.Len() - n - 1) +} + +// Returns the bit value at the most significant bit +func (b *BitArray) MSB() uint8 { + return b.Bit(0) +} + +func (b *BitArray) LSB() uint8 { + return b.BitFromLSB(0) +} + +func (b *BitArray) IsEmpty() bool { + return b.len == 0 +} + +// Serialises the BitArray into a bytes buffer in the following format: +// - First byte: length of the bit array (0-255) +// - Remaining bytes: the necessary bytes included in big endian order, without leading zeros +// Example: +// +// BitArray{len: 10, words: [4]uint64{0x03FF}} -> [0x0A, 0x03, 0xFF] +func (b *BitArray) Write(buf *bytes.Buffer) (int, error) { + if err := buf.WriteByte(b.len); err != nil { + return 0, err + } + + n, err := buf.Write(b.activeBytes()) + return n + 1, err +} + +// Deserialises the BitArray from a bytes buffer in the following format: +// - First byte: length of the bit array (0-255) +// - Remaining bytes: the necessary bytes included in big endian order +// Example: +// +// [0x0A, 0x03, 0xFF] -> BitArray{len: 10, words: [4]uint64{0x03FF}} +func (b *BitArray) UnmarshalBinary(data []byte) error { + if len(data) == 0 { + return errors.New("empty data") + } + + length := data[0] + byteCount := (int(length) + 7) / 8 // Get the total number of bytes needed to represent the bit array + + if len(data) > byteCount+1 { + return fmt.Errorf("invalid data length: got %d bytes, expected <= %d", len(data), byteCount+1) + } + + b.len = length + + var bs [32]byte + bitArrBytes := data[1:] + copy(bs[32-len(bitArrBytes):], bitArrBytes) // Fill up the non-zero bytes at the end of the byte array + b.setBytes32(bs[:]) + + return nil +} + +// Sets the bit array to the same value as x. +func (b *BitArray) Set(x *BitArray) *BitArray { + b.len = x.len + b.words[0] = x.words[0] + b.words[1] = x.words[1] + b.words[2] = x.words[2] + b.words[3] = x.words[3] + return b +} + +// Sets the bit array to the bytes representation of a felt. +func (b *BitArray) SetFelt(length uint8, f *felt.Felt) *BitArray { + b.len = length + b.setFelt(f) + b.truncateToLength() + return b +} + +// Sets the bit array to the bytes representation of a felt with length 251. +func (b *BitArray) SetFelt251(f *felt.Felt) *BitArray { + b.len = 251 + b.setFelt(f) + b.truncateToLength() + return b +} + +// Interprets the data as the big-endian bytes, sets the bit array to that value and returns it. +// If the data is larger than 32 bytes, only the first 32 bytes are used. +// +//nolint:mnd,funlen,gocyclo +func (b *BitArray) SetBytes(length uint8, data []byte) *BitArray { + switch l := len(data); l { + case 0: + b.clear() + case 1: + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, 0, uint64(data[0]) + case 2: + _ = data[1] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, 0, uint64(binary.BigEndian.Uint16(data[0:2])) + case 3: + _ = data[2] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, 0, uint64(binary.BigEndian.Uint16(data[1:3]))|uint64(data[0])<<16 + case 4: + _ = data[3] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, 0, uint64(binary.BigEndian.Uint32(data[0:4])) + case 5: + _ = data[4] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, 0, bigEndianUint40(data[0:5]) + case 6: + _ = data[5] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, 0, bigEndianUint48(data[0:6]) + case 7: + _ = data[6] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, 0, bigEndianUint56(data[0:7]) + case 8: + _ = data[7] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, 0, binary.BigEndian.Uint64(data[0:8]) + case 9: + _ = data[8] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, uint64(data[0]), binary.BigEndian.Uint64(data[1:9]) + case 10: + _ = data[9] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, uint64(binary.BigEndian.Uint16(data[0:2])), binary.BigEndian.Uint64(data[2:10]) + case 11: + _ = data[10] + b.words[3], b.words[2] = 0, 0 + b.words[1], b.words[0] = uint64(binary.BigEndian.Uint16(data[1:3]))|uint64(data[0])<<16, binary.BigEndian.Uint64(data[3:11]) + case 12: + _ = data[11] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, uint64(binary.BigEndian.Uint32(data[0:4])), binary.BigEndian.Uint64(data[4:12]) + case 13: + _ = data[12] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, bigEndianUint40(data[0:5]), binary.BigEndian.Uint64(data[5:13]) + case 14: + _ = data[13] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, bigEndianUint48(data[0:6]), binary.BigEndian.Uint64(data[6:14]) + case 15: + _ = data[14] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, bigEndianUint56(data[0:7]), binary.BigEndian.Uint64(data[7:15]) + case 16: + _ = data[15] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, binary.BigEndian.Uint64(data[0:8]), binary.BigEndian.Uint64(data[8:16]) + case 17: + _ = data[16] + b.words[3], b.words[2] = 0, uint64(data[0]) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[1:9]), binary.BigEndian.Uint64(data[9:17]) + case 18: + _ = data[17] + b.words[3], b.words[2] = 0, uint64(binary.BigEndian.Uint16(data[0:2])) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[2:10]), binary.BigEndian.Uint64(data[10:18]) + case 19: + _ = data[18] + b.words[3], b.words[2] = 0, uint64(binary.BigEndian.Uint16(data[1:3]))|uint64(data[0])<<16 + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[3:11]), binary.BigEndian.Uint64(data[11:19]) + case 20: + _ = data[19] + b.words[3], b.words[2] = 0, uint64(binary.BigEndian.Uint32(data[0:4])) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[4:12]), binary.BigEndian.Uint64(data[12:20]) + case 21: + _ = data[20] + b.words[3], b.words[2] = 0, bigEndianUint40(data[0:5]) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[5:13]), binary.BigEndian.Uint64(data[13:21]) + case 22: + _ = data[21] + b.words[3], b.words[2] = 0, bigEndianUint48(data[0:6]) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[6:14]), binary.BigEndian.Uint64(data[14:22]) + case 23: + _ = data[22] + b.words[3], b.words[2] = 0, bigEndianUint56(data[0:7]) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[7:15]), binary.BigEndian.Uint64(data[15:23]) + case 24: + _ = data[23] + b.words[3], b.words[2] = 0, binary.BigEndian.Uint64(data[0:8]) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[8:16]), binary.BigEndian.Uint64(data[16:24]) + case 25: + _ = data[24] + b.words[3], b.words[2] = uint64(data[0]), binary.BigEndian.Uint64(data[1:9]) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[9:17]), binary.BigEndian.Uint64(data[17:25]) + case 26: + _ = data[25] + b.words[3], b.words[2] = uint64(binary.BigEndian.Uint16(data[0:2])), binary.BigEndian.Uint64(data[2:10]) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[10:18]), binary.BigEndian.Uint64(data[18:26]) + case 27: + _ = data[26] + b.words[3] = uint64(binary.BigEndian.Uint16(data[1:3])) | uint64(data[0])<<16 + b.words[2] = binary.BigEndian.Uint64(data[3:11]) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[11:19]), binary.BigEndian.Uint64(data[19:27]) + case 28: + _ = data[27] + b.words[3], b.words[2] = uint64(binary.BigEndian.Uint32(data[0:4])), binary.BigEndian.Uint64(data[4:12]) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[12:20]), binary.BigEndian.Uint64(data[20:28]) + case 29: + _ = data[28] + b.words[3], b.words[2] = bigEndianUint40(data[0:5]), binary.BigEndian.Uint64(data[5:13]) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[13:21]), binary.BigEndian.Uint64(data[21:29]) + case 30: + _ = data[29] + b.words[3], b.words[2] = bigEndianUint48(data[0:6]), binary.BigEndian.Uint64(data[6:14]) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[14:22]), binary.BigEndian.Uint64(data[22:30]) + case 31: + _ = data[30] + b.words[3], b.words[2] = bigEndianUint56(data[0:7]), binary.BigEndian.Uint64(data[7:15]) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[15:23]), binary.BigEndian.Uint64(data[23:31]) + default: + b.setBytes32(data) + } + b.len = length + b.truncateToLength() + return b +} + +// Sets the bit array to the uint64 representation of a bit array. +func (b *BitArray) SetUint64(length uint8, data uint64) *BitArray { + b.words[0] = data + b.len = length + b.truncateToLength() + return b +} + +// Sets the bit array to a single bit. +func (b *BitArray) SetBit(bit uint8) *BitArray { + b.len = 1 + b.words[0] = uint64(bit & 1) + b.words[1], b.words[2], b.words[3] = 0, 0, 0 + return b +} + +// Returns the length of the encoded bit array in bytes. +func (b *BitArray) EncodedLen() uint { + return b.byteCount() + 1 +} + +// Returns a deep copy of the bit array. +func (b *BitArray) Copy() BitArray { + var res BitArray + res.Set(b) + return res +} + +// Returns the encoded string representation of the bit array. +func (b *BitArray) EncodedString() string { + var res []byte + bt := b.Bytes() + res = append(res, b.len) + res = append(res, bt[:]...) + return string(res) +} + +// Returns a string representation of the bit array. +// This is typically used for logging or debugging. +func (b *BitArray) String() string { + bt := b.Bytes() + return fmt.Sprintf("(%d) %s", b.len, hex.EncodeToString(bt[:])) +} + +func (b *BitArray) setFelt(f *felt.Felt) { + res := f.Bytes() + b.words[3] = binary.BigEndian.Uint64(res[0:8]) + b.words[2] = binary.BigEndian.Uint64(res[8:16]) + b.words[1] = binary.BigEndian.Uint64(res[16:24]) + b.words[0] = binary.BigEndian.Uint64(res[24:32]) +} + +func (b *BitArray) setBytes32(data []byte) { + _ = data[31] // bound check hint, see https://golang.org/issue/14808 + b.words[3] = binary.BigEndian.Uint64(data[0:8]) + b.words[2] = binary.BigEndian.Uint64(data[8:16]) + b.words[1] = binary.BigEndian.Uint64(data[16:24]) + b.words[0] = binary.BigEndian.Uint64(data[24:32]) +} + +// Returns the minimum number of bytes needed to represent the bit array. +// It rounds up to the nearest byte. +func (b *BitArray) byteCount() uint { + const bits8 = 8 + return (uint(b.len) + (bits8 - 1)) / uint(bits8) +} + +// Returns a slice containing only the bytes that are actually used by the bit array, +// as specified by the length. Leading zero bytes will be removed. +// The returned slice is in big-endian order. +// +// Example: +// +// len = 10, words = [0x3FF, 0, 0, 0] -> [0x03, 0xFF] +func (b *BitArray) activeBytes() []byte { + if b.len == 0 { + return nil + } + + wordsBytes := b.Bytes() + end := uint(bytes32) + start := end - b.byteCount() + + // Find first non-zero byte + for start < end-1 && wordsBytes[start] == 0 { + start++ + } + + if start == end { + return nil + } + + return wordsBytes[start:end] +} + +func (b *BitArray) rsh64(x *BitArray) { + b.words[3], b.words[2], b.words[1], b.words[0] = 0, x.words[3], x.words[2], x.words[1] +} + +func (b *BitArray) rsh128(x *BitArray) { + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, x.words[3], x.words[2] +} + +func (b *BitArray) rsh192(x *BitArray) { + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, 0, x.words[3] +} + +func (b *BitArray) lsh64(x *BitArray) { + b.words[3], b.words[2], b.words[1], b.words[0] = x.words[2], x.words[1], x.words[0], 0 +} + +func (b *BitArray) lsh128(x *BitArray) { + b.words[3], b.words[2], b.words[1], b.words[0] = x.words[1], x.words[0], 0, 0 +} + +func (b *BitArray) lsh192(x *BitArray) { + b.words[3], b.words[2], b.words[1], b.words[0] = x.words[0], 0, 0, 0 +} + +func (b *BitArray) clear() *BitArray { + b.len = 0 + b.words[0], b.words[1], b.words[2], b.words[3] = 0, 0, 0, 0 + return b +} + +// Truncates the bit array to the specified length, ensuring that any unused bits are all zeros. +// +// Example: +// +// b := &BitArray{ +// len: 5, +// words: [4]uint64{ +// 0xFFFFFFFFFFFFFFFF, // Before: all bits are 1 +// 0x0, 0x0, 0x0, +// }, +// } +// b.truncateToLength() +// // After: only first 5 bits remain +// // words[0] = 0x000000000000001F +// // words[1..3] = 0x0 +// +//nolint:mnd +func (b *BitArray) truncateToLength() { + switch { + case b.len == 0: + b.words = [4]uint64{0, 0, 0, 0} + case b.len <= 64: + b.words[0] &= maxUint64 >> (64 - b.len) + b.words[1], b.words[2], b.words[3] = 0, 0, 0 + case b.len <= 128: + b.words[1] &= maxUint64 >> (128 - b.len) + b.words[2], b.words[3] = 0, 0 + case b.len <= 192: + b.words[2] &= maxUint64 >> (192 - b.len) + b.words[3] = 0 + default: + b.words[3] &= maxUint64 >> (256 - uint16(b.len)) + } +} + +// Sets the bit array to a sequence of ones of the specified length. +func (b *BitArray) Ones(length uint8) *BitArray { + b.len = length + b.words[0] = maxUint64 + b.words[1] = maxUint64 + b.words[2] = maxUint64 + b.words[3] = maxUint64 + b.truncateToLength() + return b +} + +func (b *BitArray) Zeros(length uint8) *BitArray { + b.len = length + b.words[0] = 0 + b.words[1] = 0 + b.words[2] = 0 + b.words[3] = 0 + return b +} + +// Returns the position of the first '1' bit in the array, scanning from most significant to least significant bit. +// The bit position is counted from the least significant bit, starting at 0. +// For example: +// +// array = 0000 0000 ... 0100 (len=251) +// findFirstSetBit() = 3 // third bit from right is set +func findFirstSetBit(b *BitArray) uint8 { + if b.len == 0 { + return 0 + } + + // Start from the most significant and move towards the least significant + for i := 3; i >= 0; i-- { + if word := b.words[i]; word != 0 { + return uint8((i+1)*64 - bits.LeadingZeros64(word)) + } + } + + // All bits are zero, no set bit found + return 0 +} + +// Cmp compares two bit arrays lexicographically. +// The comparison is first done by length, then by content if lengths are equal. +// Returns: +// +// -1 if b < x +// 0 if b == x +// 1 if b > x +func (b *BitArray) Cmp(x *BitArray) int { + // First compare lengths + if b.len < x.len { + return -1 + } + if b.len > x.len { + return 1 + } + + // Lengths are equal, compare the actual bits + d0, carry := bits.Sub64(b.words[0], x.words[0], 0) + d1, carry := bits.Sub64(b.words[1], x.words[1], carry) + d2, carry := bits.Sub64(b.words[2], x.words[2], carry) + d3, carry := bits.Sub64(b.words[3], x.words[3], carry) + + if carry == 1 { + return -1 + } + + if d0|d1|d2|d3 == 0 { + return 0 + } + + return 1 +} + +func bigEndianUint40(b []byte) uint64 { + _ = b[4] // bounds check hint to compiler; see golang.org/issue/14808 + return uint64(b[4]) | uint64(b[3])<<8 | uint64(b[2])<<16 | uint64(b[1])<<24 | + uint64(b[0])<<32 +} + +func bigEndianUint48(b []byte) uint64 { + _ = b[5] // bounds check hint to compiler; see golang.org/issue/14808 + return uint64(b[5]) | uint64(b[4])<<8 | uint64(b[3])<<16 | uint64(b[2])<<24 | + uint64(b[1])<<32 | uint64(b[0])<<40 +} + +func bigEndianUint56(b []byte) uint64 { + _ = b[6] // bounds check hint to compiler; see golang.org/issue/14808 + return uint64(b[6]) | uint64(b[5])<<8 | uint64(b[4])<<16 | uint64(b[3])<<24 | + uint64(b[2])<<32 | uint64(b[1])<<40 | uint64(b[0])<<48 +} diff --git a/core/trie2/trieutils/bitarray_test.go b/core/trie2/trieutils/bitarray_test.go new file mode 100644 index 0000000000..1835b538bf --- /dev/null +++ b/core/trie2/trieutils/bitarray_test.go @@ -0,0 +1,2126 @@ +package trieutils + +import ( + "bytes" + "encoding/binary" + "math/bits" + "testing" + + "github.com/NethermindEth/juno/core/felt" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const ( + ones63 = 0x7FFFFFFFFFFFFFFF // 63 bits of 1 +) + +func TestBytes(t *testing.T) { + tests := []struct { + name string + ba BitArray + want [32]byte + }{ + { + name: "length == 0", + ba: BitArray{len: 0, words: [4]uint64{0, 0, 0, 0}}, + want: [32]byte{}, + }, + { + name: "length < 64", + ba: BitArray{len: 38, words: [4]uint64{0x3FFFFFFFFF, 0, 0, 0}}, + want: func() [32]byte { + var b [32]byte + binary.BigEndian.PutUint64(b[24:32], 0x3FFFFFFFFF) + return b + }(), + }, + { + name: "64 <= length < 128", + ba: BitArray{len: 100, words: [4]uint64{maxUint64, 0xFFFFFFFFF, 0, 0}}, + want: func() [32]byte { + var b [32]byte + binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFF) + binary.BigEndian.PutUint64(b[24:32], maxUint64) + return b + }(), + }, + { + name: "128 <= length < 192", + ba: BitArray{len: 130, words: [4]uint64{maxUint64, maxUint64, 0x3, 0}}, + want: func() [32]byte { + var b [32]byte + binary.BigEndian.PutUint64(b[8:16], 0x3) + binary.BigEndian.PutUint64(b[16:24], maxUint64) + binary.BigEndian.PutUint64(b[24:32], maxUint64) + return b + }(), + }, + { + name: "192 <= length < 255", + ba: BitArray{len: 201, words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x1FF}}, + want: func() [32]byte { + var b [32]byte + binary.BigEndian.PutUint64(b[0:8], 0x1FF) + binary.BigEndian.PutUint64(b[8:16], maxUint64) + binary.BigEndian.PutUint64(b[16:24], maxUint64) + binary.BigEndian.PutUint64(b[24:32], maxUint64) + return b + }(), + }, + { + name: "length == 254", + ba: BitArray{len: 254, words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x3FFFFFFFFFFFFFFF}}, + want: func() [32]byte { + var b [32]byte + binary.BigEndian.PutUint64(b[0:8], 0x3FFFFFFFFFFFFFFF) + binary.BigEndian.PutUint64(b[8:16], maxUint64) + binary.BigEndian.PutUint64(b[16:24], maxUint64) + binary.BigEndian.PutUint64(b[24:32], maxUint64) + return b + }(), + }, + { + name: "length == 255", + ba: BitArray{len: 255, words: [4]uint64{maxUint64, maxUint64, maxUint64, ones63}}, + want: func() [32]byte { + var b [32]byte + binary.BigEndian.PutUint64(b[0:8], ones63) + binary.BigEndian.PutUint64(b[8:16], maxUint64) + binary.BigEndian.PutUint64(b[16:24], maxUint64) + binary.BigEndian.PutUint64(b[24:32], maxUint64) + return b + }(), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.ba.Bytes() + if !bytes.Equal(got[:], tt.want[:]) { + t.Errorf("BitArray.Bytes() = %v, want %v", got, tt.want) + } + + // check if the received bytes has the same bit count as the BitArray.len + count := 0 + for _, b := range got { + count += bits.OnesCount8(b) + } + if count != int(tt.ba.len) { + t.Errorf("BitArray.Bytes() bit count = %v, want %v", count, tt.ba.len) + } + }) + } +} + +func TestRsh(t *testing.T) { + tests := []struct { + name string + initial *BitArray + shiftBy uint8 + expected *BitArray + }{ + { + name: "zero length array", + initial: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + shiftBy: 5, + expected: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "shift by 0", + initial: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + shiftBy: 0, + expected: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "shift by more than length", + initial: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + shiftBy: 65, + expected: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "shift by less than 64", + initial: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + shiftBy: 32, + expected: &BitArray{ + len: 96, + words: [4]uint64{maxUint64, 0x00000000FFFFFFFF, 0, 0}, + }, + }, + { + name: "shift by exactly 64", + initial: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + shiftBy: 64, + expected: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "shift by 127", + initial: &BitArray{ + len: 255, + words: [4]uint64{maxUint64, maxUint64, maxUint64, ones63}, + }, + shiftBy: 127, + expected: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + }, + { + name: "shift by 128", + initial: &BitArray{ + len: 251, + words: [4]uint64{maxUint64, maxUint64, maxUint64, maxUint64}, + }, + shiftBy: 128, + expected: &BitArray{ + len: 123, + words: [4]uint64{maxUint64, 0x7FFFFFFFFFFFFFF, 0, 0}, + }, + }, + { + name: "shift by 192", + initial: &BitArray{ + len: 251, + words: [4]uint64{maxUint64, maxUint64, maxUint64, maxUint64}, + }, + shiftBy: 192, + expected: &BitArray{ + len: 59, + words: [4]uint64{0x7FFFFFFFFFFFFFF, 0, 0, 0}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := new(BitArray).Rsh(tt.initial, tt.shiftBy) + if !result.Equal(tt.expected) { + t.Errorf("Rsh() got = %+v, want %+v", result, tt.expected) + } + }) + } +} + +func TestLsh(t *testing.T) { + tests := []struct { + name string + x *BitArray + n uint8 + want *BitArray + }{ + { + name: "empty array", + x: emptyBitArray, + n: 5, + want: emptyBitArray, + }, + { + name: "shift by 0", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + n: 0, + want: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "shift within first word", + x: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + n: 4, + want: &BitArray{ + len: 8, + words: [4]uint64{0xF0, 0, 0, 0}, // 11110000 + }, + }, + { + name: "shift across word boundary", + x: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + n: 62, + want: &BitArray{ + len: 66, + words: [4]uint64{0xC000000000000000, 0x3, 0, 0}, + }, + }, + { + name: "shift by 64 (full word)", + x: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 + }, + n: 64, + want: &BitArray{ + len: 72, + words: [4]uint64{0, 0xFF, 0, 0}, + }, + }, + { + name: "shift by 128", + x: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 + }, + n: 128, + want: &BitArray{ + len: 136, + words: [4]uint64{0, 0, 0xFF, 0}, + }, + }, + { + name: "shift by 192", + x: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 + }, + n: 192, + want: &BitArray{ + len: 200, + words: [4]uint64{0, 0, 0, 0xFF}, + }, + }, + { + name: "shift causing length overflow", + x: &BitArray{ + len: 200, + words: [4]uint64{0xFF, 0, 0, 0}, + }, + n: 60, + want: &BitArray{ + len: 255, // capped at maxUint8 + words: [4]uint64{ + 0xF000000000000000, + 0xF, + 0, + 0, + }, + }, + }, + { + name: "shift sparse bits", + x: &BitArray{ + len: 8, + words: [4]uint64{0xAA, 0, 0, 0}, // 10101010 + }, + n: 4, + want: &BitArray{ + len: 12, + words: [4]uint64{0xAA0, 0, 0, 0}, // 101010100000 + }, + }, + { + name: "shift partial word across boundary", + x: &BitArray{ + len: 100, + words: [4]uint64{0xFF, 0xFF, 0, 0}, + }, + n: 60, + want: &BitArray{ + len: 160, + words: [4]uint64{ + 0xF000000000000000, + 0xF00000000000000F, + 0xF, + 0, + }, + }, + }, + { + name: "near maximum length shift", + x: &BitArray{ + len: 251, + words: [4]uint64{0xFF, 0, 0, 0}, + }, + n: 4, + want: &BitArray{ + len: 255, // capped at maxUint8 + words: [4]uint64{0xFF0, 0, 0, 0}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := new(BitArray).Lsh(tt.x, tt.n) + if !got.Equal(tt.want) { + t.Errorf("Lsh() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAppend(t *testing.T) { + tests := []struct { + name string + x *BitArray + y *BitArray + want *BitArray + }{ + { + name: "both empty arrays", + x: emptyBitArray, + y: emptyBitArray, + want: emptyBitArray, + }, + { + name: "first array empty", + x: emptyBitArray, + y: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + want: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + }, + { + name: "second array empty", + x: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + y: emptyBitArray, + want: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + }, + { + name: "within first word", + x: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + y: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + want: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 + }, + }, + { + name: "different lengths within word", + x: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + y: &BitArray{ + len: 2, + words: [4]uint64{0x3, 0, 0, 0}, // 11 + }, + want: &BitArray{ + len: 6, + words: [4]uint64{0x3F, 0, 0, 0}, // 111111 + }, + }, + { + name: "across word boundary", + x: &BitArray{ + len: 62, + words: [4]uint64{0x3FFFFFFFFFFFFFFF, 0, 0, 0}, + }, + y: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + want: &BitArray{ + len: 66, + words: [4]uint64{maxUint64, 0x3, 0, 0}, + }, + }, + { + name: "across multiple words", + x: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + y: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + want: &BitArray{ + len: 192, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0}, + }, + }, + { + name: "sparse bits", + x: &BitArray{ + len: 8, + words: [4]uint64{0xAA, 0, 0, 0}, // 10101010 + }, + y: &BitArray{ + len: 8, + words: [4]uint64{0x55, 0, 0, 0}, // 01010101 + }, + want: &BitArray{ + len: 16, + words: [4]uint64{0xAA55, 0, 0, 0}, // 1010101001010101 + }, + }, + { + name: "result exactly at length limit", + x: &BitArray{ + len: 251, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFFF}, + }, + y: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, + }, + want: &BitArray{ + len: 255, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFFF}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := new(BitArray).Append(tt.x, tt.y) + if !got.Equal(tt.want) { + t.Errorf("Append() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestEqualMSBs(t *testing.T) { + tests := []struct { + name string + a *BitArray + b *BitArray + want bool + }{ + { + name: "equal lengths, equal values", + a: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + b: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + want: true, + }, + { + name: "equal lengths, different values", + a: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + b: &BitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFF0, 0, 0, 0}, + }, + want: false, + }, + { + name: "different lengths, a longer but same prefix", + a: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + b: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + want: true, + }, + { + name: "different lengths, b longer but same prefix", + a: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + b: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + want: true, + }, + { + name: "different lengths, different prefix", + a: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + b: &BitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFFFFFFFFF0, 0, 0, 0}, + }, + want: false, + }, + { + name: "zero length arrays", + a: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + b: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + want: true, + }, + { + name: "one zero length array", + a: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + b: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + want: true, + }, + { + name: "max length difference", + a: &BitArray{ + len: 251, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFF}, + }, + b: &BitArray{ + len: 1, + words: [4]uint64{0x1, 0, 0, 0}, + }, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.a.EqualMSBs(tt.b); got != tt.want { + t.Errorf("PrefixEqual() = %v, want %v", got, tt.want) + } + // Test symmetry: a.PrefixEqual(b) should equal b.PrefixEqual(a) + if got := tt.b.EqualMSBs(tt.a); got != tt.want { + t.Errorf("PrefixEqual() symmetric test = %v, want %v", got, tt.want) + } + }) + } +} + +func TestLSBs(t *testing.T) { + tests := []struct { + name string + x *BitArray + pos uint8 + want *BitArray + }{ + { + name: "zero position", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + pos: 0, + want: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "position beyond length", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + pos: 65, + want: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "get last 4 bits", + x: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 + }, + pos: 4, + want: &BitArray{ + len: 4, + words: [4]uint64{0x0F, 0, 0, 0}, // 1111 + }, + }, + { + name: "get bits across word boundary", + x: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + pos: 64, + want: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "get bits from max length array", + x: &BitArray{ + len: 251, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFF}, + }, + pos: 200, + want: &BitArray{ + len: 51, + words: [4]uint64{0x7FFFFFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "empty array", + x: emptyBitArray, + pos: 1, + want: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "sparse bits", + x: &BitArray{ + len: 16, + words: [4]uint64{0xAAAA, 0, 0, 0}, // 1010101010101010 + }, + pos: 8, + want: &BitArray{ + len: 8, + words: [4]uint64{0xAA, 0, 0, 0}, // 10101010 + }, + }, + { + name: "position equals length", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + pos: 64, + want: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := new(BitArray).LSBs(tt.x, tt.pos) + if !got.Equal(tt.want) { + t.Errorf("LSBs() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestLSBsFromLSB(t *testing.T) { + tests := []struct { + name string + initial BitArray + length uint8 + expected BitArray + }{ + { + name: "zero", + initial: BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + length: 0, + expected: BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "get 32 LSBs", + initial: BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + length: 32, + expected: BitArray{ + len: 32, + words: [4]uint64{0x00000000FFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "get 1 LSB", + initial: BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + length: 1, + expected: BitArray{ + len: 1, + words: [4]uint64{0x1, 0, 0, 0}, + }, + }, + { + name: "get 100 LSBs across words", + initial: BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + length: 100, + expected: BitArray{ + len: 100, + words: [4]uint64{maxUint64, 0x0000000FFFFFFFFF, 0, 0}, + }, + }, + { + name: "get 64 LSBs at word boundary", + initial: BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + length: 64, + expected: BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "get 128 LSBs at word boundary", + initial: BitArray{ + len: 192, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0}, + }, + length: 128, + expected: BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + }, + { + name: "get 150 LSBs in third word", + initial: BitArray{ + len: 192, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0}, + }, + length: 150, + expected: BitArray{ + len: 150, + words: [4]uint64{maxUint64, maxUint64, 0x3FFFFF, 0}, + }, + }, + { + name: "get 220 LSBs in fourth word", + initial: BitArray{ + len: 255, + words: [4]uint64{maxUint64, maxUint64, maxUint64, maxUint64}, + }, + length: 220, + expected: BitArray{ + len: 220, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0xFFFFFFF}, + }, + }, + { + name: "get 251 LSBs", + initial: BitArray{ + len: 255, + words: [4]uint64{maxUint64, maxUint64, maxUint64, maxUint64}, + }, + length: 251, + expected: BitArray{ + len: 251, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFF}, + }, + }, + { + name: "get 100 LSBs from sparse bits", + initial: BitArray{ + len: 128, + words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0x5555555555555555, 0, 0}, + }, + length: 100, + expected: BitArray{ + len: 100, + words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0x0000000555555555, 0, 0}, + }, + }, + { + name: "no change when new length equals current length", + initial: BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + length: 64, + expected: BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "no change when new length greater than current length", + initial: BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + length: 128, + expected: BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := new(BitArray).LSBsFromLSB(&tt.initial, tt.length) + if !result.Equal(&tt.expected) { + t.Errorf("Truncate() got = %+v, want %+v", result, tt.expected) + } + }) + } +} + +func TestMSBs(t *testing.T) { + tests := []struct { + name string + x *BitArray + n uint8 + want *BitArray + }{ + { + name: "empty array", + x: emptyBitArray, + n: 0, + want: emptyBitArray, + }, + { + name: "get all bits", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + n: 64, + want: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "get more bits than available", + x: &BitArray{ + len: 32, + words: [4]uint64{0xFFFFFFFF, 0, 0, 0}, + }, + n: 64, + want: &BitArray{ + len: 32, + words: [4]uint64{0xFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "get half of available bits", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + n: 32, + want: &BitArray{ + len: 32, + words: [4]uint64{0xFFFFFFFF00000000 >> 32, 0, 0, 0}, + }, + }, + { + name: "get MSBs across word boundary", + x: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + n: 100, + want: &BitArray{ + len: 100, + words: [4]uint64{maxUint64, maxUint64 >> 28, 0, 0}, + }, + }, + { + name: "get MSBs from max length array", + x: &BitArray{ + len: 255, + words: [4]uint64{maxUint64, maxUint64, maxUint64, ones63}, + }, + n: 64, + want: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "get zero bits", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + n: 0, + want: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "sparse bits", + x: &BitArray{ + len: 128, + words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0x5555555555555555, 0, 0}, + }, + n: 64, + want: &BitArray{ + len: 64, + words: [4]uint64{0x5555555555555555, 0, 0, 0}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := new(BitArray).MSBs(tt.x, tt.n) + if !got.Equal(tt.want) { + t.Errorf("MSBs() = %v, want %v", got, tt.want) + } + + if got.len != tt.want.len { + t.Errorf("MSBs() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestWriteAndUnmarshalBinary(t *testing.T) { + tests := []struct { + name string + ba BitArray + want []byte // Expected bytes after writing + }{ + { + name: "empty bit array", + ba: BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + want: []byte{0}, // Just the length byte + }, + { + name: "8 bits", + ba: BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, + }, + want: []byte{8, 0xFF}, // length byte + 1 data byte + }, + { + name: "10 bits requiring 2 bytes", + ba: BitArray{ + len: 10, + words: [4]uint64{0x3FF, 0, 0, 0}, // 1111111111 in binary + }, + want: []byte{10, 0x3, 0xFF}, // length byte + 2 data bytes + }, + { + name: "64 bits", + ba: BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + want: append( + []byte{64}, // length byte + []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}..., // 8 data bytes + ), + }, + { + name: "251 bits", + ba: BitArray{ + len: 251, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFF}, + }, + want: func() []byte { + b := make([]byte, 33) // 1 length byte + 32 data bytes + b[0] = 251 // length byte + // First byte is 0x07 (from the most significant bits) + b[1] = 0x07 + // Rest of the bytes are 0xFF + for i := 2; i < 33; i++ { + b[i] = 0xFF + } + return b + }(), + }, + { + name: "sparse bits", + ba: BitArray{ + len: 16, + words: [4]uint64{0xAAAA, 0, 0, 0}, // 1010101010101010 in binary + }, + want: []byte{16, 0xAA, 0xAA}, // length byte + 2 data bytes + }, + { + name: "leading zeros in first byte", + ba: BitArray{ + len: 8, + words: [4]uint64{0x0F, 0, 0, 0}, // 00001111 + }, + want: []byte{8, 0x0F}, // length byte + 1 data byte + }, + { + name: "all leading zeros in first byte", + ba: BitArray{ + len: 8, + words: [4]uint64{0x00, 0, 0, 0}, // 00000000 + }, + want: []byte{8, 0x00}, // length byte + 1 data byte + }, + { + name: "leading zeros across multiple bytes", + ba: BitArray{ + len: 24, + words: [4]uint64{0x0000FF, 0, 0, 0}, // 000000000000000011111111 + }, + want: []byte{24, 0xFF}, // length byte + 1 data byte + }, + { + name: "leading zeros in large number", + ba: BitArray{ + len: 255, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, // All 1s in lower bits, zeros in upper bits + }, + want: append( + []byte{255}, // length byte + bytes.Repeat([]byte{0xFF}, 16)..., // 16 bytes of all 1s + ), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + buf := new(bytes.Buffer) + gotN, err := tt.ba.Write(buf) + assert.NoError(t, err) + + // Check number of bytes written + if gotN != len(tt.want) { + t.Errorf("Write() wrote %d bytes, want %d", gotN, len(tt.want)) + } + + // Check written bytes + got := buf.Bytes() + if !bytes.Equal(got, tt.want) { + t.Errorf("Write() = %v, want %v", got, tt.want) + } + + var gotBitArray BitArray + err = gotBitArray.UnmarshalBinary(got) + require.NoError(t, err) + if !gotBitArray.Equal(&tt.ba) { + t.Errorf("UnmarshalBinary() = %v, want %v", gotBitArray, tt.ba) + } + }) + } +} + +func TestCommonPrefix(t *testing.T) { + tests := []struct { + name string + x *BitArray + y *BitArray + want *BitArray + }{ + { + name: "empty arrays", + x: emptyBitArray, + y: emptyBitArray, + want: emptyBitArray, + }, + { + name: "one empty array", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + y: emptyBitArray, + want: emptyBitArray, + }, + { + name: "identical arrays - single word", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + y: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + want: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "identical arrays - multiple words", + x: &BitArray{ + len: 192, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0}, + }, + y: &BitArray{ + len: 192, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0}, + }, + want: &BitArray{ + len: 192, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0}, + }, + }, + { + name: "different lengths with common prefix - first word", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + y: &BitArray{ + len: 32, + words: [4]uint64{0xFFFFFFFF, 0, 0, 0}, + }, + want: &BitArray{ + len: 32, + words: [4]uint64{0xFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "different lengths with common prefix - multiple words", + x: &BitArray{ + len: 255, + words: [4]uint64{maxUint64, maxUint64, maxUint64, ones63}, + }, + y: &BitArray{ + len: 127, + words: [4]uint64{maxUint64, ones63, 0, 0}, + }, + want: &BitArray{ + len: 127, + words: [4]uint64{maxUint64, ones63, 0, 0}, + }, + }, + { + name: "different at first bit", + x: &BitArray{ + len: 64, + words: [4]uint64{ones63, 0, 0, 0}, + }, + y: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + want: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "different in middle of first word", + x: &BitArray{ + len: 64, + words: [4]uint64{0xFFFFFFFF0FFFFFFF, 0, 0, 0}, + }, + y: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + want: &BitArray{ + len: 32, + words: [4]uint64{0xFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "different in second word", + x: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, 0xFFFFFFFF0FFFFFFF, 0, 0}, + }, + y: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + want: &BitArray{ + len: 32, + words: [4]uint64{0xFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "different in third word", + x: &BitArray{ + len: 192, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0}, + }, + y: &BitArray{ + len: 192, + words: [4]uint64{0, 0, 0xFFFFFFFFFFFFFF0F, 0}, + }, + want: &BitArray{ + len: 56, + words: [4]uint64{0xFFFFFFFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "different in last word", + x: &BitArray{ + len: 251, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFF}, + }, + y: &BitArray{ + len: 251, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFF0FFFFFFF}, + }, + want: &BitArray{ + len: 27, + words: [4]uint64{0x7FFFFFF}, + }, + }, + { + name: "sparse bits with common prefix", + x: &BitArray{ + len: 128, + words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0xAAAAAAAAAAAAAAAA, 0, 0}, + }, + y: &BitArray{ + len: 128, + words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0xAAAAAAAAAAAAA000, 0, 0}, + }, + want: &BitArray{ + len: 52, + words: [4]uint64{0xAAAAAAAAAAAAA, 0, 0, 0}, + }, + }, + { + name: "max length difference", + x: &BitArray{ + len: 255, + words: [4]uint64{maxUint64, maxUint64, maxUint64, ones63}, + }, + y: &BitArray{ + len: 1, + words: [4]uint64{0x1, 0, 0, 0}, + }, + want: &BitArray{ + len: 1, + words: [4]uint64{0x1, 0, 0, 0}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := new(BitArray) + gotSymmetric := new(BitArray) + + got.CommonMSBs(tt.x, tt.y) + if !got.Equal(tt.want) { + t.Errorf("CommonMSBs() = %v, want %v", got, tt.want) + } + + // Test symmetry: x.CommonMSBs(y) should equal y.CommonMSBs(x) + gotSymmetric.CommonMSBs(tt.y, tt.x) + if !gotSymmetric.Equal(tt.want) { + t.Errorf("CommonMSBs() symmetric test = %v, want %v", gotSymmetric, tt.want) + } + }) + } +} + +func TestIsBitSetFromLSB(t *testing.T) { + tests := []struct { + name string + ba BitArray + pos uint8 + want bool + }{ + { + name: "empty array", + ba: BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + pos: 0, + want: false, + }, + { + name: "first bit set", + ba: BitArray{ + len: 64, + words: [4]uint64{1, 0, 0, 0}, + }, + pos: 0, + want: true, + }, + { + name: "last bit in first word", + ba: BitArray{ + len: 64, + words: [4]uint64{1 << 63, 0, 0, 0}, + }, + pos: 63, + want: true, + }, + { + name: "first bit in second word", + ba: BitArray{ + len: 128, + words: [4]uint64{0, 1, 0, 0}, + }, + pos: 64, + want: true, + }, + { + name: "bit beyond length", + ba: BitArray{ + len: 64, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + pos: 65, + want: false, + }, + { + name: "alternating bits", + ba: BitArray{ + len: 8, + words: [4]uint64{0xAA, 0, 0, 0}, // 10101010 in binary + }, + pos: 1, + want: true, + }, + { + name: "alternating bits - unset position", + ba: BitArray{ + len: 8, + words: [4]uint64{0xAA, 0, 0, 0}, // 10101010 in binary + }, + pos: 0, + want: false, + }, + { + name: "bit in last word", + ba: BitArray{ + len: 251, + words: [4]uint64{0, 0, 0, 1 << 59}, + }, + pos: 251, + want: false, // position 251 is beyond the highest valid bit (250) + }, + { + name: "251 bits", + ba: BitArray{ + len: 251, + words: [4]uint64{0, 0, 0, 1 << 58}, + }, + pos: 250, + want: true, + }, + { + name: "highest valid bit (255)", + ba: BitArray{ + len: 255, + words: [4]uint64{0, 0, 0, 1 << 62}, // bit 255 set + }, + pos: 254, + want: true, + }, + { + name: "position at length boundary", + ba: BitArray{ + len: 100, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + pos: 100, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.ba.IsBitSetFromLSB(tt.pos) + if got != tt.want { + t.Errorf("IsBitSetFromLSB(%d) = %v, want %v", tt.pos, got, tt.want) + } + }) + } +} + +func TestIsBitSet(t *testing.T) { + tests := []struct { + name string + ba BitArray + pos uint8 + want bool + }{ + { + name: "empty array", + ba: BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + pos: 0, + want: false, + }, + { + name: "first bit (MSB) set", + ba: BitArray{ + len: 8, + words: [4]uint64{0x80, 0, 0, 0}, // 10000000 + }, + pos: 0, + want: true, + }, + { + name: "last bit (LSB) set", + ba: BitArray{ + len: 8, + words: [4]uint64{0x01, 0, 0, 0}, // 00000001 + }, + pos: 7, + want: true, + }, + { + name: "alternating bits", + ba: BitArray{ + len: 8, + words: [4]uint64{0xAA, 0, 0, 0}, // 10101010 + }, + pos: 0, + want: true, + }, + { + name: "alternating bits - unset position", + ba: BitArray{ + len: 8, + words: [4]uint64{0xAA, 0, 0, 0}, // 10101010 + }, + pos: 1, + want: false, + }, + { + name: "position beyond length", + ba: BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, + }, + pos: 8, + want: false, + }, + { + name: "bit in second word", + ba: BitArray{ + len: 128, + words: [4]uint64{0, 1, 0, 0}, + }, + pos: 63, + want: true, + }, + { + name: "251 bits", + ba: BitArray{ + len: 251, + words: [4]uint64{0, 0, 0, 1 << 58}, + }, + pos: 0, + want: true, + }, + { + name: "position at length boundary", + ba: BitArray{ + len: 100, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + pos: 99, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.ba.IsBitSet(tt.pos) + if got != tt.want { + t.Errorf("IsBitSet(%d) = %v, want %v", tt.pos, got, tt.want) + } + }) + } +} + +func TestFeltConversion(t *testing.T) { + tests := []struct { + name string + ba BitArray + length uint8 + want string // hex representation of felt + }{ + { + name: "empty bit array", + ba: BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + length: 0, + want: "0x0", + }, + { + name: "single word", + ba: BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + length: 64, + want: "0xffffffffffffffff", + }, + { + name: "two words", + ba: BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + length: 128, + want: "0xffffffffffffffffffffffffffffffff", + }, + { + name: "three words", + ba: BitArray{ + len: 192, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0}, + }, + length: 192, + want: "0xffffffffffffffffffffffffffffffffffffffffffffffff", + }, + { + name: "251 bits", + ba: BitArray{ + len: 251, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFF}, + }, + length: 251, + want: "0x7ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", + }, + { + name: "sparse bits", + ba: BitArray{ + len: 128, + words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0x5555555555555555, 0, 0}, + }, + length: 128, + want: "0x5555555555555555aaaaaaaaaaaaaaaa", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test Felt() conversion + gotFelt := tt.ba.Felt() + assert.Equal(t, tt.want, gotFelt.String()) + + // Test SetFelt() conversion (round trip) + var newBA BitArray + newBA.SetFelt(tt.length, &gotFelt) + assert.Equal(t, tt.ba.len, newBA.len) + assert.Equal(t, tt.ba.words, newBA.words) + }) + } +} + +func TestSetFeltValidation(t *testing.T) { + tests := []struct { + name string + feltStr string + length uint8 + shouldMatch bool + }{ + { + name: "valid felt with matching length", + feltStr: "0xf", + length: 4, + shouldMatch: true, + }, + { + name: "felt larger than specified length", + feltStr: "0xff", + length: 4, + shouldMatch: false, + }, + { + name: "zero felt with non-zero length", + feltStr: "0x0", + length: 8, + shouldMatch: true, + }, + { + name: "max felt with max length", + feltStr: "0x7ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", + length: 251, + shouldMatch: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var f felt.Felt + _, err := f.SetString(tt.feltStr) + require.NoError(t, err) + + var ba BitArray + ba.SetFelt(tt.length, &f) + + // Convert back to felt and compare + roundTrip := ba.Felt() + if tt.shouldMatch { + assert.True(t, roundTrip.Equal(&f), + "expected %s, got %s", f.String(), roundTrip.String()) + } else { + assert.False(t, roundTrip.Equal(&f), + "values should not match: original %s, roundtrip %s", + f.String(), roundTrip.String()) + } + }) + } +} + +func TestSetBit(t *testing.T) { + tests := []struct { + name string + bit uint8 + want BitArray + }{ + { + name: "set bit 0", + bit: 0, + want: BitArray{ + len: 1, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "set bit 1", + bit: 1, + want: BitArray{ + len: 1, + words: [4]uint64{1, 0, 0, 0}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := new(BitArray).SetBit(tt.bit) + if !got.Equal(&tt.want) { + t.Errorf("SetBit(%v) = %v, want %v", tt.bit, got, tt.want) + } + }) + } +} + +func TestCmp(t *testing.T) { + tests := []struct { + name string + x BitArray + y BitArray + want int + }{ + { + name: "equal empty arrays", + x: BitArray{len: 0, words: [4]uint64{0, 0, 0, 0}}, + y: BitArray{len: 0, words: [4]uint64{0, 0, 0, 0}}, + want: 0, + }, + { + name: "equal non-empty arrays", + x: BitArray{len: 64, words: [4]uint64{maxUint64, 0, 0, 0}}, + y: BitArray{len: 64, words: [4]uint64{maxUint64, 0, 0, 0}}, + want: 0, + }, + { + name: "different lengths - x shorter", + x: BitArray{len: 32, words: [4]uint64{maxUint64, 0, 0, 0}}, + y: BitArray{len: 64, words: [4]uint64{maxUint64, 0, 0, 0}}, + want: -1, + }, + { + name: "different lengths - x longer", + x: BitArray{len: 64, words: [4]uint64{maxUint64, 0, 0, 0}}, + y: BitArray{len: 32, words: [4]uint64{maxUint64, 0, 0, 0}}, + want: 1, + }, + { + name: "same length, x < y in first word", + x: BitArray{len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFE, 0, 0, 0}}, + y: BitArray{len: 64, words: [4]uint64{maxUint64, 0, 0, 0}}, + want: -1, + }, + { + name: "same length, x > y in first word", + x: BitArray{len: 64, words: [4]uint64{maxUint64, 0, 0, 0}}, + y: BitArray{len: 64, words: [4]uint64{0xFFFFFFFFFFFFFFFE, 0, 0, 0}}, + want: 1, + }, + { + name: "same length, difference in last word", + x: BitArray{len: 251, words: [4]uint64{0, 0, 0, 0x7FFFFFFFFFFFFFF}}, + y: BitArray{len: 251, words: [4]uint64{0, 0, 0, 0x7FFFFFFFFFFFFF0}}, + want: 1, + }, + { + name: "same length, sparse bits", + x: BitArray{len: 128, words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0xAAAAAAAAAAAAAAAA, 0, 0}}, + y: BitArray{len: 128, words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0xAAAAAAAAAAAAA000, 0, 0}}, + want: 1, + }, + { + name: "max length difference", + x: BitArray{len: 255, words: [4]uint64{maxUint64, maxUint64, maxUint64, ones63}}, + y: BitArray{len: 1, words: [4]uint64{0x1, 0, 0, 0}}, + want: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.x.Cmp(&tt.y) + if got != tt.want { + t.Errorf("Cmp() = %v, want %v", got, tt.want) + } + + // Test anti-symmetry: if x.Cmp(y) = z then y.Cmp(x) = -z + gotReverse := tt.y.Cmp(&tt.x) + if gotReverse != -tt.want { + t.Errorf("Reverse Cmp() = %v, want %v", gotReverse, -tt.want) + } + + // Test transitivity with self: x.Cmp(x) should always be 0 + if tt.x.Cmp(&tt.x) != 0 { + t.Error("Self Cmp() != 0") + } + }) + } +} + +func TestSetBytes(t *testing.T) { + tests := []struct { + name string + length uint8 + data []byte + want BitArray + }{ + { + name: "empty data", + length: 0, + data: []byte{}, + want: BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "single byte", + length: 8, + data: []byte{0xFF}, + want: BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, + }, + }, + { + name: "two bytes", + length: 16, + data: []byte{0xAA, 0xFF}, + want: BitArray{ + len: 16, + words: [4]uint64{0xAAFF, 0, 0, 0}, + }, + }, + { + name: "three bytes", + length: 24, + data: []byte{0xAA, 0xBB, 0xCC}, + want: BitArray{ + len: 24, + words: [4]uint64{0xAABBCC, 0, 0, 0}, + }, + }, + { + name: "four bytes", + length: 32, + data: []byte{0xAA, 0xBB, 0xCC, 0xDD}, + want: BitArray{ + len: 32, + words: [4]uint64{0xAABBCCDD, 0, 0, 0}, + }, + }, + { + name: "eight bytes (full word)", + length: 64, + data: []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}, + want: BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "sixteen bytes (two words)", + length: 128, + data: []byte{ + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, + }, + want: BitArray{ + len: 128, + words: [4]uint64{ + 0xAAAAAAAAAAAAAAAA, + 0xFFFFFFFFFFFFFFFF, + 0, 0, + }, + }, + }, + { + name: "thirty-two bytes (full array)", + length: 251, + data: []byte{ + 0x7F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + }, + want: BitArray{ + len: 251, + words: [4]uint64{ + maxUint64, + maxUint64, + maxUint64, + 0x7FFFFFFFFFFFFFF, + }, + }, + }, + { + name: "truncate to length", + length: 4, + data: []byte{0xFF}, + want: BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, + }, + }, + { + name: "data larger than 32 bytes", + length: 251, + data: []byte{ + 0x7F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, // extra bytes should be ignored + }, + want: BitArray{ + len: 251, + words: [4]uint64{ + maxUint64, + maxUint64, + maxUint64, + 0x7FFFFFFFFFFFFFF, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := new(BitArray).SetBytes(tt.length, tt.data) + if !got.Equal(&tt.want) { + t.Errorf("SetBytes(%d, %v) = %v, want %v", tt.length, tt.data, got, tt.want) + } + }) + } +} + +func TestSubset(t *testing.T) { + tests := []struct { + name string + x *BitArray + startPos uint8 + endPos uint8 + want *BitArray + }{ + { + name: "empty array", + x: emptyBitArray, + startPos: 0, + endPos: 0, + want: emptyBitArray, + }, + { + name: "invalid range - start >= end", + x: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 + }, + startPos: 4, + endPos: 2, + want: emptyBitArray, + }, + { + name: "invalid range - start >= length", + x: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 + }, + startPos: 8, + endPos: 10, + want: emptyBitArray, + }, + { + name: "full range", + x: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 + }, + startPos: 0, + endPos: 8, + want: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 + }, + }, + { + name: "middle subset", + x: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 + }, + startPos: 2, + endPos: 5, + want: &BitArray{ + len: 3, + words: [4]uint64{0x7, 0, 0, 0}, // 111 + }, + }, + { + name: "end subset", + x: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 + }, + startPos: 5, + endPos: 8, + want: &BitArray{ + len: 3, + words: [4]uint64{0x7, 0, 0, 0}, // 111 + }, + }, + { + name: "start subset", + x: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 + }, + startPos: 0, + endPos: 3, + want: &BitArray{ + len: 3, + words: [4]uint64{0x7, 0, 0, 0}, // 111 + }, + }, + { + name: "endPos beyond length", + x: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 + }, + startPos: 4, + endPos: 10, + want: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + }, + { + name: "sparse bits", + x: &BitArray{ + len: 8, + words: [4]uint64{0xAA, 0, 0, 0}, // 10101010 + }, + startPos: 2, + endPos: 6, + want: &BitArray{ + len: 4, + words: [4]uint64{0xA, 0, 0, 0}, // 1010 + }, + }, + { + name: "across word boundary", + x: &BitArray{ + len: 128, + words: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}, + }, + startPos: 60, + endPos: 68, + want: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, + }, + }, + { + name: "max length subset", + x: &BitArray{ + len: 251, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFF}, + }, + startPos: 1, + endPos: 251, + want: &BitArray{ + len: 250, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x3FFFFFFFFFFFFFF}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := new(BitArray).Subset(tt.x, tt.startPos, tt.endPos) + if !got.Equal(tt.want) { + t.Errorf("Subset(%v, %d, %d) = %v, want %v", tt.x, tt.startPos, tt.endPos, got, tt.want) + } + }) + } +} diff --git a/db/buckets.go b/db/buckets.go index e5037378a3..329eb15993 100644 --- a/db/buckets.go +++ b/db/buckets.go @@ -38,6 +38,9 @@ const ( MempoolTail // key of the tail node MempoolLength // number of transactions MempoolNode + ClassTrie // ClassTrie + Node path -> Trie Node + ContractTrieContract // ContractTrieContract + Node path -> Trie Node + ContractTrieStorage // ContractTrieStorage + Contract Address + Node path -> Trie Node ) // Key flattens a prefix and series of byte arrays into a single []byte. diff --git a/utils/orderedset.go b/utils/orderedset.go index 9d78809731..b5ab2a6680 100644 --- a/utils/orderedset.go +++ b/utils/orderedset.go @@ -75,3 +75,11 @@ func (o *OrderedSet[K, V]) Keys() []K { } return keys } + +func (o *OrderedSet[K, V]) Clear() { + o.lock.Lock() + defer o.lock.Unlock() + + o.items = nil + o.itemPos = make(map[K]int) +}