Skip to content

Commit

Permalink
add docs
Browse files Browse the repository at this point in the history
  • Loading branch information
weiihann committed Dec 31, 2024
1 parent 4f7b564 commit 2401a74
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 43 deletions.
26 changes: 19 additions & 7 deletions core/trie2/bitarray.go
Original file line number Diff line number Diff line change
Expand Up @@ -363,12 +363,12 @@ func (b *BitArray) Equal(x *BitArray) bool {

// Returns true if bit n-th is set, where n = 0 is LSB.
func (b *BitArray) IsBitSetFromLSB(n uint8) bool {
return b.BitSetFromLSB(n) == 1
return b.BitFromLSB(n) == 1
}

// Returns the bit value at position n, where n = 0 is LSB.
// If n is out of bounds, returns 0.
func (b *BitArray) BitSetFromLSB(n uint8) uint8 {
func (b *BitArray) BitFromLSB(n uint8) uint8 {
if n >= b.len {
return 0
}
Expand All @@ -381,26 +381,26 @@ func (b *BitArray) BitSetFromLSB(n uint8) uint8 {
}

func (b *BitArray) IsBitSet(n uint8) bool {
return b.BitSet(n) == 1
return b.Bit(n) == 1
}

// Returns the bit value at position n, where n = 0 is MSB.
// If n is out of bounds, returns 0.
func (b *BitArray) BitSet(n uint8) uint8 {
func (b *BitArray) Bit(n uint8) uint8 {
if n >= b.Len() {
return 0
}

return b.BitSetFromLSB(b.Len() - n - 1)
return b.BitFromLSB(b.Len() - n - 1)
}

// Returns the bit value at the most significant bit
func (b *BitArray) MSB() uint8 {
return b.BitSet(0)
return b.Bit(0)
}

func (b *BitArray) LSB() uint8 {
return b.BitSetFromLSB(0)
return b.BitFromLSB(0)
}

func (b *BitArray) IsEmpty() bool {
Expand Down Expand Up @@ -479,6 +479,18 @@ func (b *BitArray) SetUint64(length uint8, data uint64) *BitArray {
return b
}

// Sets the bit array to a single bit.
func (b *BitArray) SetBit(bit bool) *BitArray {
b.len = 1
if bit {
b.words[0] = 1
} else {
b.words[0] = 0
}
b.truncateToLength()
return b
}

// Returns the length of the encoded bit array in bytes.
func (b *BitArray) EncodedLen() uint {
return b.byteCount() + 1
Expand Down
1 change: 1 addition & 0 deletions core/trie2/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ func (n *edgeNode) pathMatches(key *BitArray) bool {
return n.path.EqualMSBs(key)
}

// Returns the common bits between the current node and the given key, starting from the most significant bit
func (n *edgeNode) commonPath(key *BitArray) BitArray {
var commonPath BitArray
commonPath.CommonMSBs(n.path, key)
Expand Down
81 changes: 58 additions & 23 deletions core/trie2/trie.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,18 @@ func NewTrie(height uint8) *Trie {
return &Trie{height: height}
}

// Modifies or inserts a key-value pair in the trie.
// If value is zero, the key is deleted from the trie.
func (t *Trie) Update(key, value *felt.Felt) error {
// if t.commited {
// return ErrCommitted
// }
return t.update(key, value)
}

// Retrieves the value associated with the given key.
// Returns felt.Zero if the key doesn't exist.
// May update the trie's internal structure if nodes need to be resolved.
func (t *Trie) Get(key *felt.Felt) (*felt.Felt, error) {
k := t.FeltToKey(key)
// TODO(weiihann): get the value directly from the reader
Expand All @@ -39,11 +44,17 @@ func (t *Trie) Get(key *felt.Felt) (*felt.Felt, error) {
return val, err
}

// Removes the given key from the trie.
func (t *Trie) Delete(key *felt.Felt) error {
panic("TODO(weiihann): implement me")
k := t.FeltToKey(key)
_, n, err := t.delete(t.root, new(BitArray), &k)
if err != nil {
return err
}
t.root = n
return nil
}

// Traverses the trie recursively to find the value that corresponds to the key.
func (t *Trie) get(n node, key *BitArray) (*felt.Felt, node, bool, error) {
switch n := n.(type) {
case *edgeNode:
Expand Down Expand Up @@ -75,10 +86,12 @@ func (t *Trie) get(n node, key *BitArray) (*felt.Felt, node, bool, error) {
}
}

// Modifies the trie by either inserting/updating a value or deleting a key.
// The operation is determined by whether the value is zero (delete) or non-zero (insert/update).
func (t *Trie) update(key, value *felt.Felt) error {
k := t.FeltToKey(key)
if value.IsZero() {
_, n, err := t.delete(t.root, &k)
_, n, err := t.delete(t.root, new(BitArray), &k)
if err != nil {
return err
}
Expand All @@ -104,8 +117,8 @@ func (t *Trie) insert(n node, key *BitArray, value node) (bool, node, error) {

switch n := n.(type) {
case *edgeNode:
match := n.commonPath(key)
// If the whole key matches, just keep this edge node as it is and update the value
match := n.commonPath(key) // get the matching bits between the current node and the key
// If the match is the same as the path, just keep this edge node as it is and update the value
if match.Len() == n.path.Len() {
dirty, newNode, err := t.insert(n.child, key.LSBs(key, match.Len()), value)
if !dirty || err != nil {
Expand All @@ -117,15 +130,15 @@ func (t *Trie) insert(n node, key *BitArray, value node) (bool, node, error) {
flags: newFlag(),
}, nil
}
// Otherwise branch out at the bit index where they differ
// Otherwise branch out at the bit position where they differ
branch := &binaryNode{flags: newFlag()}
var err error
_, branch.children[n.path.BitSet(match.Len())], err = t.insert(nil, new(BitArray).LSBs(n.path, match.Len()+1), n.child)
_, branch.children[n.path.Bit(match.Len())], err = t.insert(nil, new(BitArray).LSBs(n.path, match.Len()+1), n.child)
if err != nil {
return false, n, err
}

_, branch.children[key.BitSet(match.Len())], err = t.insert(nil, new(BitArray).LSBs(key, match.Len()+1), value)
_, branch.children[key.Bit(match.Len())], err = t.insert(nil, new(BitArray).LSBs(key, match.Len()+1), value)
if err != nil {
return false, n, err
}
Expand All @@ -135,22 +148,27 @@ func (t *Trie) insert(n node, key *BitArray, value node) (bool, node, error) {
return true, branch, nil
}

// Otherwise, create a new edge node with the path being the common path and the branch as the child
return true, &edgeNode{path: new(BitArray).MSBs(key, match.Len()), child: branch, flags: newFlag()}, nil

case *binaryNode:
// Go to the child node based on the MSB of the key
bit := key.MSB()
dirty, newNode, err := t.insert(n.children[bit], new(BitArray).LSBs(key, 1), value)
if !dirty || err != nil {
return false, n, err
}
// Replace the child node with the new node
n = n.copy()
n.flags = newFlag()
n.children[bit] = newNode
return true, n, nil
case nil:
// We reach the end of the key, return the value node
if key.IsEmpty() {
return true, value, nil
}
// Otherwise, return a new edge node with the path being the key and the value as the child
return true, &edgeNode{path: key, child: value, flags: newFlag()}, nil
case hashNode:
panic("TODO(weiihann): implement me")
Expand All @@ -159,47 +177,62 @@ func (t *Trie) insert(n node, key *BitArray, value node) (bool, node, error) {
}
}

func (t *Trie) delete(n node, key *BitArray) (bool, node, error) {
func (t *Trie) delete(n node, prefix, key *BitArray) (bool, node, error) {
switch n := n.(type) {
case *edgeNode:
match := n.commonPath(key)
// Mismatched, don't do anything
if match.Len() < n.path.Len() {
return false, n, nil
}
// If the whole key matches, just delete the edge node
// If the whole key matches, remove the entire edge node
if match.Len() == key.Len() {
return true, nil, nil
}

// Otherwise, we need to delete the child node
dirty, child, err := t.delete(n.child, key.LSBs(key, match.Len()))
// Otherwise, key is longer than current node path, so we need to delete the child.
// Child can never be nil because it's guaranteed that we have at least 2 other values in the subtrie.
keyPrefix := new(BitArray).MSBs(key, n.path.Len())
dirty, child, err := t.delete(n.child, new(BitArray).Append(prefix, keyPrefix), key.LSBs(key, n.path.Len()))
if !dirty || err != nil {
return false, n, err
}
switch child := child.(type) {
case *edgeNode:
return true, &edgeNode{path: n.path, child: child.child, flags: newFlag()}, nil
return true, &edgeNode{path: new(BitArray).Append(n.path, child.path), child: child.child, flags: newFlag()}, nil
default:
return true, &edgeNode{path: n.path, child: child, flags: newFlag()}, nil
return true, &edgeNode{path: new(BitArray).Set(n.path), child: child, flags: newFlag()}, nil
}
case *binaryNode:
bit := key.MSB()
dirty, newNode, err := t.delete(n.children[bit], key.LSBs(key, 1))
keyPrefix := new(BitArray).MSBs(key, 1)
dirty, newNode, err := t.delete(n.children[bit], new(BitArray).Append(prefix, keyPrefix), key.LSBs(key, 1))
if !dirty || err != nil {
return false, n, err
}
n = n.copy()
n.flags = newFlag()
n.children[bit] = newNode

// If the child node is not nil, that means we still have 2 children in this binary node
if newNode != nil {
return true, n, nil
}

// TODO(weiihann): combine this binary node with the child
// Otherwise, we need to combine this binary node with the other child
other := bit ^ 1
bitPrefix := new(BitArray).SetBit(other == 1)
if cn, ok := n.children[other].(*edgeNode); ok { // other child is an edge node, append the bit prefix to the child path
return true, &edgeNode{
path: new(BitArray).Append(bitPrefix, cn.path),
child: cn.child,
flags: newFlag(),
}, nil
}

return true, n, nil
// other child is not an edge node, create a new edge node with the bit prefix as the path
// containing the other child as the child
return true, &edgeNode{path: bitPrefix, child: n.children[other], flags: newFlag()}, nil
case valueNode:
return true, nil, nil
case nil:
Expand All @@ -211,15 +244,17 @@ func (t *Trie) delete(n node, key *BitArray) (bool, node, error) {
}
}

// Converts a Felt value into a BitArray representation suitable for
// use as a trie key with the specified height.
func (t *Trie) FeltToKey(f *felt.Felt) BitArray {
var key BitArray
key.SetFelt(t.height, f)
return key
}

func (t *Trie) String() string {
if t.root == nil {
return ""
}
return t.root.String()
}

func (t *Trie) FeltToKey(f *felt.Felt) BitArray {
var key BitArray
key.SetFelt(t.height, f)
return key
}
59 changes: 46 additions & 13 deletions core/trie2/trie_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@ import (
)

func TestUpdate(t *testing.T) {
trie := NewTrie(251)
tr, records := nonRandomTrie(t, 1000)

key := new(felt.Felt).SetUint64(1)
value := new(felt.Felt).SetUint64(2)
err := trie.Update(key, value)
require.NoError(t, err)
for _, record := range records {
err := tr.Update(record.key, record.value)
require.NoError(t, err)

got, err := trie.Get(key)
require.NoError(t, err)
require.Equal(t, value, got)
got, err := tr.Get(record.key)
require.NoError(t, err)
require.Equal(t, record.value, got)
}
}

func TestUpdateRandom(t *testing.T) {
Expand All @@ -34,16 +34,51 @@ func TestUpdateRandom(t *testing.T) {
}
}

func Test4KeysTrieD(t *testing.T) {
tr, _ := build4KeysTrieD(t)
t.Log(tr.String())
func TestDelete(t *testing.T) {
tr, records := nonRandomTrie(t, 10000)

for _, record := range records {
err := tr.Delete(record.key)
require.NoError(t, err)

got, err := tr.Get(record.key)
require.NoError(t, err)
require.Equal(t, got, &felt.Zero)
}
}

func TestDeleteRandom(t *testing.T) {
tr, records := randomTrie(t, 10000)

for i := len(records) - 1; i >= 0; i-- {
err := tr.Delete(records[i].key)
require.NoError(t, err)

got, err := tr.Get(records[i].key)
require.NoError(t, err)
require.Equal(t, got, &felt.Zero)
}
}

type keyValue struct {
key *felt.Felt
value *felt.Felt
}

func nonRandomTrie(t *testing.T, numKeys int) (*Trie, []*keyValue) {
tr := NewTrie(251)
records := make([]*keyValue, numKeys)

for i := 1; i < numKeys+1; i++ {
key := new(felt.Felt).SetUint64(uint64(i))
records[i-1] = &keyValue{key: key, value: key}
err := tr.Update(key, key)
require.NoError(t, err)
}

return tr, records
}

func randomTrie(t testing.TB, n int) (*Trie, []*keyValue) {
rrand := rand.New(rand.NewSource(3))

Expand Down Expand Up @@ -80,8 +115,6 @@ func buildTrie(t *testing.T, records []*keyValue) *Trie {

for _, record := range records {
err := tempTrie.Update(record.key, record.value)
t.Log("--------------------------------")
t.Log(tempTrie.String())
require.NoError(t, err)
}

Expand Down

0 comments on commit 2401a74

Please sign in to comment.