Skip to content

Commit

Permalink
update the tree logic
Browse files Browse the repository at this point in the history
  • Loading branch information
jackzampolin committed Dec 5, 2024
1 parent b08bc01 commit e2adcf1
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 143 deletions.
17 changes: 14 additions & 3 deletions gturbine/builder/helpers.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package builder

import (
"bytes"

"github.com/gordian-engine/gordian/gturbine"
"github.com/gordian-engine/gordian/tm/tmconsensus"
)

// FindLayerPosition finds a validator's layer and index in the tree
Expand All @@ -13,7 +16,7 @@ func FindLayerPosition(tree *gturbine.Tree, pubKey []byte) (*gturbine.Layer, int
layer := tree.Root
for layer != nil {
for i, v := range layer.Validators {
if string(v.PubKey) == string(pubKey) {
if bytes.Equal(v.PubKey.PubKeyBytes(), pubKey) {
return layer, i
}
}
Expand All @@ -27,7 +30,7 @@ func FindLayerPosition(tree *gturbine.Tree, pubKey []byte) (*gturbine.Layer, int
}

// GetChildren returns the validators that should receive forwarded shreds
func GetChildren(tree *gturbine.Tree, pubKey []byte) []gturbine.Validator {
func GetChildren(tree *gturbine.Tree, pubKey []byte) []tmconsensus.Validator {
layer, idx := FindLayerPosition(tree, pubKey)
if layer == nil || len(layer.Children) == 0 {
return nil
Expand All @@ -37,6 +40,14 @@ func GetChildren(tree *gturbine.Tree, pubKey []byte) []gturbine.Validator {
childLayer := layer.Children[0]
startIdx := (idx * len(childLayer.Validators)) / len(layer.Validators)
endIdx := ((idx + 1) * len(childLayer.Validators)) / len(layer.Validators)
if startIdx == endIdx {
endIdx = startIdx + 1
}

// Ensure endIdx doesn't exceed slice bounds
if endIdx > len(childLayer.Validators) {
endIdx = len(childLayer.Validators)
}

return childLayer.Validators[startIdx:endIdx]
}
}
105 changes: 71 additions & 34 deletions gturbine/builder/helpers_test.go
Original file line number Diff line number Diff line change
@@ -1,64 +1,101 @@
package builder

import (
"bytes"
"crypto/ed25519"
"testing"

"github.com/gordian-engine/gordian/gcrypto"
"github.com/gordian-engine/gordian/gturbine"
"github.com/gordian-engine/gordian/tm/tmconsensus"
)

func TestHelpers(t *testing.T) {
makeTree := func() *gturbine.Tree {
b := NewTreeBuilder(2)
validators := []gturbine.Validator{
{PubKey: ed25519.PublicKey{1}, Stake: 100},
{PubKey: ed25519.PublicKey{2}, Stake: 200},
{PubKey: ed25519.PublicKey{3}, Stake: 300},
{PubKey: ed25519.PublicKey{4}, Stake: 400},
{PubKey: ed25519.PublicKey{5}, Stake: 500},
func makeTestTree(validatorCount int) *gturbine.Tree {
b := NewTreeBuilder(200)
validators := make([]tmconsensus.Validator, validatorCount)

for i := 0; i < validatorCount; i++ {
pubKey := ed25519.PublicKey(make([]byte, ed25519.PublicKeySize))
pubKey[0] = byte(i % 255)
pubKey[1] = byte(i / 255)
validators[i] = tmconsensus.Validator{
PubKey: gcrypto.Ed25519PubKey(pubKey),
Power: uint64((validatorCount - i) * 100),
}
tree, _ := b.BuildTree(validators, 1, 0)
return tree
}

t.Run("find position", func(t *testing.T) {
tree := makeTree()

layer, idx := FindLayerPosition(tree, []byte{1})
tree, _ := b.BuildTree(validators, 1, 0)
return tree
}

func TestHelpers(t *testing.T) {
t.Run("find position in large tree", func(t *testing.T) {
tree := makeTestTree(500)

// Test finding root validator
searchKey := make([]byte, ed25519.PublicKeySize)
searchKey[0] = byte(0 % 255)
searchKey[1] = byte(0 / 255)
layer, idx := FindLayerPosition(tree, searchKey)
if layer == nil {
t.Fatal("validator not found")
}

// Verify the found validator has our search key
foundKey := layer.Validators[idx].PubKey.PubKeyBytes()
if !bytes.Equal(searchKey, foundKey) {
t.Errorf("found wrong validator: want %v, got %v", searchKey, foundKey)
}

// Test finding last layer validator
searchKey[0] = byte(499 % 255)
searchKey[1] = byte(499 / 255)
layer, idx = FindLayerPosition(tree, searchKey)
if layer == nil {
t.Fatal("validator not found")
}
if idx == -1 {
t.Error("invalid index returned")
foundKey = layer.Validators[idx].PubKey.PubKeyBytes()
if !bytes.Equal(searchKey, foundKey) {
t.Errorf("found wrong validator: want %v, got %v", searchKey, foundKey)
}

// Test unknown validator
layer, idx = FindLayerPosition(tree, []byte{99})
// Test non-existent validator
badKey := make([]byte, ed25519.PublicKeySize)
badKey[0] = 255
badKey[1] = 255
layer, idx = FindLayerPosition(tree, badKey)
if layer != nil || idx != -1 {
t.Error("found non-existent validator")
}
})

t.Run("get children", func(t *testing.T) {
tree := makeTree()
t.Run("children distribution in large tree", func(t *testing.T) {
tree := makeTestTree(500)

// Root validator should have children
children := GetChildren(tree, tree.Root.Validators[0].PubKey)
if len(children) == 0 {
t.Error("root validator should have children")
// Root layer validators should each get 1 child
for i, v := range tree.Root.Validators {
children := GetChildren(tree, v.PubKey.PubKeyBytes())
expectedCount := 1 // 200 validators divided by 200 fanout
if len(children) != expectedCount {
t.Errorf("validator %d: expected %d children, got %d", i, expectedCount, len(children))
}
}

// Last layer validator should have no children
lastLayer := tree.Root.Children[0].Children[0]
leafChildren := GetChildren(tree, lastLayer.Validators[0].PubKey)
if len(leafChildren) != 0 {
t.Error("leaf validator should have no children")
// Middle layer validators should each get 0 or 1 child
for i, v := range tree.Root.Children[0].Validators {
children := GetChildren(tree, v.PubKey.PubKeyBytes())
if len(children) > 1 {
t.Errorf("middle layer validator %d has too many children: %d", i, len(children))
}
}

// Check distribution
rootChildren := GetChildren(tree, tree.Root.Validators[0].PubKey)
if len(rootChildren) != 1 {
t.Errorf("expected 1 child for first root validator, got %d", len(rootChildren))
// Last layer validators should have no children
lastLayer := tree.Root.Children[0].Children[0]
for i, v := range lastLayer.Validators {
children := GetChildren(tree, v.PubKey.PubKeyBytes())
if len(children) != 0 {
t.Errorf("last layer validator %d has children when it shouldn't: %d", i, len(children))
}
}
})
}
24 changes: 11 additions & 13 deletions gturbine/builder/tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ package builder
import (
"crypto/sha256"
"encoding/binary"
"sort"

"github.com/gordian-engine/gordian/gturbine"
"github.com/gordian-engine/gordian/tm/tmconsensus"
)

type TreeBuilder struct {
Expand All @@ -19,27 +19,25 @@ func NewTreeBuilder(fanout uint32) *TreeBuilder {
}

// BuildTree creates a new propagation tree with stake-weighted validator ordering
func (b *TreeBuilder) BuildTree(validators []gturbine.Validator, slot uint64, shredIndex uint32) (*gturbine.Tree, error) {
func (b *TreeBuilder) BuildTree(validators []tmconsensus.Validator, slot uint64, shredIndex uint32) (*gturbine.Tree, error) {
if len(validators) == 0 {
return nil, nil
}

// Sort validators by stake
sortedVals := make([]gturbine.Validator, len(validators))
// Sort validators by stake (power) and pubkey
sortedVals := make([]tmconsensus.Validator, len(validators))
copy(sortedVals, validators)
sort.Slice(sortedVals, func(i, j int) bool {
return sortedVals[i].Stake > sortedVals[j].Stake
})
tmconsensus.SortValidators(sortedVals)

// Generate deterministic seed for shuffling
seed := b.deriveTreeSeed(slot, shredIndex, 0)

// Fisher-Yates shuffle with deterministic seed
for i := len(sortedVals) - 1; i > 0; i-- {
// Use seed to generate index
j := int(binary.LittleEndian.Uint64(seed) % uint64(i+1))
sortedVals[i], sortedVals[j] = sortedVals[j], sortedVals[i]

// Update seed for next iteration
h := sha256.New()
h.Write(seed)
Expand Down Expand Up @@ -79,17 +77,17 @@ func (b *TreeBuilder) BuildTree(validators []gturbine.Validator, slot uint64, sh
// deriveTreeSeed generates deterministic seed for tree creation
func (b *TreeBuilder) deriveTreeSeed(slot uint64, shredIndex uint32, shredType uint8) []byte {
h := sha256.New()

slotBytes := make([]byte, 8)
binary.LittleEndian.PutUint64(slotBytes, slot)
h.Write(slotBytes)

shredBytes := make([]byte, 4)
binary.LittleEndian.PutUint32(shredBytes, shredIndex)
h.Write(shredBytes)

h.Write([]byte{shredType})

return h.Sum(nil)
}

Expand Down
Loading

0 comments on commit e2adcf1

Please sign in to comment.