From 4fa2b05e7c57d36af984604c235342dd49ff56fd Mon Sep 17 00:00:00 2001 From: Jack Zampolin Date: Wed, 4 Dec 2024 12:49:33 -0800 Subject: [PATCH 1/4] a turbine implementation for gordian into debugging the encoding tests Passing encoding and shredding testing adding shredding tests refactor shredding to do everything in processor. failing tests, need to get them passing. passing processor tests and expanded coverage Add group id to each shred for over the wire push collection logic, split out shred group into sep file push Get basic tests passing use map w mu binary encoding update the tree logic add network transport w/ tests fix naming for gtbuilder and gtnetwork Remove gshred package gtshredding -> gtshred refactor builder to remove outside deps and be more efficent Fix memory leak when shards from already decoded blocks come in Use the shardgroup.reset function Fix timing issue in the test add a first pass at benchmarks remove unused config type expire stale groups. use rwmutexes where appropriate use lock in shredgroup Reset improve perf fix test race clean up processor crud Push basic readme for gturbine --- go.mod | 2 +- gturbine/README.md | 78 +++++ gturbine/gtbuilder/benchmark_test.go | 131 +++++++ gturbine/gtbuilder/helpers.go | 48 +++ gturbine/gtbuilder/helpers_test.go | 77 +++++ gturbine/gtbuilder/tree.go | 96 ++++++ gturbine/gtbuilder/tree_test.go | 88 +++++ gturbine/gtencoding/benchmark_test.go | 323 ++++++++++++++++++ .../gtencoding/binary_codec_bench_test.go | 165 +++++++++ gturbine/gtencoding/binary_encoder.go | 120 +++++++ gturbine/gtencoding/binary_encoder_test.go | 153 +++++++++ gturbine/gtencoding/encoder.go | 8 + gturbine/gtencoding/erasure.go | 88 +++++ gturbine/gtencoding/erasure_test.go | 189 ++++++++++ gturbine/gtnetwork/transport.go | 138 ++++++++ gturbine/gtnetwork/transport_test.go | 105 ++++++ gturbine/gtshred/benchmark_test.go | 133 ++++++++ gturbine/gtshred/process_shred_test.go | 298 ++++++++++++++++ gturbine/gtshred/processor.go | 238 +++++++++++++ gturbine/gtshred/processor_test.go | 224 ++++++++++++ gturbine/gtshred/shred_group.go | 229 +++++++++++++ gturbine/turbine.go | 53 +++ gturbine/turbine_test.go | 204 +++++++++++ 23 files changed, 3187 insertions(+), 1 deletion(-) create mode 100644 gturbine/README.md create mode 100644 gturbine/gtbuilder/benchmark_test.go create mode 100644 gturbine/gtbuilder/helpers.go create mode 100644 gturbine/gtbuilder/helpers_test.go create mode 100644 gturbine/gtbuilder/tree.go create mode 100644 gturbine/gtbuilder/tree_test.go create mode 100644 gturbine/gtencoding/benchmark_test.go create mode 100644 gturbine/gtencoding/binary_codec_bench_test.go create mode 100644 gturbine/gtencoding/binary_encoder.go create mode 100644 gturbine/gtencoding/binary_encoder_test.go create mode 100644 gturbine/gtencoding/encoder.go create mode 100644 gturbine/gtencoding/erasure.go create mode 100644 gturbine/gtencoding/erasure_test.go create mode 100644 gturbine/gtnetwork/transport.go create mode 100644 gturbine/gtnetwork/transport_test.go create mode 100644 gturbine/gtshred/benchmark_test.go create mode 100644 gturbine/gtshred/process_shred_test.go create mode 100644 gturbine/gtshred/processor.go create mode 100644 gturbine/gtshred/processor_test.go create mode 100644 gturbine/gtshred/shred_group.go create mode 100644 gturbine/turbine.go create mode 100644 gturbine/turbine_test.go diff --git a/go.mod b/go.mod index 82e7120..57f9014 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( github.com/tv42/httpunix v0.0.0-20191220191345-2ba4b9c3382c golang.org/x/crypto v0.27.0 golang.org/x/tools v0.22.0 + google.golang.org/protobuf v1.34.2 ) require ( @@ -141,7 +142,6 @@ require ( golang.org/x/text v0.18.0 // indirect golang.org/x/time v0.6.0 // indirect gonum.org/v1/gonum v0.15.1 // indirect - google.golang.org/protobuf v1.34.2 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect lukechampine.com/blake3 v1.2.2 // indirect ) diff --git a/gturbine/README.md b/gturbine/README.md new file mode 100644 index 0000000..d97f10e --- /dev/null +++ b/gturbine/README.md @@ -0,0 +1,78 @@ +# GTurbine 🌪️ + +GTurbine is a high-performance block propagation protocol designed for distributed consensus systems. It uses a structured network topology and Reed-Solomon erasure coding to efficiently propagate blocks across large validator networks while minimizing bandwidth requirements. + +*Because flooding blocks to every node is **so** 2019...* + +## Overview + +GTurbine implements a multi-layer propagation strategy inspired by Solana's Turbine protocol. Rather than having proposers flood blocks to every validator, GTurbine orchestrates a structured propagation flow: + +1. **Block Creation** (Proposer): + - Proposer reaps transactions from mempool + - Assembles them into a new block + - Calculates block header and metadata + +2. **Block Shredding** (Proposer): + - Splits block into fixed-size data shreds + - Applies Reed-Solomon erasure coding + - Generates recovery shreds for fault tolerance + - Tags all shreds with unique group ID and metadata + +3. **Tree Organization** (Network-wide): + - Validators self-organize into propagation layers + - Each validator knows its upstream source and downstream targets + - Layer assignments are deterministic and stake-weighted + - Tree structure changes periodically to prevent targeted attacks + +4. **Initial Propagation** (Proposer → Layer 1): + - Proposer distributes different shreds to each Layer 1 validator + - Each Layer 1 validator receives unique subset of block data + - Distribution uses UDP for low latency + +5. **Cascade Propagation** (Layer N → Layer N+1): + - Each validator forwards received shreds to assigned downstream nodes + - Propagation continues until leaves of tree are reached + - Different paths carry different shreds + +6. **Block Reconstruction** (All Nodes): + - Validators collect shreds from upstream nodes + - Once minimum threshold of shreds received (data + recovery) + - Reed-Solomon decoding recovers any missing pieces + - Original block is reconstructed and verified + +This structured approach transforms the bandwidth requirement at each node from O(n) in a flood-based system to O(log n), where n is the number of validators. By leveraging erasure coding and tree-based propagation, GTurbine achieves reliable block distribution while minimizing network congestion and single-node bandwidth requirements. + +## Architecture + +### Core Components + +``` +gturbine/ +├── gtbuilder/ - Tree construction and management +├── gtencoding/ - Erasure coding and shred serialization +├── gtnetwork/ - Network transport and routing +├── gtshred/ - Block shredding and reconstruction +└── turbine.go - Core interfaces and types +``` + +### Key Features + +- **Efficient Erasure Coding**: Uses the battle-tested [klauspost/reedsolomon](https://github.com/klauspost/reedsolomon) library +- **Flexible Tree Structure**: Configurable fanout and layer organization +- **UDP Transport**: Low-latency propagation with erasure coding for reliability +- **Safe Partial Recovery**: Reconstruct blocks with minimum required shreds +- **Built-in Verification**: Integrity checking at both shred and block levels + +## Acknowledgments + +GTurbine's was made possible by: +- [Solana's Turbine Protocol](https://docs.solana.com/cluster/turbine-block-propagation) +- Academic work on reliable multicast protocols +- The incredible [klauspost/reedsolomon](https://github.com/klauspost/reedsolomon) library + +## Support + +Found a bug? Have a feature request? Open an issue! + +*Remember: In distributed systems, eventual consistency is better than eventual insanity* 😉 \ No newline at end of file diff --git a/gturbine/gtbuilder/benchmark_test.go b/gturbine/gtbuilder/benchmark_test.go new file mode 100644 index 0000000..9b72c8f --- /dev/null +++ b/gturbine/gtbuilder/benchmark_test.go @@ -0,0 +1,131 @@ +package gtbuilder + +import ( + "testing" +) + +func BenchmarkTreeBuilder(b *testing.B) { + sizes := []struct { + name string + valCount int + fanout uint32 + }{ + {"100-Validators-200-Fanout", 100, 200}, + {"500-Validators-200-Fanout", 500, 200}, + {"1000-Validators-200-Fanout", 1000, 200}, + {"2000-Validators-200-Fanout", 2000, 200}, + {"500-Validators-100-Fanout", 500, 100}, + {"500-Validators-400-Fanout", 500, 400}, + } + + for _, size := range sizes { + // Create indices array + indices := make([]uint64, size.valCount) + for i := range indices { + indices[i] = uint64(i) + } + + b.Run(size.name, func(b *testing.B) { + builder := NewTreeBuilder(size.fanout) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + // Make copy of indices since they get modified + indicesCopy := make([]uint64, len(indices)) + copy(indicesCopy, indices) + + tree, err := builder.BuildTree(indicesCopy, uint64(i), 0) + if err != nil { + b.Fatal(err) + } + if tree == nil { + b.Fatal("expected non-nil tree") + } + } + }) + } +} + +func BenchmarkFindLayerPosition(b *testing.B) { + sizes := []struct { + name string + valCount int + searchPct float64 // percentage through validator set to search + }{ + {"500-Val-First-10pct", 500, 0.1}, + {"500-Val-Middle", 500, 0.5}, + {"500-Val-Last-10pct", 500, 0.9}, + {"2000-Val-First-10pct", 2000, 0.1}, + {"2000-Val-Middle", 2000, 0.5}, + {"2000-Val-Last-10pct", 2000, 0.9}, + } + + for _, size := range sizes { + indices := make([]uint64, size.valCount) + for i := range indices { + indices[i] = uint64(i) + } + + builder := NewTreeBuilder(200) + tree, _ := builder.BuildTree(indices, 1, 0) + searchIdx := uint64(float64(size.valCount) * size.searchPct) + + b.Run(size.name, func(b *testing.B) { + b.ResetTimer() + + for i := 0; i < b.N; i++ { + layer, idx := FindLayerPosition(tree, searchIdx) + if layer == nil || idx == -1 { + b.Fatal("failed to find validator") + } + } + }) + } +} + +func BenchmarkGetChildren(b *testing.B) { + sizes := []struct { + name string + valCount int + fanout uint32 + }{ + {"500-Val-200-Fanout-Root", 500, 200}, + {"500-Val-200-Fanout-Mid", 500, 200}, + {"2000-Val-200-Fanout-Root", 2000, 200}, + {"2000-Val-200-Fanout-Mid", 2000, 200}, + } + + for _, size := range sizes { + indices := make([]uint64, size.valCount) + for i := range indices { + indices[i] = uint64(i) + } + + builder := NewTreeBuilder(size.fanout) + tree, _ := builder.BuildTree(indices, 1, 0) + + // Test with root validator and middle layer validator + rootIdx := tree.Root.Validators[0] + midIdx := tree.Root.Children[0].Validators[0] + + b.Run(size.name+"-Root", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + children := GetChildren(tree, rootIdx) + if len(children) == 0 { + b.Fatal("expected children") + } + } + }) + + b.Run(size.name+"-Mid", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + children := GetChildren(tree, midIdx) + if children == nil { + b.Fatal("expected valid children slice") + } + } + }) + } +} diff --git a/gturbine/gtbuilder/helpers.go b/gturbine/gtbuilder/helpers.go new file mode 100644 index 0000000..a5a7bb3 --- /dev/null +++ b/gturbine/gtbuilder/helpers.go @@ -0,0 +1,48 @@ +package gtbuilder + +import ( + "github.com/gordian-engine/gordian/gturbine" +) + +// FindLayerPosition finds a validator's layer and index in the tree +func FindLayerPosition(tree *gturbine.Tree, valIndex uint64) (*gturbine.Layer, int) { + if tree == nil { + return nil, -1 + } + + layer := tree.Root + for layer != nil { + for i, idx := range layer.Validators { + if idx == valIndex { + return layer, i + } + } + if len(layer.Children) > 0 { + layer = layer.Children[0] + } else { + layer = nil + } + } + return nil, -1 +} + +// GetChildren returns the validator indices that should receive forwarded shreds +func GetChildren(tree *gturbine.Tree, valIndex uint64) []uint64 { + layer, idx := FindLayerPosition(tree, valIndex) + if layer == nil || len(layer.Children) == 0 { + return nil + } + + // Same distribution logic, now returning indices + childLayer := layer.Children[0] + startIdx := (idx * len(childLayer.Validators)) / len(layer.Validators) + endIdx := ((idx + 1) * len(childLayer.Validators)) / len(layer.Validators) + if startIdx == endIdx { + endIdx = startIdx + 1 + } + if endIdx > len(childLayer.Validators) { + endIdx = len(childLayer.Validators) + } + + return childLayer.Validators[startIdx:endIdx] +} diff --git a/gturbine/gtbuilder/helpers_test.go b/gturbine/gtbuilder/helpers_test.go new file mode 100644 index 0000000..789f7ed --- /dev/null +++ b/gturbine/gtbuilder/helpers_test.go @@ -0,0 +1,77 @@ +package gtbuilder + +import ( + "testing" + + "github.com/gordian-engine/gordian/gturbine" +) + +func makeTestTree(count int) *gturbine.Tree { + b := NewTreeBuilder(200) + indices := make([]uint64, count) + for i := 0; i < count; i++ { + indices[i] = uint64(i) + } + tree, _ := b.BuildTree(indices, 1, 0) + return tree +} + +func TestHelpers(t *testing.T) { + t.Run("find position in large tree", func(t *testing.T) { + tree := makeTestTree(500) + + // Test finding first validator + layer, idx := FindLayerPosition(tree, 0) + if layer == nil { + t.Fatal("validator not found") + } + if layer.Validators[idx] != 0 { + t.Errorf("found wrong validator index: want 0, got %d", layer.Validators[idx]) + } + + // Test finding last validator + layer, idx = FindLayerPosition(tree, 499) + if layer == nil { + t.Fatal("validator not found") + } + if layer.Validators[idx] != 499 { + t.Errorf("found wrong validator index: want 499, got %d", layer.Validators[idx]) + } + + // Test non-existent validator + layer, idx = FindLayerPosition(tree, 1000) + if layer != nil || idx != -1 { + t.Error("found non-existent validator") + } + }) + + t.Run("children distribution in large tree", func(t *testing.T) { + tree := makeTestTree(500) + + // Root layer validators should each get 1 child + for _, v := range tree.Root.Validators { + children := GetChildren(tree, v) + expectedCount := 1 // 200 validators divided by 200 fanout + if len(children) != expectedCount { + t.Errorf("validator %d: expected %d children, got %d", v, expectedCount, len(children)) + } + } + + // Middle layer validators should each get 0 or 1 child + for _, v := range tree.Root.Children[0].Validators { + children := GetChildren(tree, v) + if len(children) > 1 { + t.Errorf("middle layer validator %d has too many children: %d", v, len(children)) + } + } + + // Last layer validators should have no children + lastLayer := tree.Root.Children[0].Children[0] + for _, v := range lastLayer.Validators { + children := GetChildren(tree, v) + if len(children) != 0 { + t.Errorf("last layer validator %d has children when it shouldn't: %d", v, len(children)) + } + } + }) +} diff --git a/gturbine/gtbuilder/tree.go b/gturbine/gtbuilder/tree.go new file mode 100644 index 0000000..2649dd3 --- /dev/null +++ b/gturbine/gtbuilder/tree.go @@ -0,0 +1,96 @@ +package gtbuilder + +import ( + "crypto/sha256" + "encoding/binary" + + "github.com/gordian-engine/gordian/gturbine" +) + +type TreeBuilder struct { + fanout uint32 +} + +func NewTreeBuilder(fanout uint32) *TreeBuilder { + return &TreeBuilder{ + fanout: fanout, + } +} + +// BuildTree creates a new propagation tree with stake-weighted validator ordering +// It takes valIndices which is a presorted array of indices. It returns the tree +// with the indices as the values. It is up to the caller to map these indicies to +// the actual validators +func (b *TreeBuilder) BuildTree(valIndices []uint64, slot uint64, shredIndex uint32) (*gturbine.Tree, error) { + if len(valIndices) == 0 { + return nil, nil + } + + // Generate deterministic seed for shuffling + seed := b.deriveTreeSeed(slot, shredIndex, 0) + + // Fisher-Yates shuffle with deterministic seed + for i := len(valIndices) - 1; i > 0; i-- { + // Use seed to generate index + j := int(binary.LittleEndian.Uint64(seed) % uint64(i+1)) + valIndices[i], valIndices[j] = valIndices[j], valIndices[i] + + // Update seed for next iteration + h := sha256.New() + h.Write(seed) + seed = h.Sum(nil) + } + + // Build layers + tree := >urbine.Tree{ + Fanout: b.fanout, + } + + remaining := valIndices + currentLayer := >urbine.Layer{} + tree.Root = currentLayer + tree.Height = 1 + + for len(remaining) > 0 { + // Take up to fanout validators for current layer + takeCount := min(len(remaining), int(b.fanout)) + currentLayer.Validators = remaining[:takeCount] + remaining = remaining[takeCount:] + + if len(remaining) > 0 { + // Create new layer + newLayer := >urbine.Layer{ + Parent: currentLayer, + } + currentLayer.Children = append(currentLayer.Children, newLayer) + currentLayer = newLayer + tree.Height++ + } + } + + return tree, nil +} + +// deriveTreeSeed generates deterministic seed for tree creation +func (b *TreeBuilder) deriveTreeSeed(slot uint64, shredIndex uint32, shredType uint8) []byte { + h := sha256.New() + + slotBytes := make([]byte, 8) + binary.LittleEndian.PutUint64(slotBytes, slot) + h.Write(slotBytes) + + shredBytes := make([]byte, 4) + binary.LittleEndian.PutUint32(shredBytes, shredIndex) + h.Write(shredBytes) + + h.Write([]byte{shredType}) + + return h.Sum(nil) +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} diff --git a/gturbine/gtbuilder/tree_test.go b/gturbine/gtbuilder/tree_test.go new file mode 100644 index 0000000..e763a4a --- /dev/null +++ b/gturbine/gtbuilder/tree_test.go @@ -0,0 +1,88 @@ +package gtbuilder + +import ( + "testing" +) + +func TestTreeBuilder(t *testing.T) { + makeIndices := func(count int) []uint64 { + indices := make([]uint64, count) + for i := 0; i < count; i++ { + indices[i] = uint64(i) + } + return indices + } + + t.Run("production size tree - 500 validators", func(t *testing.T) { + b := NewTreeBuilder(200) + indices := makeIndices(500) + tree, err := b.BuildTree(indices, 1, 0) + if err != nil { + t.Fatal(err) + } + + if tree.Height != 3 { + t.Errorf("expected height 3 for 500 validators, got %d", tree.Height) + } + + if len(tree.Root.Validators) != 200 { + t.Errorf("expected 200 validators in root, got %d", len(tree.Root.Validators)) + } + + if len(tree.Root.Children[0].Validators) != 200 { + t.Errorf("expected 200 validators in second layer, got %d", len(tree.Root.Children[0].Validators)) + } + + lastLayer := tree.Root.Children[0].Children[0] + if len(lastLayer.Validators) != 100 { + t.Errorf("expected 100 validators in last layer, got %d", len(lastLayer.Validators)) + } + }) + + t.Run("determinism", func(t *testing.T) { + b := NewTreeBuilder(200) + + tree1, _ := b.BuildTree(makeIndices(500), 1, 0) + tree2, _ := b.BuildTree(makeIndices(500), 1, 0) + tree3, _ := b.BuildTree(makeIndices(500), 2, 0) + + if !compareValidators(tree1.Root.Validators, tree2.Root.Validators) { + t.Error("trees not deterministic for same inputs") + } + + if compareValidators(tree1.Root.Validators, tree3.Root.Validators) { + t.Error("expected different trees for different slots") + } + }) + + t.Run("children distribution", func(t *testing.T) { + b := NewTreeBuilder(200) + indices := makeIndices(500) + tree, _ := b.BuildTree(indices, 1, 0) + + children := GetChildren(tree, tree.Root.Validators[0]) + expectedCount := 1 + if len(children) != expectedCount { + t.Errorf("expected %d children for root validator, got %d", expectedCount, len(children)) + } + + for _, v := range tree.Root.Children[0].Validators { + children := GetChildren(tree, v) + if len(children) > 1 { + t.Error("middle layer validator has too many children") + } + } + }) +} + +func compareValidators(a, b []uint64) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} diff --git a/gturbine/gtencoding/benchmark_test.go b/gturbine/gtencoding/benchmark_test.go new file mode 100644 index 0000000..0c671ef --- /dev/null +++ b/gturbine/gtencoding/benchmark_test.go @@ -0,0 +1,323 @@ +package gtencoding + +import ( + "crypto/rand" + "fmt" + "testing" + + "github.com/google/uuid" + "github.com/gordian-engine/gordian/gturbine" +) + +// Benchmark sizes - all multiples of 64 bytes, ranging from 8MB to 128MB +type benchSize struct { + size int + name string +} + +var benchSizes = []benchSize{ + {name: "128MB", size: 128 * 1024 * 1024}, + {name: "64MB", size: 64 * 1024 * 1024}, + {name: "32MB", size: 32 * 1024 * 1024}, + {name: "16MB", size: 16 * 1024 * 1024}, + {name: "8MB", size: 8 * 1024 * 1024}, +} + +// BenchmarkEncode tests binary encoding performance at various block sizes +func BenchmarkEncode(b *testing.B) { + for _, size := range benchSizes { + b.Run(fmt.Sprintf("%s", size.name), func(b *testing.B) { + data := make([]byte, size.size) + rand.Read(data) + + shred := >urbine.Shred{ + Type: gturbine.DataShred, + FullDataSize: size.size, + BlockHash: make([]byte, 32), + GroupID: uuid.New().String(), + Height: 1, + Index: 0, + TotalDataShreds: 16, + TotalRecoveryShreds: 4, + Data: data, + } + + codec := NewBinaryShardCodec() + + b.ResetTimer() + b.SetBytes(int64(size.size)) + + for i := 0; i < b.N; i++ { + _, err := codec.Encode(shred) + if err != nil { + b.Fatal(err) + } + } + }) + } +} + +// BenchmarkDecode tests binary decoding performance at various block sizes +func BenchmarkDecode(b *testing.B) { + for _, size := range benchSizes { + b.Run(fmt.Sprintf("%s", size.name), func(b *testing.B) { + data := make([]byte, size.size) + rand.Read(data) + + shred := >urbine.Shred{ + Type: gturbine.DataShred, + FullDataSize: size.size, + BlockHash: make([]byte, 32), + GroupID: uuid.New().String(), + Height: 1, + Index: 0, + TotalDataShreds: 16, + TotalRecoveryShreds: 4, + Data: data, + } + + codec := NewBinaryShardCodec() + encoded, _ := codec.Encode(shred) + + b.ResetTimer() + b.SetBytes(int64(size.size)) + + for i := 0; i < b.N; i++ { + _, err := codec.Decode(encoded) + if err != nil { + b.Fatal(err) + } + } + }) + } +} + +// BenchmarkErasureEncoding tests Reed-Solomon encoding with different configurations +func BenchmarkErasureEncoding(b *testing.B) { + configs := []struct { + data int + recovery int + }{ + {4, 2}, // 33% overhead + {8, 4}, // 33% overhead + {16, 4}, // 20% overhead + {32, 8}, // 20% overhead + } + + for _, size := range benchSizes { + for _, cfg := range configs { + shardSize := size.size / cfg.data + name := fmt.Sprintf("%s-%dd-%dr", size.name, cfg.data, cfg.recovery) + + b.Run(name, func(b *testing.B) { + enc, err := NewEncoder(cfg.data, cfg.recovery) + if err != nil { + b.Fatal(err) + } + + shreds := make([][]byte, cfg.data) + for i := range shreds { + shreds[i] = make([]byte, shardSize) + rand.Read(shreds[i]) + } + + b.ResetTimer() + b.SetBytes(int64(size.size)) + + for i := 0; i < b.N; i++ { + _, err := enc.GenerateRecoveryShreds(shreds) + if err != nil { + b.Fatal(err) + } + } + }) + } + } +} + +// BenchmarkErasureReconstruction tests Reed-Solomon reconstruction with different configurations +func BenchmarkErasureReconstruction(b *testing.B) { + configs := []struct { + data int + recovery int + }{ + {4, 2}, // 33% overhead + {8, 4}, // 33% overhead + {16, 4}, // 20% overhead + {32, 8}, // 20% overhead + } + + for _, size := range benchSizes { + for _, cfg := range configs { + shardSize := size.size / cfg.data + name := fmt.Sprintf("%s-%dd-%dr", size.name, cfg.data, cfg.recovery) + + b.Run(name, func(b *testing.B) { + enc, err := NewEncoder(cfg.data, cfg.recovery) + if err != nil { + b.Fatal(err) + } + + // Generate test data + shreds := make([][]byte, cfg.data) + for i := range shreds { + shreds[i] = make([]byte, shardSize) + rand.Read(shreds[i]) + } + + // Generate recovery shreds + recoveryShreds, err := enc.GenerateRecoveryShreds(shreds) + if err != nil { + b.Fatal(err) + } + + // Combine all shreds + allShreds := append(shreds, recoveryShreds...) + + // Simulate worst case - lose maximum recoverable shards + for i := 0; i < cfg.recovery; i++ { + allShreds[i] = nil + } + + b.ResetTimer() + b.SetBytes(int64(size.size)) + + for i := 0; i < b.N; i++ { + // Make a copy since Reconstruct modifies the slice + testShreds := make([][]byte, len(allShreds)) + copy(testShreds, allShreds) + + err := enc.Reconstruct(testShreds) + if err != nil { + b.Fatal(err) + } + } + }) + } + } +} + +// BenchmarkErasureVerification tests Reed-Solomon verification with different configurations +func BenchmarkErasureVerification(b *testing.B) { + configs := []struct { + data int + recovery int + }{ + {4, 2}, // 33% overhead + {8, 4}, // 33% overhead + {16, 4}, // 20% overhead + {32, 8}, // 20% overhead + } + + for _, size := range benchSizes { + for _, cfg := range configs { + shardSize := size.size / cfg.data + name := fmt.Sprintf("%s-%dd-%dr", size.name, cfg.data, cfg.recovery) + + b.Run(name, func(b *testing.B) { + enc, err := NewEncoder(cfg.data, cfg.recovery) + if err != nil { + b.Fatal(err) + } + + // Generate test data + shreds := make([][]byte, cfg.data) + for i := range shreds { + shreds[i] = make([]byte, shardSize) + rand.Read(shreds[i]) + } + + // Generate recovery shreds + recoveryShreds, err := enc.GenerateRecoveryShreds(shreds) + if err != nil { + b.Fatal(err) + } + + // Combine all shreds + allShreds := append(shreds, recoveryShreds...) + + b.ResetTimer() + b.SetBytes(int64(size.size)) + + for i := 0; i < b.N; i++ { + ok, err := enc.Verify(allShreds) + if err != nil { + b.Fatal(err) + } + if !ok { + b.Fatal("verification failed") + } + } + }) + } + } +} + +// BenchmarkFullPipeline tests the complete encoding process +func BenchmarkFullPipeline(b *testing.B) { + configs := []struct { + data int + recovery int + }{ + {16, 4}, // Typical configuration + } + + for _, size := range benchSizes { + for _, cfg := range configs { + name := fmt.Sprintf("%s-%dd-%dr", size.name, cfg.data, cfg.recovery) + + b.Run(name, func(b *testing.B) { + binaryCodec := NewBinaryShardCodec() + erasureEnc, err := NewEncoder(cfg.data, cfg.recovery) + if err != nil { + b.Fatal(err) + } + + data := make([]byte, size.size) + rand.Read(data) + + b.ResetTimer() + b.SetBytes(int64(size.size)) + + for i := 0; i < b.N; i++ { + shreds := make([][]byte, cfg.data) + shredSize := size.size / cfg.data + + for j := 0; j < cfg.data; j++ { + shred := >urbine.Shred{ + Type: gturbine.DataShred, + FullDataSize: size.size, + BlockHash: make([]byte, 32), + GroupID: uuid.New().String(), + Height: 1, + Index: j, + TotalDataShreds: cfg.data, + TotalRecoveryShreds: cfg.recovery, + Data: data[j*shredSize : (j+1)*shredSize], + } + + encoded, err := binaryCodec.Encode(shred) + if err != nil { + b.Fatal(err) + } + shreds[j] = encoded + } + + recoveryShreds, err := erasureEnc.GenerateRecoveryShreds(shreds) + if err != nil { + b.Fatal(err) + } + + allShreds := append(shreds, recoveryShreds...) + ok, err := erasureEnc.Verify(allShreds) + if err != nil { + b.Fatal(err) + } + if !ok { + b.Fatal("verification failed") + } + } + }) + } + } +} diff --git a/gturbine/gtencoding/binary_codec_bench_test.go b/gturbine/gtencoding/binary_codec_bench_test.go new file mode 100644 index 0000000..e84a0ee --- /dev/null +++ b/gturbine/gtencoding/binary_codec_bench_test.go @@ -0,0 +1,165 @@ +package gtencoding + +import ( + "bytes" + "crypto/rand" + "fmt" + "testing" + + "github.com/google/uuid" + "github.com/gordian-engine/gordian/gturbine" +) + +// TestShred represents a reusable test shred configuration +type TestShred struct { + size int + dataType gturbine.ShredType +} + +var testConfigs = []TestShred{ + {64, gturbine.DataShred}, // Minimum size + {1024, gturbine.DataShred}, // 1KB + {64 * 1024, gturbine.DataShred}, // 64KB + {1024 * 1024, gturbine.DataShred}, // 1MB +} + +// BenchmarkBinaryCodec runs comprehensive benchmarks for the binary codec +func BenchmarkBinaryCodec(b *testing.B) { + for _, cfg := range testConfigs { + b.Run(benchName("Encode", cfg), func(b *testing.B) { + benchmarkEncode(b, cfg) + }) + b.Run(benchName("Decode", cfg), func(b *testing.B) { + benchmarkDecode(b, cfg) + }) + b.Run(benchName("RoundTrip", cfg), func(b *testing.B) { + benchmarkRoundTrip(b, cfg) + }) + } +} + +// BenchmarkBinaryCodecParallel tests parallel encoding/decoding performance +func BenchmarkBinaryCodecParallel(b *testing.B) { + for _, cfg := range testConfigs { + b.Run(benchName("EncodeParallel", cfg), func(b *testing.B) { + benchmarkEncodeParallel(b, cfg) + }) + b.Run(benchName("DecodeParallel", cfg), func(b *testing.B) { + benchmarkDecodeParallel(b, cfg) + }) + } +} + +// Helper to create consistent benchmark names +func benchName(op string, cfg TestShred) string { + return fmt.Sprintf("%s/%dB", op, cfg.size) +} + +// Helper to create a test shred +func createTestShred(cfg TestShred) *gturbine.Shred { + data := make([]byte, cfg.size) + rand.Read(data) + + return >urbine.Shred{ + Type: cfg.dataType, + FullDataSize: cfg.size, + BlockHash: bytes.Repeat([]byte{0xFF}, blockHashSize), // Fixed pattern for consistent benchmarking + GroupID: uuid.New().String(), + Height: 1, + Index: 0, + TotalDataShreds: 16, + TotalRecoveryShreds: 4, + Data: data, + } +} + +func benchmarkEncode(b *testing.B, cfg TestShred) { + codec := NewBinaryShardCodec() + shred := createTestShred(cfg) + + b.ResetTimer() + b.SetBytes(int64(cfg.size + prefixSize)) + + for i := 0; i < b.N; i++ { + _, err := codec.Encode(shred) + if err != nil { + b.Fatal(err) + } + } +} + +func benchmarkDecode(b *testing.B, cfg TestShred) { + codec := NewBinaryShardCodec() + shred := createTestShred(cfg) + encoded, err := codec.Encode(shred) + if err != nil { + b.Fatal(err) + } + + b.ResetTimer() + b.SetBytes(int64(cfg.size + prefixSize)) + + for i := 0; i < b.N; i++ { + _, err := codec.Decode(encoded) + if err != nil { + b.Fatal(err) + } + } +} + +func benchmarkRoundTrip(b *testing.B, cfg TestShred) { + codec := NewBinaryShardCodec() + shred := createTestShred(cfg) + + b.ResetTimer() + b.SetBytes(int64(cfg.size + prefixSize)) + + for i := 0; i < b.N; i++ { + encoded, err := codec.Encode(shred) + if err != nil { + b.Fatal(err) + } + _, err = codec.Decode(encoded) + if err != nil { + b.Fatal(err) + } + } +} + +func benchmarkEncodeParallel(b *testing.B, cfg TestShred) { + codec := NewBinaryShardCodec() + shred := createTestShred(cfg) + + b.ResetTimer() + b.SetBytes(int64(cfg.size + prefixSize)) + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _, err := codec.Encode(shred) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func benchmarkDecodeParallel(b *testing.B, cfg TestShred) { + codec := NewBinaryShardCodec() + shred := createTestShred(cfg) + encoded, err := codec.Encode(shred) + if err != nil { + b.Fatal(err) + } + + b.ResetTimer() + b.SetBytes(int64(cfg.size + prefixSize)) + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _, err := codec.Decode(encoded) + if err != nil { + b.Fatal(err) + } + } + }) +} diff --git a/gturbine/gtencoding/binary_encoder.go b/gturbine/gtencoding/binary_encoder.go new file mode 100644 index 0000000..0c4d326 --- /dev/null +++ b/gturbine/gtencoding/binary_encoder.go @@ -0,0 +1,120 @@ +package gtencoding + +import ( + "encoding/binary" + "fmt" + + "github.com/google/uuid" + "github.com/gordian-engine/gordian/gturbine" +) + +const ( + int16Size = 2 + int32Size = 4 + int64Size = 8 + versionSize = int16Size + typeSize = int16Size + uuidSize = 16 + + fullDataSizeSize = int64Size + blockHashSize = 32 + groupIDSize = uuidSize + heightSize = int64Size + indexSize = int64Size + totalDataShredsSize = int64Size + totalRecoveryShredsSize = int64Size + + prefixSize = versionSize + typeSize + fullDataSizeSize + blockHashSize + groupIDSize + heightSize + indexSize + totalDataShredsSize + totalRecoveryShredsSize + binaryVersion = 1 +) + +// BinaryShardCodec represents a codec for encoding and decoding shreds +type BinaryShardCodec struct{} + +func NewBinaryShardCodec() *BinaryShardCodec { + return &BinaryShardCodec{} +} + +func (bsc *BinaryShardCodec) Encode(shred *gturbine.Shred) ([]byte, error) { + out := make([]byte, prefixSize+len(shred.Data)) + + // Write version + binary.LittleEndian.PutUint16(out[:2], binaryVersion) + + // Write type + binary.LittleEndian.PutUint16(out[2:4], uint16(shred.Type)) + + // Write full data size + binary.LittleEndian.PutUint64(out[4:12], uint64(shred.FullDataSize)) + + // Write block hash + copy(out[12:44], shred.BlockHash) + + uid, err := uuid.Parse(shred.GroupID) + if err != nil { + return nil, fmt.Errorf("failed to parse group ID: %w", err) + } + // Write group ID + copy(out[44:60], uid[:]) + + // Write height + binary.LittleEndian.PutUint64(out[60:68], shred.Height) + + // Write index + binary.LittleEndian.PutUint64(out[68:76], uint64(shred.Index)) + + // Write total data shreds + binary.LittleEndian.PutUint64(out[76:84], uint64(shred.TotalDataShreds)) + + // Write total recovery shreds + binary.LittleEndian.PutUint64(out[84:92], uint64(shred.TotalRecoveryShreds)) + + // Write data + copy(out[prefixSize:], shred.Data) + + return out, nil + +} + +func (bsc *BinaryShardCodec) Decode(data []byte) (*gturbine.Shred, error) { + shred := gturbine.Shred{} + + // Read version + version := binary.LittleEndian.Uint16(data[:2]) + if version != binaryVersion { + return nil, fmt.Errorf("unsupported version: %d", version) + } + + // Read type + shred.Type = gturbine.ShredType(binary.LittleEndian.Uint16(data[2:4])) + + // Read full data size + shred.FullDataSize = int(binary.LittleEndian.Uint64(data[4:12])) + + // Read block hash + shred.BlockHash = make([]byte, blockHashSize) + copy(shred.BlockHash, data[12:44]) + + // Read group ID + uid := uuid.UUID{} + copy(uid[:], data[44:60]) + shred.GroupID = uid.String() + + // Read height + shred.Height = binary.LittleEndian.Uint64(data[60:68]) + + // Read index + shred.Index = int(binary.LittleEndian.Uint64(data[68:76])) + + // Read total data shreds + shred.TotalDataShreds = int(binary.LittleEndian.Uint64(data[76:84])) + + // Read total recovery shreds + shred.TotalRecoveryShreds = int(binary.LittleEndian.Uint64(data[84:92])) + + // Read data + shred.Data = make([]byte, len(data)-prefixSize) + copy(shred.Data, data[prefixSize:]) + + return &shred, nil +} diff --git a/gturbine/gtencoding/binary_encoder_test.go b/gturbine/gtencoding/binary_encoder_test.go new file mode 100644 index 0000000..2463087 --- /dev/null +++ b/gturbine/gtencoding/binary_encoder_test.go @@ -0,0 +1,153 @@ +package gtencoding + +import ( + "bytes" + "testing" + + "github.com/google/uuid" + "github.com/gordian-engine/gordian/gturbine" +) + +func TestBinaryShardCodec_EncodeDecode(t *testing.T) { + tests := []struct { + name string + shred *gturbine.Shred + wantErr bool + }{ + { + name: "basic encode/decode", + shred: >urbine.Shred{ + FullDataSize: 1000, + BlockHash: bytes.Repeat([]byte{1}, 32), + GroupID: uuid.New().String(), + Height: 12345, + Index: 5, + TotalDataShreds: 10, + TotalRecoveryShreds: 2, + Data: []byte("test data"), + }, + wantErr: false, + }, + { + name: "empty data", + shred: >urbine.Shred{ + FullDataSize: 0, + BlockHash: bytes.Repeat([]byte{2}, 32), + GroupID: uuid.New().String(), + Height: 67890, + Index: 0, + TotalDataShreds: 1, + TotalRecoveryShreds: 0, + Data: []byte{}, + }, + wantErr: false, + }, + { + name: "large data", + shred: >urbine.Shred{ + FullDataSize: 1000000, + BlockHash: bytes.Repeat([]byte{3}, 32), + GroupID: uuid.New().String(), + Height: 999999, + Index: 50, + TotalDataShreds: 100, + TotalRecoveryShreds: 20, + Data: bytes.Repeat([]byte("large data"), 1000), + }, + wantErr: false, + }, + } + + codec := &BinaryShardCodec{} + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test encoding + encoded, err := codec.Encode(tt.shred) + if (err != nil) != tt.wantErr { + t.Errorf("BinaryShardCodec.Encode() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if tt.wantErr { + return + } + + // Test decoding + decoded, err := codec.Decode(encoded) + if err != nil { + t.Errorf("BinaryShardCodec.Decode() error = %v", err) + return + } + + // Verify all fields match + if decoded.FullDataSize != tt.shred.FullDataSize { + t.Errorf("FullDataSize mismatch: got %v, want %v", decoded.FullDataSize, tt.shred.FullDataSize) + } + + if !bytes.Equal(decoded.BlockHash, tt.shred.BlockHash) { + t.Errorf("BlockHash mismatch: got %v, want %v", decoded.BlockHash, tt.shred.BlockHash) + } + + if decoded.GroupID != tt.shred.GroupID { + t.Errorf("GroupID mismatch: got %v, want %v", decoded.GroupID, tt.shred.GroupID) + } + + if decoded.Height != tt.shred.Height { + t.Errorf("Height mismatch: got %v, want %v", decoded.Height, tt.shred.Height) + } + + if decoded.Index != tt.shred.Index { + t.Errorf("Index mismatch: got %v, want %v", decoded.Index, tt.shred.Index) + } + + if decoded.TotalDataShreds != tt.shred.TotalDataShreds { + t.Errorf("TotalDataShreds mismatch: got %v, want %v", decoded.TotalDataShreds, tt.shred.TotalDataShreds) + } + + if decoded.TotalRecoveryShreds != tt.shred.TotalRecoveryShreds { + t.Errorf("TotalRecoveryShreds mismatch: got %v, want %v", decoded.TotalRecoveryShreds, tt.shred.TotalRecoveryShreds) + } + + if !bytes.Equal(decoded.Data, tt.shred.Data) { + t.Errorf("Data mismatch: got %v, want %v", decoded.Data, tt.shred.Data) + } + }) + } +} + +func TestBinaryShardCodec_InvalidGroupID(t *testing.T) { + codec := &BinaryShardCodec{} + shred := >urbine.Shred{ + GroupID: "invalid-uuid", + // Other fields can be empty for this test + } + + _, err := codec.Encode(shred) + if err == nil { + t.Error("Expected error when encoding invalid GroupID, got nil") + } +} + +func TestBinaryShardCodec_DataSizes(t *testing.T) { + codec := &BinaryShardCodec{} + shred := >urbine.Shred{ + FullDataSize: 1000, + BlockHash: bytes.Repeat([]byte{1}, 32), + GroupID: uuid.New().String(), + Height: 12345, + Index: 5, + TotalDataShreds: 10, + TotalRecoveryShreds: 2, + Data: []byte("test data"), + } + + encoded, err := codec.Encode(shred) + if err != nil { + t.Fatalf("Failed to encode shred: %v", err) + } + + if len(encoded) != prefixSize+len(shred.Data) { + t.Errorf("Encoded data size mismatch: got %v, want %v", len(encoded), prefixSize+len(shred.Data)) + } +} diff --git a/gturbine/gtencoding/encoder.go b/gturbine/gtencoding/encoder.go new file mode 100644 index 0000000..d2c943b --- /dev/null +++ b/gturbine/gtencoding/encoder.go @@ -0,0 +1,8 @@ +package gtencoding + +import "github.com/gordian-engine/gordian/gturbine" + +type ShardCodec interface { + Encode(shred *gturbine.Shred) ([]byte, error) + Decode(data []byte) (*gturbine.Shred, error) +} diff --git a/gturbine/gtencoding/erasure.go b/gturbine/gtencoding/erasure.go new file mode 100644 index 0000000..d5032d6 --- /dev/null +++ b/gturbine/gtencoding/erasure.go @@ -0,0 +1,88 @@ +package gtencoding + +import ( + "fmt" + + "github.com/klauspost/reedsolomon" +) + +const maxTotalShreds = 128 + +type Encoder struct { + enc reedsolomon.Encoder + dataShreds int + recoveryShreds int +} + +func NewEncoder(dataShreds, recoveryShreds int) (*Encoder, error) { + if dataShreds <= 0 { + return nil, fmt.Errorf("data shreds must be > 0") + } + if recoveryShreds <= 0 { + return nil, fmt.Errorf("recovery shreds must be > 0") + } + if dataShreds+recoveryShreds > maxTotalShreds { + return nil, fmt.Errorf("total shreds must be <= %d", maxTotalShreds) + } + + enc, err := reedsolomon.New(dataShreds, recoveryShreds) + if err != nil { + return nil, fmt.Errorf("failed to create reed-solomon encoder: %w", err) + } + + return &Encoder{ + enc: enc, + dataShreds: dataShreds, + recoveryShreds: recoveryShreds, + }, nil +} + +func (e *Encoder) GenerateRecoveryShreds(shreds [][]byte) ([][]byte, error) { + if len(shreds) != e.dataShreds { + return nil, fmt.Errorf("expected %d data shreds, got %d", e.dataShreds, len(shreds)) + } + + totalShreds := make([][]byte, e.dataShreds+e.recoveryShreds) + copy(totalShreds, shreds) + + for i := e.dataShreds; i < len(totalShreds); i++ { + totalShreds[i] = make([]byte, len(shreds[0])) + } + + if err := e.enc.Encode(totalShreds); err != nil { + return nil, fmt.Errorf("encoding failed: %w", err) + } + + return totalShreds[e.dataShreds:], nil +} + +func (e *Encoder) Reconstruct(allShreds [][]byte) error { + if len(allShreds) != e.dataShreds+e.recoveryShreds { + return fmt.Errorf("expected %d total shreds, got %d", e.dataShreds+e.recoveryShreds, len(allShreds)) + } + + // Count non-nil shreds + validShreds := 0 + for _, shred := range allShreds { + if shred != nil { + validShreds++ + } + } + + // Need at least dataShreds valid pieces for reconstruction + if validShreds < e.dataShreds { + return fmt.Errorf("insufficient shreds for reconstruction: have %d, need %d", validShreds, e.dataShreds) + } + + if err := e.enc.Reconstruct(allShreds); err != nil { + return fmt.Errorf("reconstruction failed: %w", err) + } + return nil +} + +func (e *Encoder) Verify(allShreds [][]byte) (bool, error) { + if len(allShreds) != e.dataShreds+e.recoveryShreds { + return false, fmt.Errorf("expected %d total shreds, got %d", e.dataShreds+e.recoveryShreds, len(allShreds)) + } + return e.enc.Verify(allShreds) +} diff --git a/gturbine/gtencoding/erasure_test.go b/gturbine/gtencoding/erasure_test.go new file mode 100644 index 0000000..27e2590 --- /dev/null +++ b/gturbine/gtencoding/erasure_test.go @@ -0,0 +1,189 @@ +package gtencoding + +import ( + "bytes" + "crypto/rand" + "fmt" + mrand "math/rand" + "testing" +) + +const ( + // NOTE: blocksize needs to be a multiple of 64 for reed solomon to work + blockSize = 128 * 1024 * 1024 // 128MB blocks same as solana + numTests = 5 // Number of iterations for randomized tests +) + +func TestEncoderRealWorld(t *testing.T) { + t.Run("solana-like configuration", func(t *testing.T) { + enc, err := NewEncoder(32, 32) + if err != nil { + t.Fatal(err) + } + + shredSize := blockSize / 32 + dataShreds := make([][]byte, 32) + for i := range dataShreds { + dataShreds[i] = make([]byte, shredSize) + if _, err := rand.Read(dataShreds[i]); err != nil { + t.Fatal(err) + } + } + + recoveryShreds, err := enc.GenerateRecoveryShreds(dataShreds) + if err != nil { + t.Fatal(err) + } + + if len(recoveryShreds) != 32 { + t.Fatalf("expected 32 recovery shreds, got %d", len(recoveryShreds)) + } + + allShreds := append(dataShreds, recoveryShreds...) + + scenarios := []struct { + name string + numDataLost int + numParityLost int + shouldRecover bool + }{ + {"lose 16 data shreds", 16, 0, true}, + {"lose 16 data and 16 parity shreds", 16, 16, true}, + {"lose 31 data shreds", 31, 0, true}, + {"lose all parity shreds", 0, 32, true}, + {"lose 32 data shreds and 1 parity shred", 32, 1, false}, + {"lose 31 data and 2 parity shreds", 31, 2, false}, + } + + for _, sc := range scenarios { + t.Run(sc.name, func(t *testing.T) { + testShreds := make([][]byte, len(allShreds)) + copy(testShreds, allShreds) + + // Remove data shreds + for i := 0; i < sc.numDataLost; i++ { + testShreds[i] = nil + } + + // Remove parity shreds + for i := 0; i < sc.numParityLost; i++ { + testShreds[32+i] = nil + } + + err := enc.Reconstruct(testShreds) + if sc.shouldRecover { + if err != nil { + t.Errorf("failed to reconstruct when it should: %v", err) + return + } + // Verify reconstruction + for i := range dataShreds { + if !bytes.Equal(testShreds[i], dataShreds[i]) { + t.Errorf("shard %d not properly reconstructed", i) + } + } + } else if err == nil { + t.Error("reconstruction succeeded when it should have failed") + } + }) + } + }) + + t.Run("random failure patterns", func(t *testing.T) { + enc, _ := NewEncoder(32, 32) + shredSize := blockSize / 32 + + for i := 0; i < numTests; i++ { + dataShreds := make([][]byte, 32) + for j := range dataShreds { + dataShreds[j] = make([]byte, shredSize) + if _, err := rand.Read(dataShreds[j]); err != nil { + t.Fatal(err) + } + } + + recoveryShreds, err := enc.GenerateRecoveryShreds(dataShreds) + if err != nil { + t.Fatal(err) + } + + allShreds := append(dataShreds, recoveryShreds...) + testShreds := make([][]byte, len(allShreds)) + copy(testShreds, allShreds) + + numToRemove := mrand.Intn(32) + removedIndices := make(map[int]bool) + for j := 0; j < numToRemove; j++ { + for { + idx := mrand.Intn(len(testShreds)) + if !removedIndices[idx] { + testShreds[idx] = nil + removedIndices[idx] = true + break + } + } + } + + err = enc.Reconstruct(testShreds) + if numToRemove >= 32 { + if err == nil { + t.Errorf("test %d: reconstruction succeeded with %d shreds removed", i, numToRemove) + } + } else { + if err != nil { + t.Errorf("test %d: failed to reconstruct with %d shreds removed: %v", i, numToRemove, err) + continue + } + + for j := range dataShreds { + if !bytes.Equal(testShreds[j], dataShreds[j]) { + t.Errorf("test %d: shard %d not properly reconstructed", i, j) + } + } + } + } + }) + + t.Run("performance benchmarks", func(t *testing.T) { + if testing.Short() { + t.Skip("skipping performance test in short mode") + } + + enc, _ := NewEncoder(32, 32) + shredSize := blockSize / 32 + dataShreds := make([][]byte, 32) + for i := range dataShreds { + dataShreds[i] = make([]byte, shredSize) + rand.Read(dataShreds[i]) + } + + recoveryShreds, err := enc.GenerateRecoveryShreds(dataShreds) + if err != nil { + t.Fatal(err) + } + + allShreds := append(dataShreds, recoveryShreds...) + lostCounts := []int{8, 16, 24, 31} + + for _, count := range lostCounts { + t.Run(fmt.Sprintf("reconstruct_%d_lost", count), func(t *testing.T) { + testShreds := make([][]byte, len(allShreds)) + copy(testShreds, allShreds) + + for i := 0; i < count; i++ { + testShreds[i] = nil + } + + if err := enc.Reconstruct(testShreds); err != nil { + t.Fatal(err) + } + + for i := range dataShreds { + if !bytes.Equal(testShreds[i], dataShreds[i]) { + t.Errorf("shard %d not properly reconstructed", i) + } + } + }) + } + }) +} diff --git a/gturbine/gtnetwork/transport.go b/gturbine/gtnetwork/transport.go new file mode 100644 index 0000000..62ad834 --- /dev/null +++ b/gturbine/gtnetwork/transport.go @@ -0,0 +1,138 @@ +package gtnetwork + +import ( + "context" + "fmt" + "net" + "sync" +) + +// Transport handles shred sending/receiving over UDP +type Transport struct { + basePort int + numPorts int + listeners []*net.UDPConn + handlers []ShredHandler + ctx context.Context + cancel context.CancelFunc + mu sync.RWMutex +} + +// ShredHandler processes received shreds +type ShredHandler interface { + HandleShred(data []byte, from net.Addr) error +} + +// Config contains Transport configuration +type Config struct { + BasePort int + NumPorts int +} + +// DefaultConfig returns standard Transport configuration +func DefaultConfig() Config { + return Config{ + BasePort: 12000, + NumPorts: 10, + } +} + +// NewTransport creates a new Transport instance +func NewTransport(cfg Config) *Transport { + ctx, cancel := context.WithCancel(context.Background()) + return &Transport{ + basePort: cfg.BasePort, + numPorts: cfg.NumPorts, + listeners: make([]*net.UDPConn, 0, cfg.NumPorts), + ctx: ctx, + cancel: cancel, + } +} + +// Start initializes all UDP listeners +func (t *Transport) Start() error { + t.mu.Lock() + defer t.mu.Unlock() + + for i := 0; i < t.numPorts; i++ { + addr := &net.UDPAddr{Port: t.basePort + i} + conn, err := net.ListenUDP("udp", addr) + if err != nil { + t.close() + return fmt.Errorf("failed to start UDP listener on port %d: %w", addr.Port, err) + } + t.listeners = append(t.listeners, conn) + go t.listen(conn) + } + return nil +} + +// Stop gracefully shuts down the transport +func (t *Transport) Stop() { + t.cancel() + t.close() +} + +// BasePort returns the base port number for this transport +func (t *Transport) BasePort() int { + return t.basePort +} + +// AddHandler registers a new shred handler +func (t *Transport) AddHandler(h ShredHandler) { + t.mu.Lock() + defer t.mu.Unlock() + t.handlers = append(t.handlers, h) +} + +// SendShred sends data to the specified address +func (t *Transport) SendShred(data []byte, to *net.UDPAddr) error { + t.mu.RLock() + defer t.mu.RUnlock() + + // Try each listener until send succeeds + var lastErr error + for _, conn := range t.listeners { + _, err := conn.WriteToUDP(data, to) + if err == nil { + return nil + } + lastErr = err + } + return fmt.Errorf("failed to send shred: %w", lastErr) +} + +func (t *Transport) listen(conn *net.UDPConn) { + buf := make([]byte, 65507) // Max UDP packet size + + for { + select { + case <-t.ctx.Done(): + return + default: + n, addr, err := conn.ReadFromUDP(buf) + if err != nil { + continue + } + + data := make([]byte, n) + copy(data, buf[:n]) + + t.mu.RLock() + for _, h := range t.handlers { + go h.HandleShred(data, addr) + } + t.mu.RUnlock() + } + } +} + +func (t *Transport) close() { + t.mu.Lock() + defer t.mu.Unlock() + + for _, l := range t.listeners { + l.Close() + } + t.listeners = nil +} diff --git a/gturbine/gtnetwork/transport_test.go b/gturbine/gtnetwork/transport_test.go new file mode 100644 index 0000000..df2f8bb --- /dev/null +++ b/gturbine/gtnetwork/transport_test.go @@ -0,0 +1,105 @@ +package gtnetwork + +import ( + "net" + "sync" + "testing" + "time" +) + +type testHandler struct { + mu sync.Mutex + shreds [][]byte + addrs []net.Addr +} + +func (h *testHandler) HandleShred(data []byte, from net.Addr) error { + h.mu.Lock() + defer h.mu.Unlock() + h.shreds = append(h.shreds, data) + h.addrs = append(h.addrs, from) + return nil +} + +func TestTransport(t *testing.T) { + cfg := Config{ + BasePort: 30000, // Use high ports for testing + NumPorts: 2, + } + + tr := NewTransport(cfg) + handler := &testHandler{} + tr.AddHandler(handler) + + if err := tr.Start(); err != nil { + t.Fatalf("Start failed: %v", err) + } + defer tr.Stop() + + // Allow listeners to start + time.Sleep(100 * time.Millisecond) + + // Test send/receive + testData := []byte("test data") + addr := &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: cfg.BasePort, + } + + if err := tr.SendShred(testData, addr); err != nil { + t.Fatalf("SendShred failed: %v", err) + } + + time.Sleep(100 * time.Millisecond) + + handler.mu.Lock() + if len(handler.shreds) != 1 { + t.Fatalf("Expected 1 shred, got %d", len(handler.shreds)) + } + if string(handler.shreds[0]) != string(testData) { + t.Errorf("Expected %q, got %q", testData, handler.shreds[0]) + } + handler.mu.Unlock() +} + +func TestMultipleHandlers(t *testing.T) { + cfg := Config{ + BasePort: 30100, + NumPorts: 1, + } + + tr := NewTransport(cfg) + h1, h2 := &testHandler{}, &testHandler{} + tr.AddHandler(h1) + tr.AddHandler(h2) + + if err := tr.Start(); err != nil { + t.Fatalf("Start failed: %v", err) + } + defer tr.Stop() + + time.Sleep(100 * time.Millisecond) + + testData := []byte("test multiple handlers") + addr := &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: cfg.BasePort, + } + + if err := tr.SendShred(testData, addr); err != nil { + t.Fatalf("SendShred failed: %v", err) + } + + time.Sleep(100 * time.Millisecond) + + for i, h := range []*testHandler{h1, h2} { + h.mu.Lock() + if len(h.shreds) != 1 { + t.Errorf("Handler %d: Expected 1 shred, got %d", i, len(h.shreds)) + } + if string(h.shreds[0]) != string(testData) { + t.Errorf("Handler %d: Expected %q, got %q", i, testData, h.shreds[0]) + } + h.mu.Unlock() + } +} diff --git a/gturbine/gtshred/benchmark_test.go b/gturbine/gtshred/benchmark_test.go new file mode 100644 index 0000000..4c98e6e --- /dev/null +++ b/gturbine/gtshred/benchmark_test.go @@ -0,0 +1,133 @@ +package gtshred + +import ( + "context" + "crypto/rand" + "testing" + "time" +) + +type noopCallback struct{} + +func (n *noopCallback) ProcessBlock(height uint64, blockHash []byte, block []byte) error { + return nil +} + +func BenchmarkShredProcessing(b *testing.B) { + sizes := []struct { + name string + size int + chunkSize uint32 + }{ + {"8MB", 8 << 20, 1 << 18}, // 256KB chunks + {"16MB", 16 << 20, 1 << 19}, // 512KB chunks + {"32MB", 32 << 20, 1 << 20}, // 1MB chunks + {"64MB", 64 << 20, 1 << 21}, // 2MB chunks + {"128MB", 128 << 20, 1 << 22}, // 4MB chunks + } + + for _, size := range sizes { + b.Run(size.name, func(b *testing.B) { + // Generate random block data + block := make([]byte, size.size) + _, err := rand.Read(block) + if err != nil { + b.Fatal(err) + } + + // Create processor with noop callback + p := NewProcessor(&noopCallback{}, time.Minute) + go p.RunBackgroundCleanup(context.Background()) + + // Reset timer before main benchmark loop + b.ResetTimer() + + for i := 0; i < b.N; i++ { + // Create shred group with appropriate chunk size + group, err := NewShredGroup(block, uint64(i), 32, 32, size.chunkSize) + if err != nil { + b.Fatal(err) + } + + // Process all data shreds + for _, shred := range group.DataShreds { + if err := p.CollectShred(shred); err != nil { + b.Fatal(err) + } + } + + b.StopTimer() + // Reset processor state between iterations + p.groups = make(map[string]*ShredGroupWithTimestamp) + p.completedBlocks = make(map[string]time.Time) + b.StartTimer() + } + }) + } +} + +func BenchmarkShredReconstruction(b *testing.B) { + // Test reconstruction with different loss patterns + patterns := []struct { + name string + lossRate float64 + }{ + {"10% Loss", 0.10}, + {"25% Loss", 0.25}, + {"40% Loss", 0.40}, + } + + // Use 32MB block with 1MB chunks + block := make([]byte, 32<<20) + chunkSize := uint32(1 << 20) // 1MB chunks + + _, err := rand.Read(block) + if err != nil { + b.Fatal(err) + } + + for _, pattern := range patterns { + b.Run(pattern.name, func(b *testing.B) { + // Create processor + p := NewProcessor(&noopCallback{}, time.Minute) + go p.RunBackgroundCleanup(context.Background()) + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + // Create shred group with appropriate chunk size + group, err := NewShredGroup(block, uint64(i), 32, 32, chunkSize) + if err != nil { + b.Fatal(err) + } + + // Simulate packet loss + lossCount := int(float64(len(group.DataShreds)) * pattern.lossRate) + for j := 0; j < lossCount; j++ { + group.DataShreds[j] = nil + } + + // Process remaining shreds + for _, shred := range group.DataShreds { + if shred != nil { + if err := p.CollectShred(shred); err != nil { + b.Fatal(err) + } + } + } + + // Process recovery shreds + for _, shred := range group.RecoveryShreds { + if err := p.CollectShred(shred); err != nil { + b.Fatal(err) + } + } + + b.StopTimer() + p.groups = make(map[string]*ShredGroupWithTimestamp) + p.completedBlocks = make(map[string]time.Time) + b.StartTimer() + } + }) + } +} diff --git a/gturbine/gtshred/process_shred_test.go b/gturbine/gtshred/process_shred_test.go new file mode 100644 index 0000000..393c6ba --- /dev/null +++ b/gturbine/gtshred/process_shred_test.go @@ -0,0 +1,298 @@ +package gtshred + +import ( + "bytes" + "context" + "crypto/rand" + "crypto/sha256" + "testing" + "time" +) + +const ( + DefaultChunkSize = 64 * 1024 // 64KB + DefaultDataShreds = 16 // Number of data shreds + DefaultRecoveryShreds = 4 // Number of recovery shreds + TestHeight = uint64(1000) // Test block height +) + +func makeRandomBlock(size int) []byte { + block := make([]byte, size) + if _, err := rand.Read(block); err != nil { + panic(err) + } + return block +} + +func corrupt(data []byte) { + if len(data) > 0 { + // Flip some bits in the middle of the data + mid := len(data) / 2 + data[mid] ^= 0xFF + if len(data) > mid+1 { + data[mid+1] ^= 0xFF + } + } +} + +type testCase struct { + name string + blockSize int + corrupt []int // indices of shreds to corrupt and then mark as missing + remove []int // indices of shreds to remove + expectErr bool +} + +type testProcessorCallback struct { + count int + blockHash []byte + data []byte +} + +func (cb *testProcessorCallback) ProcessBlock(height uint64, blockHash []byte, block []byte) error { + cb.count++ + cb.data = block + cb.blockHash = blockHash + return nil +} + +func TestProcessorShredding(t *testing.T) { + tests := []testCase{ + { + name: "even block size", + blockSize: DefaultChunkSize * DefaultDataShreds, + }, + { + name: "uneven block size", + blockSize: DefaultChunkSize*DefaultDataShreds - 1000, + }, + { + name: "oversized block", + blockSize: DefaultChunkSize*DefaultDataShreds + 1, + expectErr: true, + }, + { + name: "minimum block size", + blockSize: 1, + }, + { + name: "empty block", + blockSize: 0, + expectErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var cb = new(testProcessorCallback) + + p := NewProcessor(cb, time.Minute) + go p.RunBackgroundCleanup(context.Background()) + + block := makeRandomBlock(tc.blockSize) + group, err := NewShredGroup(block, TestHeight, DefaultDataShreds, DefaultRecoveryShreds, DefaultChunkSize) + + if tc.expectErr { + if err == nil { + t.Error("expected error but got none") + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Verify all shreds are properly sized + for i := range group.DataShreds { + if len(group.DataShreds[i].Data) != int(DefaultChunkSize) { + t.Errorf("data shred %d wrong size: got %d want %d", + i, len(group.DataShreds[i].Data), DefaultChunkSize) + } + } + + // Collect threshold shreds into processor + + // collect all data shreds except the last 4, so that recovery shreds are necessary to reassemble + for i := 0; i < DefaultDataShreds-4; i++ { + p.CollectShred(group.DataShreds[i]) + } + + // collect all recovery shreds + for i := 0; i < DefaultRecoveryShreds; i++ { + p.CollectShred(group.RecoveryShreds[i]) + } + + if p.cb.(*testProcessorCallback).count != 1 { + t.Error("expected ProcessBlock to be called once") + } + + blockHash := sha256.Sum256(block) + + if !bytes.Equal(blockHash[:], cb.blockHash) { + t.Errorf("block hash mismatch: got %v want %v", cb.blockHash, group.BlockHash) + } + + if !bytes.Equal(block, cb.data) { + t.Errorf("reassembled block doesn't match original: got len %d want len %d", + len(cb.data), len(block)) + } + + }) + } +} + +// func TestProcessorRecovery(t *testing.T) { +// tests := []testCase{ +// { +// name: "recover with missing data shreds", +// blockSize: DefaultChunkSize * (DefaultDataShreds - 1), +// remove: []int{0, 1}, // Remove first two data shreds +// }, +// { +// name: "recover with corrupted data shreds", +// blockSize: DefaultChunkSize * DefaultDataShreds, +// corrupt: []int{0, 1}, // Corrupt first two data shreds +// }, +// { +// name: "too many missing shreds", +// blockSize: DefaultChunkSize * DefaultDataShreds, +// remove: []int{0, 1, 2, 3, 4, 5}, // Remove more than recoverable +// expectErr: true, +// }, +// { +// name: "mixed corruption and missing", +// blockSize: DefaultChunkSize * DefaultDataShreds, +// corrupt: []int{0}, +// remove: []int{1}, +// }, +// { +// name: "boundary size block with last shred corrupted", +// blockSize: DefaultChunkSize*DefaultDataShreds - 1, +// corrupt: []int{DefaultDataShreds - 1}, // Corrupt last shred +// }, +// } + +// var cb = new(testProcessorCallback) + +// for _, tc := range tests { +// t.Run(tc.name, func(t *testing.T) { +// p, err := NewProcessor(DefaultChunkSize, DefaultDataShreds, DefaultRecoveryShreds) +// if err != nil { +// t.Fatal(err) +// } + +// block := makeRandomBlock(tc.blockSize) +// group, err := p.ProcessBlock(block, TestHeight) +// if err != nil { +// t.Fatal(err) +// } + +// // Apply corruptions - corrupted shreds are immediately marked as nil +// for _, idx := range tc.corrupt { +// if idx < len(group.DataShreds) && group.DataShreds[idx] != nil { +// // First corrupt the data +// corrupt(group.DataShreds[idx].Data) +// // Then mark it as missing since it's corrupted +// group.DataShreds[idx] = nil +// } +// } + +// // Remove shreds +// for _, idx := range tc.remove { +// if idx < len(group.DataShreds) { +// group.DataShreds[idx] = nil +// } +// } + +// // Try reassembly +// reassembled, err := p.ReassembleBlock(group) + +// if tc.expectErr { +// if err == nil { +// t.Error("expected error but got none") +// } +// return +// } + +// if err != nil { +// t.Fatalf("unexpected error: %v", err) +// } + +// if !bytes.Equal(block, reassembled) { +// t.Errorf("reassembled block doesn't match original: got len %d want len %d", +// len(reassembled), len(block)) +// } +// }) +// } +// } + +// func TestProcessorEdgeCases(t *testing.T) { +// t.Run("nil group", func(t *testing.T) { +// p, _ := NewProcessor(DefaultChunkSize, DefaultDataShreds, DefaultRecoveryShreds) +// _, err := p.ReassembleBlock(nil) +// if err == nil { +// t.Error("expected error for nil group") +// } +// }) + +// t.Run("mismatched heights", func(t *testing.T) { +// p, _ := NewProcessor(DefaultChunkSize, DefaultDataShreds, DefaultRecoveryShreds) +// block := makeRandomBlock(DefaultChunkSize) +// group, _ := p.ProcessBlock(block, TestHeight) + +// // Modify a shred height +// group.DataShreds[0].Height = TestHeight + 1 + +// _, err := p.ReassembleBlock(group) +// if err == nil { +// t.Error("expected error for mismatched heights") +// } +// }) + +// t.Run("invalid chunk size", func(t *testing.T) { +// _, err := NewProcessor(0, DefaultDataShreds, DefaultRecoveryShreds) +// if err == nil { +// t.Error("expected error for chunk size 0") +// } + +// _, err = NewProcessor(maxChunkSize+1, DefaultDataShreds, DefaultRecoveryShreds) +// if err == nil { +// t.Error("expected error for chunk size > max") +// } +// }) +// } + +// func BenchmarkProcessor(b *testing.B) { +// sizes := []int{ +// 1024, // 1KB +// 1024 * 1024, // 1MB +// 10 * 1024 * 1024, // 10MB +// } + +// for _, size := range sizes { +// b.Run(b.Name(), func(b *testing.B) { +// p, err := NewProcessor(DefaultChunkSize, DefaultDataShreds, DefaultRecoveryShreds) +// if err != nil { +// b.Fatal(err) +// } + +// block := makeRandomBlock(size) +// b.ResetTimer() + +// for i := 0; i < b.N; i++ { +// group, err := p.ProcessBlock(block, TestHeight) +// if err != nil { +// b.Fatal(err) +// } + +// _, err = p.ReassembleBlock(group) +// if err != nil { +// b.Fatal(err) +// } +// } + +// b.SetBytes(int64(size)) +// }) +// } +// } diff --git a/gturbine/gtshred/processor.go b/gturbine/gtshred/processor.go new file mode 100644 index 0000000..1c62858 --- /dev/null +++ b/gturbine/gtshred/processor.go @@ -0,0 +1,238 @@ +package gtshred + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/gordian-engine/gordian/gturbine" + "github.com/gordian-engine/gordian/gturbine/gtencoding" +) + +// Constants for error checking +const ( + minChunkSize = 1024 // 1KB minimum + maxChunkSize = 1 << 20 // 1MB maximum chunk size + maxBlockSize = 128 * 1024 * 1024 // 128MB maximum block size (matches Solana) +) + +// ShredGroupWithTimestamp is a ShredGroup with a timestamp for tracking when the group was created (when the first shred was received). +type ShredGroupWithTimestamp struct { + *ShredGroup + Timestamp time.Time +} + +type Processor struct { + // cb is the callback to call when a block is fully reassembled + cb ProcessorCallback + + // groups is a cache of shred groups currently being processed. + groups map[string]*ShredGroupWithTimestamp + groupsMu sync.RWMutex + + // completedBlocks is a cache of block hashes that have been fully reassembled and should no longer be processed. + completedBlocks map[string]time.Time + completedBlocksMu sync.RWMutex + + // cleanupInterval is the interval at which stale groups are cleaned up and completed blocks are removed + cleanupInterval time.Duration +} + +// ProcessorCallback is the interface for processor callbacks. +type ProcessorCallback interface { + ProcessBlock(height uint64, blockHash []byte, block []byte) error +} + +// NewProcessor creates a new Processor with the given callback and cleanup interval. +func NewProcessor(cb ProcessorCallback, cleanupInterval time.Duration) *Processor { + return &Processor{ + cb: cb, + groups: make(map[string]*ShredGroupWithTimestamp), + completedBlocks: make(map[string]time.Time), + cleanupInterval: cleanupInterval, + } +} + +// RunBackgroundCleanup starts a cleanup loop that runs at the cleanup interval. +// This should be run as a goroutine. +func (p *Processor) RunBackgroundCleanup(ctx context.Context) { + ticker := time.NewTicker(p.cleanupInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case now := <-ticker.C: + p.cleanupStaleGroups(now) + } + } +} + +// CollectShred processes an incoming data shred. +func (p *Processor) CollectShred(shred *gturbine.Shred) error { + if shred == nil { + return fmt.Errorf("nil shred") + } + + // Skip shreds from already processed blocks + if p.isCompleted(shred.BlockHash) { + return nil + } + + group, ok := p.getGroup(shred.GroupID) + if !ok { + // If the group doesn't exist, create it and add the shred + return p.initGroup(shred) + } + + group.mu.Lock() + defer group.mu.Unlock() + + // After locking the group, check if the block has already been completed. + if p.isCompleted(group.BlockHash) { + return nil + } + + full, err := group.collectShred(shred) + if err != nil { + return fmt.Errorf("failed to collect data shred: %w", err) + } + if full { + encoder, err := gtencoding.NewEncoder(group.TotalDataShreds, group.TotalRecoveryShreds) + if err != nil { + return fmt.Errorf("failed to create encoder: %w", err) + } + + block, err := group.reconstructBlock(encoder) + if err != nil { + return fmt.Errorf("failed to reconstruct block: %w", err) + } + + if err := p.cb.ProcessBlock(shred.Height, shred.BlockHash, block); err != nil { + return fmt.Errorf("failed to process block: %w", err) + } + + p.deleteGroup(shred.GroupID) + // then mark the block as completed at time.Now() + p.setCompleted(shred.BlockHash) + } + return nil +} + +// cleanupStaleGroups removes groups that have been inactive for longer than the cleanup interval. +func (p *Processor) cleanupStaleGroups(now time.Time) { + var deleteHashes []string + + p.completedBlocksMu.RLock() + for hash, completedAt := range p.completedBlocks { + if now.Sub(completedAt) > p.cleanupInterval { + deleteHashes = append(deleteHashes, hash) + } + } + p.completedBlocksMu.RUnlock() + + if len(deleteHashes) != 0 { + // Take write lock once for all deletions + p.completedBlocksMu.Lock() + for _, hash := range deleteHashes { + delete(p.completedBlocks, hash) + } + p.completedBlocksMu.Unlock() + } + + var deleteGroups []string + + // Take read lock on groups to check for groups to delete (stale or duplicate blockhash) + p.groupsMu.RLock() + for id, group := range p.groups { + for _, hash := range deleteHashes { + // Check if group is associated with a completed block + if string(group.BlockHash) == hash { + deleteGroups = append(deleteGroups, id) + } + } + + // Check if group is stale + if now.Sub(group.Timestamp) > p.cleanupInterval { + deleteGroups = append(deleteGroups, id) + } + } + p.groupsMu.RUnlock() + + if len(deleteGroups) != 0 { + // Take write lock once for all deletions + p.groupsMu.Lock() + for _, id := range deleteGroups { + delete(p.groups, id) + } + p.groupsMu.Unlock() + } +} + +// initGroup creates a new group and adds the first shred to it. +func (p *Processor) initGroup(shred *gturbine.Shred) error { + now := time.Now() + group := &ShredGroup{ + DataShreds: make([]*gturbine.Shred, shred.TotalDataShreds), + RecoveryShreds: make([]*gturbine.Shred, shred.TotalRecoveryShreds), + TotalDataShreds: shred.TotalDataShreds, + TotalRecoveryShreds: shred.TotalRecoveryShreds, + GroupID: shred.GroupID, + BlockHash: shred.BlockHash, + Height: shred.Height, + OriginalSize: shred.FullDataSize, + } + + group.DataShreds[shred.Index] = shred + + p.groupsMu.Lock() + + if _, ok := p.groups[shred.GroupID]; ok { + // If a group already exists, return early to avoid overwriting + p.groupsMu.Unlock() + + // Collect the shred into the existing group + return p.CollectShred(shred) + } + + defer p.groupsMu.Unlock() + + p.groups[shred.GroupID] = &ShredGroupWithTimestamp{ + ShredGroup: group, + Timestamp: now, + } + + return nil +} + +// getGroup returns the group with the given ID, if it exists. +func (p *Processor) getGroup(groupID string) (*ShredGroupWithTimestamp, bool) { + p.groupsMu.RLock() + defer p.groupsMu.RUnlock() + group, ok := p.groups[groupID] + return group, ok +} + +// deleteGroup removes the group with the given ID from the processor. +func (p *Processor) deleteGroup(groupID string) { + p.groupsMu.Lock() + defer p.groupsMu.Unlock() + delete(p.groups, groupID) +} + +// setCompleted marks a block as completed. +func (p *Processor) setCompleted(blockHash []byte) { + p.completedBlocksMu.Lock() + defer p.completedBlocksMu.Unlock() + p.completedBlocks[string(blockHash)] = time.Now() +} + +// isCompleted checks if a block has been marked as completed. +func (p *Processor) isCompleted(blockHash []byte) bool { + p.completedBlocksMu.RLock() + defer p.completedBlocksMu.RUnlock() + _, ok := p.completedBlocks[string(blockHash)] + return ok +} diff --git a/gturbine/gtshred/processor_test.go b/gturbine/gtshred/processor_test.go new file mode 100644 index 0000000..15e3293 --- /dev/null +++ b/gturbine/gtshred/processor_test.go @@ -0,0 +1,224 @@ +package gtshred + +import ( + "context" + "testing" + "time" +) + +func TestProcessorMemoryCleanup(t *testing.T) { + // Create processor with short cleanup interval for testing + var cb = new(testProcessorCallback) + cleanupInterval := 100 * time.Millisecond + p := NewProcessor(cb, cleanupInterval) + go p.RunBackgroundCleanup(context.Background()) + + // Create a test block and shred group + block := []byte("test block data") + group, err := NewShredGroup(block, 1, 2, 1, 100) + if err != nil { + t.Fatal(err) + } + + // Process some shreds from the group to mark it complete + for i := 0; i < len(group.DataShreds); i++ { + err := p.CollectShred(group.DataShreds[i]) + if err != nil { + t.Fatal(err) + } + } + + // Verify block is marked as completed + if _, exists := p.completedBlocks[string(group.BlockHash)]; !exists { + t.Error("block should be marked as completed") + } + + // Try to process another shred from same block + err = p.CollectShred(group.RecoveryShreds[0]) + if err != nil { + t.Fatal(err) + } + + // Verify no new group was created for this block + groupCount := len(p.groups) + if groupCount > 1 { + t.Errorf("expected at most 1 group, got %d", groupCount) + } + + // Wait for cleanup + time.Sleep(cleanupInterval * 4) + + // Verify completed block was cleaned up + p.completedBlocksMu.RLock() + defer p.completedBlocksMu.RUnlock() + if _, exists := p.completedBlocks[string(group.BlockHash)]; exists { + t.Error("completed block should have been cleaned up") + } +} + +// func TestProcessor(t *testing.T) { +// t.Run("basic shred and reassemble", func(t *testing.T) { +// // Use 32:32 config +// var cb = new(testProcessorCallback) +// processor, err := NewProcessor(cb)) +// if err != nil { +// t.Fatal(err) +// } + +// // Calculate a valid block size based on configuration +// blockSize := int(DefaultChunkSize) * 32 // 32 data shreds +// block := make([]byte, blockSize) +// if _, err := rand.Read(block); err != nil { +// t.Fatal(err) +// } + +// group, err := processor.ProcessBlock(block, 1) +// if err != nil { +// t.Fatal(err) +// } + +// if group == nil { +// t.Fatal("expected non-nil group") +// } + +// if len(group.DataShreds) != 32 { +// t.Errorf("expected %d data shreds, got %d", 32, len(group.DataShreds)) +// } + +// reassembled, err := processor.ReassembleBlock(group) +// if err != nil { +// t.Fatal(err) +// } + +// if !bytes.Equal(block, reassembled) { +// t.Error("reassembled block does not match original") +// } +// }) + +// t.Run("block size constraints", func(t *testing.T) { +// processor, _ := NewProcessor(DefaultChunkSize, 32, 32) + +// // Test block exactly at configured max size +// maxSize := int(DefaultChunkSize) * 32 +// block := make([]byte, maxSize) +// if _, err := rand.Read(block); err != nil { +// t.Fatal(err) +// } + +// if _, err := processor.ProcessBlock(block, 1); err != nil { +// t.Errorf("failed to process max size block: %v", err) +// } + +// // Test oversized block +// block = make([]byte, maxSize+1) +// if _, err := rand.Read(block); err != nil { +// t.Fatal(err) +// } + +// if _, err := processor.ProcessBlock(block, 1); err == nil { +// t.Error("expected error for oversized block") +// } +// }) + +// t.Run("packet loss scenarios", func(t *testing.T) { +// scenarios := []struct { +// name string +// lossRate float64 +// shouldRecover bool +// }{ +// {"15% loss", 0.15, true}, +// {"45% loss", 0.45, true}, // Now recoverable with 32:32 configuration +// {"60% loss", 0.60, false}, // Still too many losses to recover from +// } + +// for _, sc := range scenarios { +// t.Run(sc.name, func(t *testing.T) { +// processor, _ := NewProcessor(DefaultChunkSize, 32, 32) + +// // Use max configured block size +// blockSize := int(DefaultChunkSize) * 32 +// block := make([]byte, blockSize) +// if _, err := rand.Read(block); err != nil { +// t.Fatal(err) +// } + +// group, err := processor.ProcessBlock(block, 1) +// if err != nil { +// t.Fatal(err) +// } + +// // Calculate shreds to drop +// totalShreds := len(group.DataShreds) + len(group.RecoveryShreds) +// dropCount := int(math.Round(float64(totalShreds) * sc.lossRate)) + +// // Track which shreds we've dropped +// dropped := make(map[int]bool) +// for i := 0; i < dropCount; i++ { +// var idx int +// // Keep trying until we find an undropped shred +// for { +// idx = mrand.Intn(totalShreds) +// if !dropped[idx] { +// dropped[idx] = true +// break +// } +// } + +// // Drop the shred +// if idx < len(group.DataShreds) { +// group.DataShreds[idx] = nil +// } else { +// group.RecoveryShreds[idx-len(group.DataShreds)] = nil +// } +// } + +// // Attempt reassembly +// reassembled, err := processor.ReassembleBlock(group) +// if sc.shouldRecover { +// if err != nil { +// t.Errorf("expected recovery to succeed, got: %v", err) +// } else if !bytes.Equal(block, reassembled) { +// t.Error("reassembled block does not match original") +// } +// } else if err == nil { +// t.Error("expected recovery to fail") +// } +// }) +// } +// }) + +// t.Run("varying block sizes", func(t *testing.T) { +// processor, _ := NewProcessor(DefaultChunkSize, 32, 32) +// maxSize := int(DefaultChunkSize) * 32 + +// testSizes := []int{ +// DefaultChunkSize, // One chunk +// maxSize/2, // Half max size +// maxSize - DefaultChunkSize, // One chunk less than max +// maxSize, // Exact max size +// } + +// for _, size := range testSizes { +// block := make([]byte, size) +// if _, err := rand.Read(block); err != nil { +// t.Fatal(err) +// } + +// group, err := processor.ProcessBlock(block, 1) +// if err != nil { +// t.Errorf("failed to process block of size %d: %v", size, err) +// continue +// } + +// reassembled, err := processor.ReassembleBlock(group) +// if err != nil { +// t.Errorf("failed to reassemble block of size %d: %v", size, err) +// continue +// } + +// if !bytes.Equal(block, reassembled) { +// t.Errorf("mismatch for block size %d", size) +// } +// } +// }) +// } diff --git a/gturbine/gtshred/shred_group.go b/gturbine/gtshred/shred_group.go new file mode 100644 index 0000000..9b6bc8d --- /dev/null +++ b/gturbine/gtshred/shred_group.go @@ -0,0 +1,229 @@ +package gtshred + +import ( + "crypto/sha256" + "fmt" + "sync" + + "github.com/google/uuid" + "github.com/gordian-engine/gordian/gturbine" + "github.com/gordian-engine/gordian/gturbine/gtencoding" +) + +// ShredGroup represents a group of shreds that can be used to reconstruct a block. +type ShredGroup struct { + DataShreds []*gturbine.Shred + RecoveryShreds []*gturbine.Shred + TotalDataShreds int + TotalRecoveryShreds int + GroupID string // Changed to string for UUID + BlockHash []byte + Height uint64 // Added to struct level + OriginalSize int + + mu sync.Mutex +} + +// NewShredGroup creates a new ShredGroup from a block of data +func NewShredGroup(block []byte, height uint64, dataShreds, recoveryShreds int, chunkSize uint32) (*ShredGroup, error) { + if len(block) == 0 { + return nil, fmt.Errorf("empty block") + } + if len(block) > maxBlockSize { + return nil, fmt.Errorf("block too large: %d bytes exceeds max size %d", len(block), maxBlockSize) + } + if len(block) > int(chunkSize)*dataShreds { + return nil, fmt.Errorf("block too large for configured shred size: %d bytes exceeds max size %d", len(block), chunkSize*uint32(dataShreds)) + } + + // Create encoder for this block + encoder, err := gtencoding.NewEncoder(dataShreds, recoveryShreds) + if err != nil { + return nil, fmt.Errorf("failed to create encoder: %w", err) + } + + // Calculate block hash for verification + // TODO hasher should be interface. + blockHash := sha256.Sum256(block) + + // Create new shred group + group := &ShredGroup{ + DataShreds: make([]*gturbine.Shred, dataShreds), + RecoveryShreds: make([]*gturbine.Shred, recoveryShreds), + TotalDataShreds: dataShreds, + TotalRecoveryShreds: recoveryShreds, + GroupID: uuid.New().String(), + BlockHash: blockHash[:], + Height: height, + OriginalSize: len(block), + } + + // Create fixed-size data chunks + dataBytes := make([][]byte, dataShreds) + bytesPerShred := int(chunkSize) + + // Initialize all shreds to full chunk size with zeros + for i := 0; i < dataShreds; i++ { + dataBytes[i] = make([]byte, bytesPerShred) + } + + // Copy data into shreds + remaining := len(block) + offset := 0 + for i := 0; i < dataShreds && remaining > 0; i++ { + toCopy := remaining + if toCopy > bytesPerShred { + toCopy = bytesPerShred + } + copy(dataBytes[i], block[offset:offset+toCopy]) + offset += toCopy + remaining -= toCopy + } + + // Generate recovery data using erasure coding + recoveryBytes, err := encoder.GenerateRecoveryShreds(dataBytes) + if err != nil { + return nil, fmt.Errorf("failed to generate recovery shreds: %w", err) + } + + // Create data shreds + for i := range dataBytes { + group.DataShreds[i] = >urbine.Shred{ + Type: gturbine.DataShred, + Index: i, + TotalDataShreds: dataShreds, + TotalRecoveryShreds: recoveryShreds, + Data: dataBytes[i], + BlockHash: blockHash[:], + GroupID: group.GroupID, + Height: height, + FullDataSize: group.OriginalSize, + } + } + + // Create recovery shreds + for i := range recoveryBytes { + group.RecoveryShreds[i] = >urbine.Shred{ + Type: gturbine.RecoveryShred, + Index: i, + TotalDataShreds: dataShreds, + TotalRecoveryShreds: recoveryShreds, + Data: recoveryBytes[i], + BlockHash: blockHash[:], + GroupID: group.GroupID, + Height: height, + FullDataSize: group.OriginalSize, + } + } + + return group, nil +} + +// isFull checks if enough shreds are available to attempt reconstruction. +func (g *ShredGroup) isFull() bool { + valid := 0 + for _, s := range g.DataShreds { + if s != nil { + valid++ + } + } + + for _, s := range g.RecoveryShreds { + if s != nil { + valid++ + } + } + + return valid >= g.TotalDataShreds +} + +// reconstructBlock attempts to reconstruct the original block from available shreds +func (g *ShredGroup) reconstructBlock(encoder *gtencoding.Encoder) ([]byte, error) { + // Extract data bytes for erasure coding + allBytes := make([][]byte, len(g.DataShreds)+len(g.RecoveryShreds)) + + // Copy available data shreds + for i, shred := range g.DataShreds { + if shred != nil { + allBytes[i] = make([]byte, len(shred.Data)) + copy(allBytes[i], shred.Data) + } + } + + // Copy available recovery shreds + for i, shred := range g.RecoveryShreds { + if shred != nil { + allBytes[i+len(g.DataShreds)] = make([]byte, len(shred.Data)) + copy(allBytes[i+len(g.DataShreds)], shred.Data) + } + } + + // Reconstruct missing data + if err := encoder.Reconstruct(allBytes); err != nil { + return nil, fmt.Errorf("failed to reconstruct data: %w", err) + } + + // Combine data shreds + reconstructed := make([]byte, 0, g.OriginalSize) + remaining := g.OriginalSize + + for i := 0; i < len(g.DataShreds) && remaining > 0; i++ { + if allBytes[i] == nil { + return nil, fmt.Errorf("reconstruction failed: missing data for shard %d", i) + } + toCopy := remaining + if toCopy > len(allBytes[i]) { + toCopy = len(allBytes[i]) + } + reconstructed = append(reconstructed, allBytes[i][:toCopy]...) + remaining -= toCopy + } + + // Verify reconstructed block hash + // TODO hasher should be interface. + computedHash := sha256.Sum256(reconstructed) + if string(computedHash[:]) != string(g.BlockHash) { + return nil, fmt.Errorf("block hash mismatch after reconstruction") + } + + return reconstructed, nil +} + +// collectShred adds a data shred to the group +func (g *ShredGroup) collectShred(shred *gturbine.Shred) (bool, error) { + if shred == nil { + return false, fmt.Errorf("nil shred") + } + + // Validate shred matches group parameters + if shred.GroupID != g.GroupID { + return false, fmt.Errorf("group ID mismatch: got %s, want %s", shred.GroupID, g.GroupID) + } + if shred.Height != g.Height { + return false, fmt.Errorf("height mismatch: got %d, want %d", shred.Height, g.Height) + } + if string(shred.BlockHash) != string(g.BlockHash) { + return false, fmt.Errorf("block hash mismatch") + } + + switch shred.Type { + case gturbine.DataShred: + // Validate shred index + if int(shred.Index) >= len(g.DataShreds) { + return false, fmt.Errorf("invalid data shred index: %d", shred.Index) + } + + g.DataShreds[shred.Index] = shred + case gturbine.RecoveryShred: + // Validate shred index + if int(shred.Index) >= len(g.RecoveryShreds) { + return false, fmt.Errorf("invalid recovery shred index: %d", shred.Index) + } + + g.RecoveryShreds[shred.Index] = shred + default: + return false, fmt.Errorf("invalid shred type: %d", shred.Type) + } + + return g.isFull(), nil +} diff --git a/gturbine/turbine.go b/gturbine/turbine.go new file mode 100644 index 0000000..a534c72 --- /dev/null +++ b/gturbine/turbine.go @@ -0,0 +1,53 @@ +package gturbine + +type Tree struct { + Root *Layer + Height uint32 + Fanout uint32 +} + +type Layer struct { + Validators []uint64 + Parent *Layer + Children []*Layer +} + +type ShredType int + +const ( + DataShred ShredType = iota + RecoveryShred +) + +// Shred represents a piece of a block that can be sent over the network +type Shred struct { + // Metadata for block reconstruction + FullDataSize int // Size of the full block + BlockHash []byte // Hash for data verification + GroupID string // UUID for associating shreds from the same block + Height uint64 // Block height for chain reference + + // Shred-specific metadata + Type ShredType + Index int // Index of this shred within the block + TotalDataShreds int // Total number of shreds for this block + TotalRecoveryShreds int // Total number of shreds for this block + + Data []byte // The actual shred data +} + +// GetLayerByHeight returns layer at given height (0-based) +func (t *Tree) GetLayerByHeight(height uint32) *Layer { + if height >= t.Height { + return nil + } + + current := t.Root + for i := uint32(0); i < height; i++ { + if len(current.Children) == 0 { + return nil + } + current = current.Children[0] + } + return current +} diff --git a/gturbine/turbine_test.go b/gturbine/turbine_test.go new file mode 100644 index 0000000..07c0758 --- /dev/null +++ b/gturbine/turbine_test.go @@ -0,0 +1,204 @@ +package gturbine_test + +import ( + "bytes" + "context" + "fmt" + "net" + "sync" + "testing" + "time" + + "github.com/gordian-engine/gordian/gturbine/gtencoding" + "github.com/gordian-engine/gordian/gturbine/gtnetwork" + "github.com/gordian-engine/gordian/gturbine/gtshred" +) + +type testNode struct { + transport *gtnetwork.Transport + processor *gtshred.Processor + codec gtencoding.ShardCodec + shredHandler *testShredHandler + blockHandler *testBlockHandler +} + +type testShredHandler struct { + node *testNode // back-reference for reconstruction +} + +func (h *testShredHandler) HandleShred(data []byte, from net.Addr) error { + shred, err := h.node.codec.Decode(data) + if err != nil { + return fmt.Errorf("failed to decode shred: %w", err) + } + return h.node.processor.CollectShred(shred) +} + +type testBlock struct { + height uint64 + blockHash []byte + block []byte +} + +type testBlockHandler struct { + blocks []*testBlock + mu sync.Mutex +} + +func (h *testBlockHandler) ProcessBlock(height uint64, blockHash []byte, block []byte) error { + h.mu.Lock() + defer h.mu.Unlock() + h.blocks = append(h.blocks, &testBlock{ + height: height, + blockHash: blockHash, + block: block, + }) + return nil +} + +func newTestNode(t *testing.T, basePort int) *testNode { + encoder := gtencoding.NewBinaryShardCodec() + + transport := gtnetwork.NewTransport(gtnetwork.Config{ + BasePort: basePort, + NumPorts: 10, + }) + + cb := &testBlockHandler{} + + processor := gtshred.NewProcessor(cb, time.Minute) + go processor.RunBackgroundCleanup(context.Background()) + + shredHandler := &testShredHandler{} + node := &testNode{ + transport: transport, + processor: processor, + codec: encoder, + shredHandler: shredHandler, + blockHandler: cb, + } + shredHandler.node = node + + transport.AddHandler(shredHandler) + if err := transport.Start(); err != nil { + t.Fatalf("Failed to start transport: %v", err) + } + + return node +} + +func (n *testNode) stop() { + n.transport.Stop() +} + +func TestBlockPropagation(t *testing.T) { + // Create two test nodes + node1 := newTestNode(t, 40000) + defer node1.stop() + + node2 := newTestNode(t, 40010) + defer node2.stop() + + // Allow transports to start + time.Sleep(100 * time.Millisecond) + + // Create and process a test block + originalBlock := []byte("test block data for propagation") + + const testHeight = 12345 + + // Node 1: Shred the block + shredGroup, err := gtshred.NewShredGroup(originalBlock, testHeight, 16, 4, 1024) + if err != nil { + t.Fatalf("Failed to shred block: %v", err) + } + + // Node 1: Encode and send shreds to Node 2 + for i, shred := range append(shredGroup.DataShreds, shredGroup.RecoveryShreds...) { + encodedShred, err := node1.codec.Encode(shred) + if err != nil { + t.Fatalf("Failed to encode shred: %v", err) + } + err = node1.transport.SendShred(encodedShred, &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: node2.transport.BasePort() + i%10, + }) + if err != nil { + t.Fatalf("Failed to send shred: %v", err) + } + } + + // Wait for processing + time.Sleep(300 * time.Millisecond) + + node2.blockHandler.mu.Lock() + defer node2.blockHandler.mu.Unlock() + + // Verify Node 2 received and reconstructed the block + if len(node2.blockHandler.blocks) != 1 { + t.Fatalf("Expected 1 reconstructed block, got %d", len(node2.blockHandler.blocks)) + } + + if !bytes.Equal(node2.blockHandler.blocks[0].block, originalBlock) { + t.Errorf("Block mismatch: got %q, want %q", node2.blockHandler.blocks[0], originalBlock) + } + + if node2.blockHandler.blocks[0].height != testHeight { + t.Fatalf("Block height mismatch: got %d, want %d", node2.blockHandler.blocks[0].height, testHeight) + } +} + +func TestPartialBlockReconstruction(t *testing.T) { + node1 := newTestNode(t, 40020) + defer node1.stop() + + node2 := newTestNode(t, 40030) + defer node2.stop() + + time.Sleep(100 * time.Millisecond) + + originalBlock := []byte("test block for partial reconstruction") + + const testHeight = 54321 + + // Create shreds + shredGroup, err := gtshred.NewShredGroup(originalBlock, testHeight, 16, 4, 1024) + if err != nil { + t.Fatalf("Failed to shred block: %v", err) + } + + // Send only minimum required shreds + minShreds := append(shredGroup.DataShreds[:12], shredGroup.RecoveryShreds...) + for i, shred := range minShreds { + encodedShred, err := node1.codec.Encode(shred) + if err != nil { + t.Fatalf("Failed to encode shred: %v", err) + } + + err = node1.transport.SendShred(encodedShred, &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: node2.transport.BasePort() + i%10, + }) + if err != nil { + t.Fatalf("Failed to send shred: %v", err) + } + } + + time.Sleep(100 * time.Millisecond) + + node2.blockHandler.mu.Lock() + defer node2.blockHandler.mu.Unlock() + + // Verify Node 2 received and reconstructed the block + if len(node2.blockHandler.blocks) != 1 { + t.Fatalf("Expected 1 reconstructed block, got %d", len(node2.blockHandler.blocks)) + } + + if !bytes.Equal(node2.blockHandler.blocks[0].block, originalBlock) { + t.Errorf("Block mismatch: got %q, want %q", node2.blockHandler.blocks[0], originalBlock) + } + + if node2.blockHandler.blocks[0].height != testHeight { + t.Fatalf("Block height mismatch: got %d, want %d", node2.blockHandler.blocks[0].height, testHeight) + } +} From b3bf2d78d8810bd26221f4e31018415361b2ac1d Mon Sep 17 00:00:00 2001 From: Andrew Gouin Date: Wed, 11 Dec 2024 19:47:19 -0700 Subject: [PATCH 2/4] Integrate with gerasure --- gerasure/gereedsolomon/compliance_test.go | 13 +- gerasure/gereedsolomon/encoder.go | 16 +- gerasure/gereedsolomon/reconstructor.go | 15 +- gturbine/gtencoding/benchmark_test.go | 268 ++---------------- .../gtencoding/binary_codec_bench_test.go | 71 ++--- gturbine/gtencoding/binary_encoder.go | 53 ++-- gturbine/gtencoding/binary_encoder_test.go | 111 +++++--- gturbine/gtencoding/erasure.go | 88 ------ gturbine/gtencoding/erasure_test.go | 189 ------------ gturbine/gtshred/benchmark_test.go | 46 ++- gturbine/gtshred/process_shred_test.go | 33 +-- gturbine/gtshred/processor.go | 117 +++++--- gturbine/gtshred/processor_test.go | 16 +- gturbine/gtshred/shred_block.go | 78 +++++ gturbine/gtshred/shred_group.go | 229 --------------- gturbine/turbine.go | 28 +- gturbine/turbine_test.go | 12 +- 17 files changed, 396 insertions(+), 987 deletions(-) delete mode 100644 gturbine/gtencoding/erasure.go delete mode 100644 gturbine/gtencoding/erasure_test.go create mode 100644 gturbine/gtshred/shred_block.go delete mode 100644 gturbine/gtshred/shred_group.go diff --git a/gerasure/gereedsolomon/compliance_test.go b/gerasure/gereedsolomon/compliance_test.go index abf69dc..a3e7801 100644 --- a/gerasure/gereedsolomon/compliance_test.go +++ b/gerasure/gereedsolomon/compliance_test.go @@ -6,37 +6,32 @@ import ( "github.com/gordian-engine/gordian/gerasure" "github.com/gordian-engine/gordian/gerasure/gerasuretest" "github.com/gordian-engine/gordian/gerasure/gereedsolomon" - "github.com/klauspost/reedsolomon" ) func TestReconstructionCompliance(t *testing.T) { gerasuretest.TestFixedRateErasureReconstructionCompliance( t, func(origData []byte, nData, nParity int) (gerasure.Encoder, gerasure.Reconstructor) { - rs, err := reedsolomon.New(nData, nParity) + enc, err := gereedsolomon.NewEncoder(nData, nParity) if err != nil { panic(err) } // We don't know the shard size until we encode. // (Or at least I don't see how to get that from the reedsolomon package.) - allShards, err := rs.Split(origData) + allShards, err := enc.Encode(nil, origData) if err != nil { panic(err) } shardSize := len(allShards[0]) - enc := gereedsolomon.NewEncoder(rs) - // Separate reedsolomon encoder for the reconstructor. - rrs, err := reedsolomon.New(nData, nParity) + rcons, err := gereedsolomon.NewReconstructor(nData, nParity, shardSize) if err != nil { panic(err) } - r := gereedsolomon.NewReconstructor(rrs, shardSize) - - return enc, r + return enc, rcons }, ) } diff --git a/gerasure/gereedsolomon/encoder.go b/gerasure/gereedsolomon/encoder.go index 6af5c09..f8295f8 100644 --- a/gerasure/gereedsolomon/encoder.go +++ b/gerasure/gereedsolomon/encoder.go @@ -15,13 +15,23 @@ type Encoder struct { // NewEncoder returns a new Encoder. // The options within the given reedsolomon.Encoder determine the number of shards. -func NewEncoder(rs reedsolomon.Encoder) Encoder { - return Encoder{rs: rs} +func NewEncoder(dataShreds, parityShreds int) (*Encoder, error) { + if dataShreds <= 0 { + return nil, fmt.Errorf("data shreds must be > 0") + } + if parityShreds <= 0 { + return nil, fmt.Errorf("parity shreds must be > 0") + } + rs, err := reedsolomon.New(dataShreds, parityShreds) + if err != nil { + return nil, fmt.Errorf("failed to create reed-solomon encoder: %w", err) + } + return &Encoder{rs: rs}, nil } // Encode satisfies [gerasure.Encoder]. // Callers should assume that the Encoder takes ownership of the given data slice. -func (e Encoder) Encode(_ context.Context, data []byte) ([][]byte, error) { +func (e *Encoder) Encode(_ context.Context, data []byte) ([][]byte, error) { // From the original data, produce new subslices for the data shards and parity shards. allShards, err := e.rs.Split(data) if err != nil { diff --git a/gerasure/gereedsolomon/reconstructor.go b/gerasure/gereedsolomon/reconstructor.go index ef7dc21..e79355b 100644 --- a/gerasure/gereedsolomon/reconstructor.go +++ b/gerasure/gereedsolomon/reconstructor.go @@ -28,7 +28,18 @@ type Reconstructor struct { // NewReconstructor returns a new Reconstructor. // The options within the given reedsolomon.Encoder determine the number of shards. // The shardSize and totalDataSize must be discovered out of band; -func NewReconstructor(rs reedsolomon.Encoder, shardSize int) *Reconstructor { +func NewReconstructor(dataShards, parityShards, shardSize int) (*Reconstructor, error) { + if dataShards <= 0 { + return nil, fmt.Errorf("data shards must be > 0") + } + if parityShards <= 0 { + return nil, fmt.Errorf("parity shards must be > 0") + } + rs, err := reedsolomon.New(dataShards, parityShards) + if err != nil { + return nil, fmt.Errorf("failed to create reed-solomon reconstructor: %w", err) + } + // All reedsolomon.Encoder instances are guaranteed to satisfy reedsolomon.Extensions. // Calling AllocAligned is supposed to result in better throughput // when actually encoding and decoding. @@ -47,7 +58,7 @@ func NewReconstructor(rs reedsolomon.Encoder, shardSize int) *Reconstructor { allShards: allShards, shardSize: shardSize, - } + }, nil } // ReconstructData satisfies [gerasure.Reconstructor]. diff --git a/gturbine/gtencoding/benchmark_test.go b/gturbine/gtencoding/benchmark_test.go index 0c671ef..6aabb98 100644 --- a/gturbine/gtencoding/benchmark_test.go +++ b/gturbine/gtencoding/benchmark_test.go @@ -31,15 +31,16 @@ func BenchmarkEncode(b *testing.B) { rand.Read(data) shred := >urbine.Shred{ - Type: gturbine.DataShred, - FullDataSize: size.size, - BlockHash: make([]byte, 32), - GroupID: uuid.New().String(), - Height: 1, - Index: 0, - TotalDataShreds: 16, - TotalRecoveryShreds: 4, - Data: data, + Metadata: >urbine.ShredMetadata{ + FullDataSize: size.size, + BlockHash: make([]byte, 32), + GroupID: uuid.New().String(), + Height: 1, + TotalDataShreds: 16, + TotalRecoveryShreds: 4, + }, + Index: 0, + Data: data, } codec := NewBinaryShardCodec() @@ -65,15 +66,16 @@ func BenchmarkDecode(b *testing.B) { rand.Read(data) shred := >urbine.Shred{ - Type: gturbine.DataShred, - FullDataSize: size.size, - BlockHash: make([]byte, 32), - GroupID: uuid.New().String(), - Height: 1, - Index: 0, - TotalDataShreds: 16, - TotalRecoveryShreds: 4, - Data: data, + Metadata: >urbine.ShredMetadata{ + FullDataSize: size.size, + BlockHash: make([]byte, 32), + GroupID: uuid.New().String(), + Height: 1, + TotalDataShreds: 16, + TotalRecoveryShreds: 4, + }, + Index: 0, + Data: data, } codec := NewBinaryShardCodec() @@ -91,233 +93,3 @@ func BenchmarkDecode(b *testing.B) { }) } } - -// BenchmarkErasureEncoding tests Reed-Solomon encoding with different configurations -func BenchmarkErasureEncoding(b *testing.B) { - configs := []struct { - data int - recovery int - }{ - {4, 2}, // 33% overhead - {8, 4}, // 33% overhead - {16, 4}, // 20% overhead - {32, 8}, // 20% overhead - } - - for _, size := range benchSizes { - for _, cfg := range configs { - shardSize := size.size / cfg.data - name := fmt.Sprintf("%s-%dd-%dr", size.name, cfg.data, cfg.recovery) - - b.Run(name, func(b *testing.B) { - enc, err := NewEncoder(cfg.data, cfg.recovery) - if err != nil { - b.Fatal(err) - } - - shreds := make([][]byte, cfg.data) - for i := range shreds { - shreds[i] = make([]byte, shardSize) - rand.Read(shreds[i]) - } - - b.ResetTimer() - b.SetBytes(int64(size.size)) - - for i := 0; i < b.N; i++ { - _, err := enc.GenerateRecoveryShreds(shreds) - if err != nil { - b.Fatal(err) - } - } - }) - } - } -} - -// BenchmarkErasureReconstruction tests Reed-Solomon reconstruction with different configurations -func BenchmarkErasureReconstruction(b *testing.B) { - configs := []struct { - data int - recovery int - }{ - {4, 2}, // 33% overhead - {8, 4}, // 33% overhead - {16, 4}, // 20% overhead - {32, 8}, // 20% overhead - } - - for _, size := range benchSizes { - for _, cfg := range configs { - shardSize := size.size / cfg.data - name := fmt.Sprintf("%s-%dd-%dr", size.name, cfg.data, cfg.recovery) - - b.Run(name, func(b *testing.B) { - enc, err := NewEncoder(cfg.data, cfg.recovery) - if err != nil { - b.Fatal(err) - } - - // Generate test data - shreds := make([][]byte, cfg.data) - for i := range shreds { - shreds[i] = make([]byte, shardSize) - rand.Read(shreds[i]) - } - - // Generate recovery shreds - recoveryShreds, err := enc.GenerateRecoveryShreds(shreds) - if err != nil { - b.Fatal(err) - } - - // Combine all shreds - allShreds := append(shreds, recoveryShreds...) - - // Simulate worst case - lose maximum recoverable shards - for i := 0; i < cfg.recovery; i++ { - allShreds[i] = nil - } - - b.ResetTimer() - b.SetBytes(int64(size.size)) - - for i := 0; i < b.N; i++ { - // Make a copy since Reconstruct modifies the slice - testShreds := make([][]byte, len(allShreds)) - copy(testShreds, allShreds) - - err := enc.Reconstruct(testShreds) - if err != nil { - b.Fatal(err) - } - } - }) - } - } -} - -// BenchmarkErasureVerification tests Reed-Solomon verification with different configurations -func BenchmarkErasureVerification(b *testing.B) { - configs := []struct { - data int - recovery int - }{ - {4, 2}, // 33% overhead - {8, 4}, // 33% overhead - {16, 4}, // 20% overhead - {32, 8}, // 20% overhead - } - - for _, size := range benchSizes { - for _, cfg := range configs { - shardSize := size.size / cfg.data - name := fmt.Sprintf("%s-%dd-%dr", size.name, cfg.data, cfg.recovery) - - b.Run(name, func(b *testing.B) { - enc, err := NewEncoder(cfg.data, cfg.recovery) - if err != nil { - b.Fatal(err) - } - - // Generate test data - shreds := make([][]byte, cfg.data) - for i := range shreds { - shreds[i] = make([]byte, shardSize) - rand.Read(shreds[i]) - } - - // Generate recovery shreds - recoveryShreds, err := enc.GenerateRecoveryShreds(shreds) - if err != nil { - b.Fatal(err) - } - - // Combine all shreds - allShreds := append(shreds, recoveryShreds...) - - b.ResetTimer() - b.SetBytes(int64(size.size)) - - for i := 0; i < b.N; i++ { - ok, err := enc.Verify(allShreds) - if err != nil { - b.Fatal(err) - } - if !ok { - b.Fatal("verification failed") - } - } - }) - } - } -} - -// BenchmarkFullPipeline tests the complete encoding process -func BenchmarkFullPipeline(b *testing.B) { - configs := []struct { - data int - recovery int - }{ - {16, 4}, // Typical configuration - } - - for _, size := range benchSizes { - for _, cfg := range configs { - name := fmt.Sprintf("%s-%dd-%dr", size.name, cfg.data, cfg.recovery) - - b.Run(name, func(b *testing.B) { - binaryCodec := NewBinaryShardCodec() - erasureEnc, err := NewEncoder(cfg.data, cfg.recovery) - if err != nil { - b.Fatal(err) - } - - data := make([]byte, size.size) - rand.Read(data) - - b.ResetTimer() - b.SetBytes(int64(size.size)) - - for i := 0; i < b.N; i++ { - shreds := make([][]byte, cfg.data) - shredSize := size.size / cfg.data - - for j := 0; j < cfg.data; j++ { - shred := >urbine.Shred{ - Type: gturbine.DataShred, - FullDataSize: size.size, - BlockHash: make([]byte, 32), - GroupID: uuid.New().String(), - Height: 1, - Index: j, - TotalDataShreds: cfg.data, - TotalRecoveryShreds: cfg.recovery, - Data: data[j*shredSize : (j+1)*shredSize], - } - - encoded, err := binaryCodec.Encode(shred) - if err != nil { - b.Fatal(err) - } - shreds[j] = encoded - } - - recoveryShreds, err := erasureEnc.GenerateRecoveryShreds(shreds) - if err != nil { - b.Fatal(err) - } - - allShreds := append(shreds, recoveryShreds...) - ok, err := erasureEnc.Verify(allShreds) - if err != nil { - b.Fatal(err) - } - if !ok { - b.Fatal("verification failed") - } - } - }) - } - } -} diff --git a/gturbine/gtencoding/binary_codec_bench_test.go b/gturbine/gtencoding/binary_codec_bench_test.go index e84a0ee..62ecd82 100644 --- a/gturbine/gtencoding/binary_codec_bench_test.go +++ b/gturbine/gtencoding/binary_codec_bench_test.go @@ -12,15 +12,14 @@ import ( // TestShred represents a reusable test shred configuration type TestShred struct { - size int - dataType gturbine.ShredType + size int } -var testConfigs = []TestShred{ - {64, gturbine.DataShred}, // Minimum size - {1024, gturbine.DataShred}, // 1KB - {64 * 1024, gturbine.DataShred}, // 64KB - {1024 * 1024, gturbine.DataShred}, // 1MB +var testConfigs = []int{ + 64, // Minimum size + 1024, // 1KB + 64 * 1024, // 64KB + 1024 * 1024, // 1MB } // BenchmarkBinaryCodec runs comprehensive benchmarks for the binary codec @@ -51,34 +50,36 @@ func BenchmarkBinaryCodecParallel(b *testing.B) { } // Helper to create consistent benchmark names -func benchName(op string, cfg TestShred) string { - return fmt.Sprintf("%s/%dB", op, cfg.size) +func benchName(op string, size int) string { + return fmt.Sprintf("%s/%dB", op, size) } // Helper to create a test shred -func createTestShred(cfg TestShred) *gturbine.Shred { - data := make([]byte, cfg.size) +func createTestShred(size int) *gturbine.Shred { + data := make([]byte, size) rand.Read(data) return >urbine.Shred{ - Type: cfg.dataType, - FullDataSize: cfg.size, - BlockHash: bytes.Repeat([]byte{0xFF}, blockHashSize), // Fixed pattern for consistent benchmarking - GroupID: uuid.New().String(), - Height: 1, - Index: 0, - TotalDataShreds: 16, - TotalRecoveryShreds: 4, - Data: data, + Metadata: >urbine.ShredMetadata{ + FullDataSize: size, + BlockHash: bytes.Repeat([]byte{0xFF}, blockHashSize), // Fixed pattern for consistent benchmarking + GroupID: uuid.New().String(), + Height: 1, + TotalDataShreds: 16, + TotalRecoveryShreds: 4, + }, + Index: 0, + Data: data, + Hash: bytes.Repeat([]byte{0xFF}, blockHashSize), // Fixed pattern for consistent benchmarking, } } -func benchmarkEncode(b *testing.B, cfg TestShred) { +func benchmarkEncode(b *testing.B, size int) { codec := NewBinaryShardCodec() - shred := createTestShred(cfg) + shred := createTestShred(size) b.ResetTimer() - b.SetBytes(int64(cfg.size + prefixSize)) + b.SetBytes(int64(size + prefixSize)) for i := 0; i < b.N; i++ { _, err := codec.Encode(shred) @@ -88,16 +89,16 @@ func benchmarkEncode(b *testing.B, cfg TestShred) { } } -func benchmarkDecode(b *testing.B, cfg TestShred) { +func benchmarkDecode(b *testing.B, size int) { codec := NewBinaryShardCodec() - shred := createTestShred(cfg) + shred := createTestShred(size) encoded, err := codec.Encode(shred) if err != nil { b.Fatal(err) } b.ResetTimer() - b.SetBytes(int64(cfg.size + prefixSize)) + b.SetBytes(int64(size + prefixSize)) for i := 0; i < b.N; i++ { _, err := codec.Decode(encoded) @@ -107,12 +108,12 @@ func benchmarkDecode(b *testing.B, cfg TestShred) { } } -func benchmarkRoundTrip(b *testing.B, cfg TestShred) { +func benchmarkRoundTrip(b *testing.B, size int) { codec := NewBinaryShardCodec() - shred := createTestShred(cfg) + shred := createTestShred(size) b.ResetTimer() - b.SetBytes(int64(cfg.size + prefixSize)) + b.SetBytes(int64(size + prefixSize)) for i := 0; i < b.N; i++ { encoded, err := codec.Encode(shred) @@ -126,12 +127,12 @@ func benchmarkRoundTrip(b *testing.B, cfg TestShred) { } } -func benchmarkEncodeParallel(b *testing.B, cfg TestShred) { +func benchmarkEncodeParallel(b *testing.B, size int) { codec := NewBinaryShardCodec() - shred := createTestShred(cfg) + shred := createTestShred(size) b.ResetTimer() - b.SetBytes(int64(cfg.size + prefixSize)) + b.SetBytes(int64(size + prefixSize)) b.RunParallel(func(pb *testing.PB) { for pb.Next() { @@ -143,16 +144,16 @@ func benchmarkEncodeParallel(b *testing.B, cfg TestShred) { }) } -func benchmarkDecodeParallel(b *testing.B, cfg TestShred) { +func benchmarkDecodeParallel(b *testing.B, size int) { codec := NewBinaryShardCodec() - shred := createTestShred(cfg) + shred := createTestShred(size) encoded, err := codec.Encode(shred) if err != nil { b.Fatal(err) } b.ResetTimer() - b.SetBytes(int64(cfg.size + prefixSize)) + b.SetBytes(int64(size + prefixSize)) b.RunParallel(func(pb *testing.PB) { for pb.Next() { diff --git a/gturbine/gtencoding/binary_encoder.go b/gturbine/gtencoding/binary_encoder.go index 0c4d326..a9b3d0a 100644 --- a/gturbine/gtencoding/binary_encoder.go +++ b/gturbine/gtencoding/binary_encoder.go @@ -13,7 +13,6 @@ const ( int32Size = 4 int64Size = 8 versionSize = int16Size - typeSize = int16Size uuidSize = 16 fullDataSizeSize = int64Size @@ -23,8 +22,9 @@ const ( indexSize = int64Size totalDataShredsSize = int64Size totalRecoveryShredsSize = int64Size + shredHashSize = 32 - prefixSize = versionSize + typeSize + fullDataSizeSize + blockHashSize + groupIDSize + heightSize + indexSize + totalDataShredsSize + totalRecoveryShredsSize + prefixSize = versionSize + fullDataSizeSize + blockHashSize + groupIDSize + heightSize + indexSize + totalDataShredsSize + totalRecoveryShredsSize + shredHashSize binaryVersion = 1 ) @@ -41,33 +41,35 @@ func (bsc *BinaryShardCodec) Encode(shred *gturbine.Shred) ([]byte, error) { // Write version binary.LittleEndian.PutUint16(out[:2], binaryVersion) - // Write type - binary.LittleEndian.PutUint16(out[2:4], uint16(shred.Type)) + m := shred.Metadata // Write full data size - binary.LittleEndian.PutUint64(out[4:12], uint64(shred.FullDataSize)) + binary.LittleEndian.PutUint64(out[2:10], uint64(m.FullDataSize)) // Write block hash - copy(out[12:44], shred.BlockHash) + copy(out[10:42], m.BlockHash) - uid, err := uuid.Parse(shred.GroupID) + uid, err := uuid.Parse(m.GroupID) if err != nil { return nil, fmt.Errorf("failed to parse group ID: %w", err) } // Write group ID - copy(out[44:60], uid[:]) + copy(out[42:58], uid[:]) // Write height - binary.LittleEndian.PutUint64(out[60:68], shred.Height) + binary.LittleEndian.PutUint64(out[58:66], m.Height) // Write index - binary.LittleEndian.PutUint64(out[68:76], uint64(shred.Index)) + binary.LittleEndian.PutUint64(out[66:74], uint64(shred.Index)) // Write total data shreds - binary.LittleEndian.PutUint64(out[76:84], uint64(shred.TotalDataShreds)) + binary.LittleEndian.PutUint64(out[74:82], uint64(m.TotalDataShreds)) // Write total recovery shreds - binary.LittleEndian.PutUint64(out[84:92], uint64(shred.TotalRecoveryShreds)) + binary.LittleEndian.PutUint64(out[82:90], uint64(m.TotalRecoveryShreds)) + + // Write hash + copy(out[90:122], shred.Hash) // Write data copy(out[prefixSize:], shred.Data) @@ -85,36 +87,41 @@ func (bsc *BinaryShardCodec) Decode(data []byte) (*gturbine.Shred, error) { return nil, fmt.Errorf("unsupported version: %d", version) } - // Read type - shred.Type = gturbine.ShredType(binary.LittleEndian.Uint16(data[2:4])) + m := new(gturbine.ShredMetadata) // Read full data size - shred.FullDataSize = int(binary.LittleEndian.Uint64(data[4:12])) + m.FullDataSize = int(binary.LittleEndian.Uint64(data[2:10])) // Read block hash - shred.BlockHash = make([]byte, blockHashSize) - copy(shred.BlockHash, data[12:44]) + m.BlockHash = make([]byte, blockHashSize) + copy(m.BlockHash, data[10:42]) // Read group ID uid := uuid.UUID{} - copy(uid[:], data[44:60]) - shred.GroupID = uid.String() + copy(uid[:], data[42:58]) + m.GroupID = uid.String() // Read height - shred.Height = binary.LittleEndian.Uint64(data[60:68]) + m.Height = binary.LittleEndian.Uint64(data[58:66]) // Read index - shred.Index = int(binary.LittleEndian.Uint64(data[68:76])) + shred.Index = int(binary.LittleEndian.Uint64(data[66:74])) // Read total data shreds - shred.TotalDataShreds = int(binary.LittleEndian.Uint64(data[76:84])) + m.TotalDataShreds = int(binary.LittleEndian.Uint64(data[74:82])) // Read total recovery shreds - shred.TotalRecoveryShreds = int(binary.LittleEndian.Uint64(data[84:92])) + m.TotalRecoveryShreds = int(binary.LittleEndian.Uint64(data[82:90])) + + // Read hash + shred.Hash = make([]byte, shredHashSize) + copy(shred.Hash, data[90:122]) // Read data shred.Data = make([]byte, len(data)-prefixSize) copy(shred.Data, data[prefixSize:]) + shred.Metadata = m + return &shred, nil } diff --git a/gturbine/gtencoding/binary_encoder_test.go b/gturbine/gtencoding/binary_encoder_test.go index 2463087..645a9c9 100644 --- a/gturbine/gtencoding/binary_encoder_test.go +++ b/gturbine/gtencoding/binary_encoder_test.go @@ -17,42 +17,51 @@ func TestBinaryShardCodec_EncodeDecode(t *testing.T) { { name: "basic encode/decode", shred: >urbine.Shred{ - FullDataSize: 1000, - BlockHash: bytes.Repeat([]byte{1}, 32), - GroupID: uuid.New().String(), - Height: 12345, - Index: 5, - TotalDataShreds: 10, - TotalRecoveryShreds: 2, - Data: []byte("test data"), + Metadata: >urbine.ShredMetadata{ + FullDataSize: 1000, + BlockHash: bytes.Repeat([]byte{1}, 32), + GroupID: uuid.New().String(), + Height: 12345, + TotalDataShreds: 10, + TotalRecoveryShreds: 2, + }, + Index: 5, + Data: []byte("test data"), + Hash: bytes.Repeat([]byte{2}, 32), }, wantErr: false, }, { name: "empty data", shred: >urbine.Shred{ - FullDataSize: 0, - BlockHash: bytes.Repeat([]byte{2}, 32), - GroupID: uuid.New().String(), - Height: 67890, - Index: 0, - TotalDataShreds: 1, - TotalRecoveryShreds: 0, - Data: []byte{}, + Metadata: >urbine.ShredMetadata{ + FullDataSize: 0, + BlockHash: bytes.Repeat([]byte{2}, 32), + GroupID: uuid.New().String(), + Height: 67890, + TotalDataShreds: 1, + TotalRecoveryShreds: 0, + }, + Index: 0, + Data: []byte{}, + Hash: bytes.Repeat([]byte{2}, 32), }, wantErr: false, }, { name: "large data", shred: >urbine.Shred{ - FullDataSize: 1000000, - BlockHash: bytes.Repeat([]byte{3}, 32), - GroupID: uuid.New().String(), - Height: 999999, - Index: 50, - TotalDataShreds: 100, - TotalRecoveryShreds: 20, - Data: bytes.Repeat([]byte("large data"), 1000), + Metadata: >urbine.ShredMetadata{ + FullDataSize: 1000000, + BlockHash: bytes.Repeat([]byte{3}, 32), + GroupID: uuid.New().String(), + Height: 999999, + TotalDataShreds: 100, + TotalRecoveryShreds: 20, + }, + Index: 50, + Data: bytes.Repeat([]byte("large data"), 1000), + Hash: bytes.Repeat([]byte{2}, 32), }, wantErr: false, }, @@ -80,38 +89,46 @@ func TestBinaryShardCodec_EncodeDecode(t *testing.T) { return } + sm := tt.shred.Metadata + + dm := decoded.Metadata + // Verify all fields match - if decoded.FullDataSize != tt.shred.FullDataSize { - t.Errorf("FullDataSize mismatch: got %v, want %v", decoded.FullDataSize, tt.shred.FullDataSize) + if dm.FullDataSize != sm.FullDataSize { + t.Errorf("FullDataSize mismatch: got %v, want %v", dm.FullDataSize, sm.FullDataSize) } - if !bytes.Equal(decoded.BlockHash, tt.shred.BlockHash) { - t.Errorf("BlockHash mismatch: got %v, want %v", decoded.BlockHash, tt.shred.BlockHash) + if !bytes.Equal(dm.BlockHash, sm.BlockHash) { + t.Errorf("BlockHash mismatch: got %v, want %v", dm.BlockHash, sm.BlockHash) } - if decoded.GroupID != tt.shred.GroupID { - t.Errorf("GroupID mismatch: got %v, want %v", decoded.GroupID, tt.shred.GroupID) + if dm.GroupID != sm.GroupID { + t.Errorf("GroupID mismatch: got %v, want %v", dm.GroupID, sm.GroupID) } - if decoded.Height != tt.shred.Height { - t.Errorf("Height mismatch: got %v, want %v", decoded.Height, tt.shred.Height) + if dm.Height != sm.Height { + t.Errorf("Height mismatch: got %v, want %v", dm.Height, sm.Height) } if decoded.Index != tt.shred.Index { t.Errorf("Index mismatch: got %v, want %v", decoded.Index, tt.shred.Index) } - if decoded.TotalDataShreds != tt.shred.TotalDataShreds { - t.Errorf("TotalDataShreds mismatch: got %v, want %v", decoded.TotalDataShreds, tt.shred.TotalDataShreds) + if dm.TotalDataShreds != sm.TotalDataShreds { + t.Errorf("TotalDataShreds mismatch: got %v, want %v", dm.TotalDataShreds, sm.TotalDataShreds) } - if decoded.TotalRecoveryShreds != tt.shred.TotalRecoveryShreds { - t.Errorf("TotalRecoveryShreds mismatch: got %v, want %v", decoded.TotalRecoveryShreds, tt.shred.TotalRecoveryShreds) + if dm.TotalRecoveryShreds != sm.TotalRecoveryShreds { + t.Errorf("TotalRecoveryShreds mismatch: got %v, want %v", dm.TotalRecoveryShreds, sm.TotalRecoveryShreds) } if !bytes.Equal(decoded.Data, tt.shred.Data) { t.Errorf("Data mismatch: got %v, want %v", decoded.Data, tt.shred.Data) } + + if !bytes.Equal(decoded.Hash, tt.shred.Hash) { + t.Errorf("Hash mismatch: got %v, want %v", decoded.Hash, tt.shred.Hash) + } }) } } @@ -119,7 +136,9 @@ func TestBinaryShardCodec_EncodeDecode(t *testing.T) { func TestBinaryShardCodec_InvalidGroupID(t *testing.T) { codec := &BinaryShardCodec{} shred := >urbine.Shred{ - GroupID: "invalid-uuid", + Metadata: >urbine.ShredMetadata{ + GroupID: "invalid-uuid", + }, // Other fields can be empty for this test } @@ -132,14 +151,16 @@ func TestBinaryShardCodec_InvalidGroupID(t *testing.T) { func TestBinaryShardCodec_DataSizes(t *testing.T) { codec := &BinaryShardCodec{} shred := >urbine.Shred{ - FullDataSize: 1000, - BlockHash: bytes.Repeat([]byte{1}, 32), - GroupID: uuid.New().String(), - Height: 12345, - Index: 5, - TotalDataShreds: 10, - TotalRecoveryShreds: 2, - Data: []byte("test data"), + Metadata: >urbine.ShredMetadata{ + FullDataSize: 1000, + BlockHash: bytes.Repeat([]byte{1}, 32), + GroupID: uuid.New().String(), + Height: 12345, + TotalDataShreds: 10, + TotalRecoveryShreds: 2, + }, + Index: 5, + Data: []byte("test data"), } encoded, err := codec.Encode(shred) diff --git a/gturbine/gtencoding/erasure.go b/gturbine/gtencoding/erasure.go deleted file mode 100644 index d5032d6..0000000 --- a/gturbine/gtencoding/erasure.go +++ /dev/null @@ -1,88 +0,0 @@ -package gtencoding - -import ( - "fmt" - - "github.com/klauspost/reedsolomon" -) - -const maxTotalShreds = 128 - -type Encoder struct { - enc reedsolomon.Encoder - dataShreds int - recoveryShreds int -} - -func NewEncoder(dataShreds, recoveryShreds int) (*Encoder, error) { - if dataShreds <= 0 { - return nil, fmt.Errorf("data shreds must be > 0") - } - if recoveryShreds <= 0 { - return nil, fmt.Errorf("recovery shreds must be > 0") - } - if dataShreds+recoveryShreds > maxTotalShreds { - return nil, fmt.Errorf("total shreds must be <= %d", maxTotalShreds) - } - - enc, err := reedsolomon.New(dataShreds, recoveryShreds) - if err != nil { - return nil, fmt.Errorf("failed to create reed-solomon encoder: %w", err) - } - - return &Encoder{ - enc: enc, - dataShreds: dataShreds, - recoveryShreds: recoveryShreds, - }, nil -} - -func (e *Encoder) GenerateRecoveryShreds(shreds [][]byte) ([][]byte, error) { - if len(shreds) != e.dataShreds { - return nil, fmt.Errorf("expected %d data shreds, got %d", e.dataShreds, len(shreds)) - } - - totalShreds := make([][]byte, e.dataShreds+e.recoveryShreds) - copy(totalShreds, shreds) - - for i := e.dataShreds; i < len(totalShreds); i++ { - totalShreds[i] = make([]byte, len(shreds[0])) - } - - if err := e.enc.Encode(totalShreds); err != nil { - return nil, fmt.Errorf("encoding failed: %w", err) - } - - return totalShreds[e.dataShreds:], nil -} - -func (e *Encoder) Reconstruct(allShreds [][]byte) error { - if len(allShreds) != e.dataShreds+e.recoveryShreds { - return fmt.Errorf("expected %d total shreds, got %d", e.dataShreds+e.recoveryShreds, len(allShreds)) - } - - // Count non-nil shreds - validShreds := 0 - for _, shred := range allShreds { - if shred != nil { - validShreds++ - } - } - - // Need at least dataShreds valid pieces for reconstruction - if validShreds < e.dataShreds { - return fmt.Errorf("insufficient shreds for reconstruction: have %d, need %d", validShreds, e.dataShreds) - } - - if err := e.enc.Reconstruct(allShreds); err != nil { - return fmt.Errorf("reconstruction failed: %w", err) - } - return nil -} - -func (e *Encoder) Verify(allShreds [][]byte) (bool, error) { - if len(allShreds) != e.dataShreds+e.recoveryShreds { - return false, fmt.Errorf("expected %d total shreds, got %d", e.dataShreds+e.recoveryShreds, len(allShreds)) - } - return e.enc.Verify(allShreds) -} diff --git a/gturbine/gtencoding/erasure_test.go b/gturbine/gtencoding/erasure_test.go deleted file mode 100644 index 27e2590..0000000 --- a/gturbine/gtencoding/erasure_test.go +++ /dev/null @@ -1,189 +0,0 @@ -package gtencoding - -import ( - "bytes" - "crypto/rand" - "fmt" - mrand "math/rand" - "testing" -) - -const ( - // NOTE: blocksize needs to be a multiple of 64 for reed solomon to work - blockSize = 128 * 1024 * 1024 // 128MB blocks same as solana - numTests = 5 // Number of iterations for randomized tests -) - -func TestEncoderRealWorld(t *testing.T) { - t.Run("solana-like configuration", func(t *testing.T) { - enc, err := NewEncoder(32, 32) - if err != nil { - t.Fatal(err) - } - - shredSize := blockSize / 32 - dataShreds := make([][]byte, 32) - for i := range dataShreds { - dataShreds[i] = make([]byte, shredSize) - if _, err := rand.Read(dataShreds[i]); err != nil { - t.Fatal(err) - } - } - - recoveryShreds, err := enc.GenerateRecoveryShreds(dataShreds) - if err != nil { - t.Fatal(err) - } - - if len(recoveryShreds) != 32 { - t.Fatalf("expected 32 recovery shreds, got %d", len(recoveryShreds)) - } - - allShreds := append(dataShreds, recoveryShreds...) - - scenarios := []struct { - name string - numDataLost int - numParityLost int - shouldRecover bool - }{ - {"lose 16 data shreds", 16, 0, true}, - {"lose 16 data and 16 parity shreds", 16, 16, true}, - {"lose 31 data shreds", 31, 0, true}, - {"lose all parity shreds", 0, 32, true}, - {"lose 32 data shreds and 1 parity shred", 32, 1, false}, - {"lose 31 data and 2 parity shreds", 31, 2, false}, - } - - for _, sc := range scenarios { - t.Run(sc.name, func(t *testing.T) { - testShreds := make([][]byte, len(allShreds)) - copy(testShreds, allShreds) - - // Remove data shreds - for i := 0; i < sc.numDataLost; i++ { - testShreds[i] = nil - } - - // Remove parity shreds - for i := 0; i < sc.numParityLost; i++ { - testShreds[32+i] = nil - } - - err := enc.Reconstruct(testShreds) - if sc.shouldRecover { - if err != nil { - t.Errorf("failed to reconstruct when it should: %v", err) - return - } - // Verify reconstruction - for i := range dataShreds { - if !bytes.Equal(testShreds[i], dataShreds[i]) { - t.Errorf("shard %d not properly reconstructed", i) - } - } - } else if err == nil { - t.Error("reconstruction succeeded when it should have failed") - } - }) - } - }) - - t.Run("random failure patterns", func(t *testing.T) { - enc, _ := NewEncoder(32, 32) - shredSize := blockSize / 32 - - for i := 0; i < numTests; i++ { - dataShreds := make([][]byte, 32) - for j := range dataShreds { - dataShreds[j] = make([]byte, shredSize) - if _, err := rand.Read(dataShreds[j]); err != nil { - t.Fatal(err) - } - } - - recoveryShreds, err := enc.GenerateRecoveryShreds(dataShreds) - if err != nil { - t.Fatal(err) - } - - allShreds := append(dataShreds, recoveryShreds...) - testShreds := make([][]byte, len(allShreds)) - copy(testShreds, allShreds) - - numToRemove := mrand.Intn(32) - removedIndices := make(map[int]bool) - for j := 0; j < numToRemove; j++ { - for { - idx := mrand.Intn(len(testShreds)) - if !removedIndices[idx] { - testShreds[idx] = nil - removedIndices[idx] = true - break - } - } - } - - err = enc.Reconstruct(testShreds) - if numToRemove >= 32 { - if err == nil { - t.Errorf("test %d: reconstruction succeeded with %d shreds removed", i, numToRemove) - } - } else { - if err != nil { - t.Errorf("test %d: failed to reconstruct with %d shreds removed: %v", i, numToRemove, err) - continue - } - - for j := range dataShreds { - if !bytes.Equal(testShreds[j], dataShreds[j]) { - t.Errorf("test %d: shard %d not properly reconstructed", i, j) - } - } - } - } - }) - - t.Run("performance benchmarks", func(t *testing.T) { - if testing.Short() { - t.Skip("skipping performance test in short mode") - } - - enc, _ := NewEncoder(32, 32) - shredSize := blockSize / 32 - dataShreds := make([][]byte, 32) - for i := range dataShreds { - dataShreds[i] = make([]byte, shredSize) - rand.Read(dataShreds[i]) - } - - recoveryShreds, err := enc.GenerateRecoveryShreds(dataShreds) - if err != nil { - t.Fatal(err) - } - - allShreds := append(dataShreds, recoveryShreds...) - lostCounts := []int{8, 16, 24, 31} - - for _, count := range lostCounts { - t.Run(fmt.Sprintf("reconstruct_%d_lost", count), func(t *testing.T) { - testShreds := make([][]byte, len(allShreds)) - copy(testShreds, allShreds) - - for i := 0; i < count; i++ { - testShreds[i] = nil - } - - if err := enc.Reconstruct(testShreds); err != nil { - t.Fatal(err) - } - - for i := range dataShreds { - if !bytes.Equal(testShreds[i], dataShreds[i]) { - t.Errorf("shard %d not properly reconstructed", i) - } - } - }) - } - }) -} diff --git a/gturbine/gtshred/benchmark_test.go b/gturbine/gtshred/benchmark_test.go index 4c98e6e..f058a0f 100644 --- a/gturbine/gtshred/benchmark_test.go +++ b/gturbine/gtshred/benchmark_test.go @@ -3,6 +3,7 @@ package gtshred import ( "context" "crypto/rand" + "crypto/sha256" "testing" "time" ) @@ -15,15 +16,14 @@ func (n *noopCallback) ProcessBlock(height uint64, blockHash []byte, block []byt func BenchmarkShredProcessing(b *testing.B) { sizes := []struct { - name string - size int - chunkSize uint32 + name string + size int }{ - {"8MB", 8 << 20, 1 << 18}, // 256KB chunks - {"16MB", 16 << 20, 1 << 19}, // 512KB chunks - {"32MB", 32 << 20, 1 << 20}, // 1MB chunks - {"64MB", 64 << 20, 1 << 21}, // 2MB chunks - {"128MB", 128 << 20, 1 << 22}, // 4MB chunks + {"8MB", 8 << 20}, // 256KB chunks + {"16MB", 16 << 20}, // 512KB chunks + {"32MB", 32 << 20}, // 1MB chunks + {"64MB", 64 << 20}, // 2MB chunks + {"128MB", 128 << 20}, // 4MB chunks } for _, size := range sizes { @@ -36,7 +36,8 @@ func BenchmarkShredProcessing(b *testing.B) { } // Create processor with noop callback - p := NewProcessor(&noopCallback{}, time.Minute) + hasher := sha256.New + p := NewProcessor(&noopCallback{}, hasher, hasher, time.Minute) go p.RunBackgroundCleanup(context.Background()) // Reset timer before main benchmark loop @@ -44,13 +45,13 @@ func BenchmarkShredProcessing(b *testing.B) { for i := 0; i < b.N; i++ { // Create shred group with appropriate chunk size - group, err := NewShredGroup(block, uint64(i), 32, 32, size.chunkSize) + group, err := ShredBlock(block, hasher, uint64(i), 32, 32) if err != nil { b.Fatal(err) } // Process all data shreds - for _, shred := range group.DataShreds { + for _, shred := range group.Shreds { if err := p.CollectShred(shred); err != nil { b.Fatal(err) } @@ -58,7 +59,7 @@ func BenchmarkShredProcessing(b *testing.B) { b.StopTimer() // Reset processor state between iterations - p.groups = make(map[string]*ShredGroupWithTimestamp) + p.groups = make(map[string]*ReconstructorWithTimestamp) p.completedBlocks = make(map[string]time.Time) b.StartTimer() } @@ -79,7 +80,6 @@ func BenchmarkShredReconstruction(b *testing.B) { // Use 32MB block with 1MB chunks block := make([]byte, 32<<20) - chunkSize := uint32(1 << 20) // 1MB chunks _, err := rand.Read(block) if err != nil { @@ -89,26 +89,27 @@ func BenchmarkShredReconstruction(b *testing.B) { for _, pattern := range patterns { b.Run(pattern.name, func(b *testing.B) { // Create processor - p := NewProcessor(&noopCallback{}, time.Minute) + hasher := sha256.New + p := NewProcessor(&noopCallback{}, hasher, hasher, time.Minute) go p.RunBackgroundCleanup(context.Background()) b.ResetTimer() for i := 0; i < b.N; i++ { // Create shred group with appropriate chunk size - group, err := NewShredGroup(block, uint64(i), 32, 32, chunkSize) + group, err := ShredBlock(block, hasher, uint64(i), 32, 32) if err != nil { b.Fatal(err) } // Simulate packet loss - lossCount := int(float64(len(group.DataShreds)) * pattern.lossRate) + lossCount := int(float64(len(group.Shreds)) * pattern.lossRate) for j := 0; j < lossCount; j++ { - group.DataShreds[j] = nil + group.Shreds[j] = nil } // Process remaining shreds - for _, shred := range group.DataShreds { + for _, shred := range group.Shreds { if shred != nil { if err := p.CollectShred(shred); err != nil { b.Fatal(err) @@ -116,15 +117,8 @@ func BenchmarkShredReconstruction(b *testing.B) { } } - // Process recovery shreds - for _, shred := range group.RecoveryShreds { - if err := p.CollectShred(shred); err != nil { - b.Fatal(err) - } - } - b.StopTimer() - p.groups = make(map[string]*ShredGroupWithTimestamp) + p.groups = make(map[string]*ReconstructorWithTimestamp) p.completedBlocks = make(map[string]time.Time) b.StartTimer() } diff --git a/gturbine/gtshred/process_shred_test.go b/gturbine/gtshred/process_shred_test.go index 393c6ba..251341c 100644 --- a/gturbine/gtshred/process_shred_test.go +++ b/gturbine/gtshred/process_shred_test.go @@ -64,11 +64,11 @@ func TestProcessorShredding(t *testing.T) { }, { name: "uneven block size", - blockSize: DefaultChunkSize*DefaultDataShreds - 1000, + blockSize: 32, }, { name: "oversized block", - blockSize: DefaultChunkSize*DefaultDataShreds + 1, + blockSize: maxBlockSize + 1, expectErr: true, }, { @@ -86,11 +86,12 @@ func TestProcessorShredding(t *testing.T) { t.Run(tc.name, func(t *testing.T) { var cb = new(testProcessorCallback) - p := NewProcessor(cb, time.Minute) + hasher := sha256.New + p := NewProcessor(cb, hasher, hasher, time.Minute) go p.RunBackgroundCleanup(context.Background()) block := makeRandomBlock(tc.blockSize) - group, err := NewShredGroup(block, TestHeight, DefaultDataShreds, DefaultRecoveryShreds, DefaultChunkSize) + b, err := ShredBlock(block, hasher, TestHeight, DefaultDataShreds, DefaultRecoveryShreds) if tc.expectErr { if err == nil { @@ -103,38 +104,32 @@ func TestProcessorShredding(t *testing.T) { t.Fatalf("unexpected error: %v", err) } - // Verify all shreds are properly sized - for i := range group.DataShreds { - if len(group.DataShreds[i].Data) != int(DefaultChunkSize) { - t.Errorf("data shred %d wrong size: got %d want %d", - i, len(group.DataShreds[i].Data), DefaultChunkSize) - } - } - // Collect threshold shreds into processor // collect all data shreds except the last 4, so that recovery shreds are necessary to reassemble for i := 0; i < DefaultDataShreds-4; i++ { - p.CollectShred(group.DataShreds[i]) + p.CollectShred(b.Shreds[i]) } // collect all recovery shreds - for i := 0; i < DefaultRecoveryShreds; i++ { - p.CollectShred(group.RecoveryShreds[i]) + for i := DefaultDataShreds; i < DefaultDataShreds+DefaultRecoveryShreds; i++ { + p.CollectShred(b.Shreds[i]) } if p.cb.(*testProcessorCallback).count != 1 { - t.Error("expected ProcessBlock to be called once") + t.Fatal("expected ProcessBlock to be called once") } - blockHash := sha256.Sum256(block) + h := hasher() + h.Write(block) + blockHash := h.Sum(nil) if !bytes.Equal(blockHash[:], cb.blockHash) { - t.Errorf("block hash mismatch: got %v want %v", cb.blockHash, group.BlockHash) + t.Fatalf("block hash mismatch: got %v want %v", cb.blockHash, b.Metadata.BlockHash) } if !bytes.Equal(block, cb.data) { - t.Errorf("reassembled block doesn't match original: got len %d want len %d", + t.Fatalf("reassembled block doesn't match original: got len %d want len %d", len(cb.data), len(block)) } diff --git a/gturbine/gtshred/processor.go b/gturbine/gtshred/processor.go index 1c62858..400c431 100644 --- a/gturbine/gtshred/processor.go +++ b/gturbine/gtshred/processor.go @@ -1,13 +1,17 @@ package gtshred import ( + "bytes" "context" + "errors" "fmt" + "hash" "sync" "time" + "github.com/gordian-engine/gordian/gerasure" + "github.com/gordian-engine/gordian/gerasure/gereedsolomon" "github.com/gordian-engine/gordian/gturbine" - "github.com/gordian-engine/gordian/gturbine/gtencoding" ) // Constants for error checking @@ -17,10 +21,13 @@ const ( maxBlockSize = 128 * 1024 * 1024 // 128MB maximum block size (matches Solana) ) -// ShredGroupWithTimestamp is a ShredGroup with a timestamp for tracking when the group was created (when the first shred was received). -type ShredGroupWithTimestamp struct { - *ShredGroup +// ReconstructorWithTimestamp is a Reconstructor with a timestamp for tracking when the first shred was received. +type ReconstructorWithTimestamp struct { + *gereedsolomon.Reconstructor + Metadata *gturbine.ShredMetadata Timestamp time.Time + + mu sync.Mutex } type Processor struct { @@ -28,13 +35,16 @@ type Processor struct { cb ProcessorCallback // groups is a cache of shred groups currently being processed. - groups map[string]*ShredGroupWithTimestamp + groups map[string]*ReconstructorWithTimestamp groupsMu sync.RWMutex // completedBlocks is a cache of block hashes that have been fully reassembled and should no longer be processed. completedBlocks map[string]time.Time completedBlocksMu sync.RWMutex + shredHasher func() hash.Hash + blockHasher func() hash.Hash + // cleanupInterval is the interval at which stale groups are cleaned up and completed blocks are removed cleanupInterval time.Duration } @@ -45,10 +55,12 @@ type ProcessorCallback interface { } // NewProcessor creates a new Processor with the given callback and cleanup interval. -func NewProcessor(cb ProcessorCallback, cleanupInterval time.Duration) *Processor { +func NewProcessor(cb ProcessorCallback, shredHasher func() hash.Hash, blockHasher func() hash.Hash, cleanupInterval time.Duration) *Processor { return &Processor{ cb: cb, - groups: make(map[string]*ShredGroupWithTimestamp), + shredHasher: shredHasher, + blockHasher: blockHasher, + groups: make(map[string]*ReconstructorWithTimestamp), completedBlocks: make(map[string]time.Time), cleanupInterval: cleanupInterval, } @@ -77,11 +89,19 @@ func (p *Processor) CollectShred(shred *gturbine.Shred) error { } // Skip shreds from already processed blocks - if p.isCompleted(shred.BlockHash) { + if p.isCompleted(shred.Metadata.BlockHash) { return nil } - group, ok := p.getGroup(shred.GroupID) + h := p.shredHasher() + h.Write(shred.Data) + hash := h.Sum(nil) + + if !bytes.Equal(hash, shred.Hash) { + return fmt.Errorf("shred hash mismatch: got %x want %x", hash, shred.Hash) + } + + group, ok := p.getGroup(shred.Metadata.GroupID) if !ok { // If the group doesn't exist, create it and add the shred return p.initGroup(shred) @@ -90,34 +110,43 @@ func (p *Processor) CollectShred(shred *gturbine.Shred) error { group.mu.Lock() defer group.mu.Unlock() + m := group.Metadata + // After locking the group, check if the block has already been completed. - if p.isCompleted(group.BlockHash) { + if p.isCompleted(m.BlockHash) { return nil } - full, err := group.collectShred(shred) + if err := group.Reconstructor.ReconstructData(nil, shred.Index, shred.Data); err != nil { + if !errors.Is(err, gerasure.ErrIncompleteSet) { + return err + } + return nil + } + + // The block is now full, reconstruct it and process it. + block, err := group.Reconstructor.Data(make([]byte, 0, m.FullDataSize), m.FullDataSize) if err != nil { - return fmt.Errorf("failed to collect data shred: %w", err) + return fmt.Errorf("failed to reconstruct block: %w", err) } - if full { - encoder, err := gtencoding.NewEncoder(group.TotalDataShreds, group.TotalRecoveryShreds) - if err != nil { - return fmt.Errorf("failed to create encoder: %w", err) - } - block, err := group.reconstructBlock(encoder) - if err != nil { - return fmt.Errorf("failed to reconstruct block: %w", err) - } + // Verify the block hash + h = p.blockHasher() + h.Write(block) + blockHash := h.Sum(nil) - if err := p.cb.ProcessBlock(shred.Height, shred.BlockHash, block); err != nil { - return fmt.Errorf("failed to process block: %w", err) - } + if !bytes.Equal(blockHash, m.BlockHash) { + return fmt.Errorf("block hash mismatch: got %x want %x", blockHash, m.BlockHash) + } - p.deleteGroup(shred.GroupID) - // then mark the block as completed at time.Now() - p.setCompleted(shred.BlockHash) + if err := p.cb.ProcessBlock(m.Height, m.BlockHash, block); err != nil { + return fmt.Errorf("failed to process block: %w", err) } + + p.deleteGroup(m.GroupID) + // then mark the block as completed at time.Now() + p.setCompleted(m.BlockHash) + return nil } @@ -149,7 +178,7 @@ func (p *Processor) cleanupStaleGroups(now time.Time) { for id, group := range p.groups { for _, hash := range deleteHashes { // Check if group is associated with a completed block - if string(group.BlockHash) == hash { + if string(group.Metadata.BlockHash) == hash { deleteGroups = append(deleteGroups, id) } } @@ -174,22 +203,17 @@ func (p *Processor) cleanupStaleGroups(now time.Time) { // initGroup creates a new group and adds the first shred to it. func (p *Processor) initGroup(shred *gturbine.Shred) error { now := time.Now() - group := &ShredGroup{ - DataShreds: make([]*gturbine.Shred, shred.TotalDataShreds), - RecoveryShreds: make([]*gturbine.Shred, shred.TotalRecoveryShreds), - TotalDataShreds: shred.TotalDataShreds, - TotalRecoveryShreds: shred.TotalRecoveryShreds, - GroupID: shred.GroupID, - BlockHash: shred.BlockHash, - Height: shred.Height, - OriginalSize: shred.FullDataSize, - } - group.DataShreds[shred.Index] = shred + m := shred.Metadata + + rcons, err := gereedsolomon.NewReconstructor(m.TotalDataShreds, m.TotalRecoveryShreds, len(shred.Data)) + if err != nil { + return fmt.Errorf("failed to create reconstructor: %w", err) + } p.groupsMu.Lock() - if _, ok := p.groups[shred.GroupID]; ok { + if _, ok := p.groups[shred.Metadata.GroupID]; ok { // If a group already exists, return early to avoid overwriting p.groupsMu.Unlock() @@ -199,16 +223,21 @@ func (p *Processor) initGroup(shred *gturbine.Shred) error { defer p.groupsMu.Unlock() - p.groups[shred.GroupID] = &ShredGroupWithTimestamp{ - ShredGroup: group, - Timestamp: now, + group := &ReconstructorWithTimestamp{ + Reconstructor: rcons, + Metadata: m, + Timestamp: now, } + group.Reconstructor.ReconstructData(nil, shred.Index, shred.Data) + + p.groups[m.GroupID] = group + return nil } // getGroup returns the group with the given ID, if it exists. -func (p *Processor) getGroup(groupID string) (*ShredGroupWithTimestamp, bool) { +func (p *Processor) getGroup(groupID string) (*ReconstructorWithTimestamp, bool) { p.groupsMu.RLock() defer p.groupsMu.RUnlock() group, ok := p.groups[groupID] diff --git a/gturbine/gtshred/processor_test.go b/gturbine/gtshred/processor_test.go index 15e3293..ff70ced 100644 --- a/gturbine/gtshred/processor_test.go +++ b/gturbine/gtshred/processor_test.go @@ -2,6 +2,7 @@ package gtshred import ( "context" + "crypto/sha256" "testing" "time" ) @@ -10,31 +11,32 @@ func TestProcessorMemoryCleanup(t *testing.T) { // Create processor with short cleanup interval for testing var cb = new(testProcessorCallback) cleanupInterval := 100 * time.Millisecond - p := NewProcessor(cb, cleanupInterval) + hasher := sha256.New + p := NewProcessor(cb, hasher, hasher, cleanupInterval) go p.RunBackgroundCleanup(context.Background()) // Create a test block and shred group block := []byte("test block data") - group, err := NewShredGroup(block, 1, 2, 1, 100) + group, err := ShredBlock(block, hasher, 1, 2, 1) if err != nil { t.Fatal(err) } // Process some shreds from the group to mark it complete - for i := 0; i < len(group.DataShreds); i++ { - err := p.CollectShred(group.DataShreds[i]) + for i := 0; i < len(group.Shreds); i++ { + err := p.CollectShred(group.Shreds[i]) if err != nil { t.Fatal(err) } } // Verify block is marked as completed - if _, exists := p.completedBlocks[string(group.BlockHash)]; !exists { + if _, exists := p.completedBlocks[string(group.Metadata.BlockHash)]; !exists { t.Error("block should be marked as completed") } // Try to process another shred from same block - err = p.CollectShred(group.RecoveryShreds[0]) + err = p.CollectShred(group.Shreds[0]) if err != nil { t.Fatal(err) } @@ -51,7 +53,7 @@ func TestProcessorMemoryCleanup(t *testing.T) { // Verify completed block was cleaned up p.completedBlocksMu.RLock() defer p.completedBlocksMu.RUnlock() - if _, exists := p.completedBlocks[string(group.BlockHash)]; exists { + if _, exists := p.completedBlocks[string(group.Metadata.BlockHash)]; exists { t.Error("completed block should have been cleaned up") } } diff --git a/gturbine/gtshred/shred_block.go b/gturbine/gtshred/shred_block.go new file mode 100644 index 0000000..e1c3a69 --- /dev/null +++ b/gturbine/gtshred/shred_block.go @@ -0,0 +1,78 @@ +package gtshred + +import ( + "fmt" + "hash" + "sync" + + "github.com/google/uuid" + "github.com/gordian-engine/gordian/gerasure/gereedsolomon" + "github.com/gordian-engine/gordian/gturbine" +) + +// ShreddedBlock contains a shredded block's data and metadata +type ShreddedBlock struct { + Shreds []*gturbine.Shred + Metadata *gturbine.ShredMetadata +} + +// ShredBlock shreds a block into data and recovery shreds. +func ShredBlock(block []byte, hasher func() hash.Hash, height uint64, dataShreds, recoveryShreds int) (*ShreddedBlock, error) { + if len(block) == 0 { + return nil, fmt.Errorf("empty block") + } + if len(block) > maxBlockSize { + return nil, fmt.Errorf("block too large: %d bytes exceeds max size %d", len(block), maxBlockSize) + } + + // Create encoder for this block + encoder, err := gereedsolomon.NewEncoder(dataShreds, recoveryShreds) + if err != nil { + return nil, fmt.Errorf("failed to create encoder: %w", err) + } + + h := hasher() + h.Write(block) + blockHash := h.Sum(nil) + + m := >urbine.ShredMetadata{ + GroupID: uuid.New().String(), + FullDataSize: len(block), + BlockHash: blockHash[:], + Height: height, + TotalDataShreds: dataShreds, + TotalRecoveryShreds: recoveryShreds, + } + + // Create new shred group + group := &ShreddedBlock{ + Metadata: m, + Shreds: make([]*gturbine.Shred, dataShreds+recoveryShreds), + } + + allShreds, err := encoder.Encode(nil, block) + if err != nil { + return nil, fmt.Errorf("failed to encode block: %w", err) + } + + var wg sync.WaitGroup + wg.Add(len(allShreds)) + for i, shard := range allShreds { + go func(i int, shard []byte) { + defer wg.Done() + h := hasher() + h.Write(shard) + hash := h.Sum(nil) + + group.Shreds[i] = >urbine.Shred{ + Metadata: m, + Index: i, + Data: shard, + Hash: hash, + } + }(i, shard) + } + wg.Wait() + + return group, nil +} diff --git a/gturbine/gtshred/shred_group.go b/gturbine/gtshred/shred_group.go deleted file mode 100644 index 9b6bc8d..0000000 --- a/gturbine/gtshred/shred_group.go +++ /dev/null @@ -1,229 +0,0 @@ -package gtshred - -import ( - "crypto/sha256" - "fmt" - "sync" - - "github.com/google/uuid" - "github.com/gordian-engine/gordian/gturbine" - "github.com/gordian-engine/gordian/gturbine/gtencoding" -) - -// ShredGroup represents a group of shreds that can be used to reconstruct a block. -type ShredGroup struct { - DataShreds []*gturbine.Shred - RecoveryShreds []*gturbine.Shred - TotalDataShreds int - TotalRecoveryShreds int - GroupID string // Changed to string for UUID - BlockHash []byte - Height uint64 // Added to struct level - OriginalSize int - - mu sync.Mutex -} - -// NewShredGroup creates a new ShredGroup from a block of data -func NewShredGroup(block []byte, height uint64, dataShreds, recoveryShreds int, chunkSize uint32) (*ShredGroup, error) { - if len(block) == 0 { - return nil, fmt.Errorf("empty block") - } - if len(block) > maxBlockSize { - return nil, fmt.Errorf("block too large: %d bytes exceeds max size %d", len(block), maxBlockSize) - } - if len(block) > int(chunkSize)*dataShreds { - return nil, fmt.Errorf("block too large for configured shred size: %d bytes exceeds max size %d", len(block), chunkSize*uint32(dataShreds)) - } - - // Create encoder for this block - encoder, err := gtencoding.NewEncoder(dataShreds, recoveryShreds) - if err != nil { - return nil, fmt.Errorf("failed to create encoder: %w", err) - } - - // Calculate block hash for verification - // TODO hasher should be interface. - blockHash := sha256.Sum256(block) - - // Create new shred group - group := &ShredGroup{ - DataShreds: make([]*gturbine.Shred, dataShreds), - RecoveryShreds: make([]*gturbine.Shred, recoveryShreds), - TotalDataShreds: dataShreds, - TotalRecoveryShreds: recoveryShreds, - GroupID: uuid.New().String(), - BlockHash: blockHash[:], - Height: height, - OriginalSize: len(block), - } - - // Create fixed-size data chunks - dataBytes := make([][]byte, dataShreds) - bytesPerShred := int(chunkSize) - - // Initialize all shreds to full chunk size with zeros - for i := 0; i < dataShreds; i++ { - dataBytes[i] = make([]byte, bytesPerShred) - } - - // Copy data into shreds - remaining := len(block) - offset := 0 - for i := 0; i < dataShreds && remaining > 0; i++ { - toCopy := remaining - if toCopy > bytesPerShred { - toCopy = bytesPerShred - } - copy(dataBytes[i], block[offset:offset+toCopy]) - offset += toCopy - remaining -= toCopy - } - - // Generate recovery data using erasure coding - recoveryBytes, err := encoder.GenerateRecoveryShreds(dataBytes) - if err != nil { - return nil, fmt.Errorf("failed to generate recovery shreds: %w", err) - } - - // Create data shreds - for i := range dataBytes { - group.DataShreds[i] = >urbine.Shred{ - Type: gturbine.DataShred, - Index: i, - TotalDataShreds: dataShreds, - TotalRecoveryShreds: recoveryShreds, - Data: dataBytes[i], - BlockHash: blockHash[:], - GroupID: group.GroupID, - Height: height, - FullDataSize: group.OriginalSize, - } - } - - // Create recovery shreds - for i := range recoveryBytes { - group.RecoveryShreds[i] = >urbine.Shred{ - Type: gturbine.RecoveryShred, - Index: i, - TotalDataShreds: dataShreds, - TotalRecoveryShreds: recoveryShreds, - Data: recoveryBytes[i], - BlockHash: blockHash[:], - GroupID: group.GroupID, - Height: height, - FullDataSize: group.OriginalSize, - } - } - - return group, nil -} - -// isFull checks if enough shreds are available to attempt reconstruction. -func (g *ShredGroup) isFull() bool { - valid := 0 - for _, s := range g.DataShreds { - if s != nil { - valid++ - } - } - - for _, s := range g.RecoveryShreds { - if s != nil { - valid++ - } - } - - return valid >= g.TotalDataShreds -} - -// reconstructBlock attempts to reconstruct the original block from available shreds -func (g *ShredGroup) reconstructBlock(encoder *gtencoding.Encoder) ([]byte, error) { - // Extract data bytes for erasure coding - allBytes := make([][]byte, len(g.DataShreds)+len(g.RecoveryShreds)) - - // Copy available data shreds - for i, shred := range g.DataShreds { - if shred != nil { - allBytes[i] = make([]byte, len(shred.Data)) - copy(allBytes[i], shred.Data) - } - } - - // Copy available recovery shreds - for i, shred := range g.RecoveryShreds { - if shred != nil { - allBytes[i+len(g.DataShreds)] = make([]byte, len(shred.Data)) - copy(allBytes[i+len(g.DataShreds)], shred.Data) - } - } - - // Reconstruct missing data - if err := encoder.Reconstruct(allBytes); err != nil { - return nil, fmt.Errorf("failed to reconstruct data: %w", err) - } - - // Combine data shreds - reconstructed := make([]byte, 0, g.OriginalSize) - remaining := g.OriginalSize - - for i := 0; i < len(g.DataShreds) && remaining > 0; i++ { - if allBytes[i] == nil { - return nil, fmt.Errorf("reconstruction failed: missing data for shard %d", i) - } - toCopy := remaining - if toCopy > len(allBytes[i]) { - toCopy = len(allBytes[i]) - } - reconstructed = append(reconstructed, allBytes[i][:toCopy]...) - remaining -= toCopy - } - - // Verify reconstructed block hash - // TODO hasher should be interface. - computedHash := sha256.Sum256(reconstructed) - if string(computedHash[:]) != string(g.BlockHash) { - return nil, fmt.Errorf("block hash mismatch after reconstruction") - } - - return reconstructed, nil -} - -// collectShred adds a data shred to the group -func (g *ShredGroup) collectShred(shred *gturbine.Shred) (bool, error) { - if shred == nil { - return false, fmt.Errorf("nil shred") - } - - // Validate shred matches group parameters - if shred.GroupID != g.GroupID { - return false, fmt.Errorf("group ID mismatch: got %s, want %s", shred.GroupID, g.GroupID) - } - if shred.Height != g.Height { - return false, fmt.Errorf("height mismatch: got %d, want %d", shred.Height, g.Height) - } - if string(shred.BlockHash) != string(g.BlockHash) { - return false, fmt.Errorf("block hash mismatch") - } - - switch shred.Type { - case gturbine.DataShred: - // Validate shred index - if int(shred.Index) >= len(g.DataShreds) { - return false, fmt.Errorf("invalid data shred index: %d", shred.Index) - } - - g.DataShreds[shred.Index] = shred - case gturbine.RecoveryShred: - // Validate shred index - if int(shred.Index) >= len(g.RecoveryShreds) { - return false, fmt.Errorf("invalid recovery shred index: %d", shred.Index) - } - - g.RecoveryShreds[shred.Index] = shred - default: - return false, fmt.Errorf("invalid shred type: %d", shred.Type) - } - - return g.isFull(), nil -} diff --git a/gturbine/turbine.go b/gturbine/turbine.go index a534c72..eb9c89d 100644 --- a/gturbine/turbine.go +++ b/gturbine/turbine.go @@ -12,27 +12,25 @@ type Layer struct { Children []*Layer } -type ShredType int - -const ( - DataShred ShredType = iota - RecoveryShred -) +// ShredMetadata contains metadata required to reconstruct a block from its shreds +type ShredMetadata struct { + GroupID string + FullDataSize int + BlockHash []byte + Height uint64 + TotalDataShreds int + TotalRecoveryShreds int +} // Shred represents a piece of a block that can be sent over the network type Shred struct { // Metadata for block reconstruction - FullDataSize int // Size of the full block - BlockHash []byte // Hash for data verification - GroupID string // UUID for associating shreds from the same block - Height uint64 // Block height for chain reference - + Metadata *ShredMetadata // Shred-specific metadata - Type ShredType - Index int // Index of this shred within the block - TotalDataShreds int // Total number of shreds for this block - TotalRecoveryShreds int // Total number of shreds for this block + Index int // Index of this shred within the block + Hash []byte // Hash of the shred data + // Shred data Data []byte // The actual shred data } diff --git a/gturbine/turbine_test.go b/gturbine/turbine_test.go index 07c0758..8b95393 100644 --- a/gturbine/turbine_test.go +++ b/gturbine/turbine_test.go @@ -3,6 +3,7 @@ package gturbine_test import ( "bytes" "context" + "crypto/sha256" "fmt" "net" "sync" @@ -66,7 +67,8 @@ func newTestNode(t *testing.T, basePort int) *testNode { cb := &testBlockHandler{} - processor := gtshred.NewProcessor(cb, time.Minute) + hasher := sha256.New + processor := gtshred.NewProcessor(cb, hasher, hasher, time.Minute) go processor.RunBackgroundCleanup(context.Background()) shredHandler := &testShredHandler{} @@ -108,13 +110,13 @@ func TestBlockPropagation(t *testing.T) { const testHeight = 12345 // Node 1: Shred the block - shredGroup, err := gtshred.NewShredGroup(originalBlock, testHeight, 16, 4, 1024) + shredGroup, err := gtshred.ShredBlock(originalBlock, sha256.New, testHeight, 16, 4) if err != nil { t.Fatalf("Failed to shred block: %v", err) } // Node 1: Encode and send shreds to Node 2 - for i, shred := range append(shredGroup.DataShreds, shredGroup.RecoveryShreds...) { + for i, shred := range shredGroup.Shreds { encodedShred, err := node1.codec.Encode(shred) if err != nil { t.Fatalf("Failed to encode shred: %v", err) @@ -162,13 +164,13 @@ func TestPartialBlockReconstruction(t *testing.T) { const testHeight = 54321 // Create shreds - shredGroup, err := gtshred.NewShredGroup(originalBlock, testHeight, 16, 4, 1024) + shredGroup, err := gtshred.ShredBlock(originalBlock, sha256.New, testHeight, 16, 4) if err != nil { t.Fatalf("Failed to shred block: %v", err) } // Send only minimum required shreds - minShreds := append(shredGroup.DataShreds[:12], shredGroup.RecoveryShreds...) + minShreds := append(shredGroup.Shreds[:12], shredGroup.Shreds[16:]...) for i, shred := range minShreds { encodedShred, err := node1.codec.Encode(shred) if err != nil { From 46738ece473b63311f1f5999f4901bcf860965f9 Mon Sep 17 00:00:00 2001 From: Andrew Gouin Date: Thu, 12 Dec 2024 08:48:33 -0700 Subject: [PATCH 3/4] Pipe through rs opts --- gerasure/gereedsolomon/encoder.go | 4 ++-- gerasure/gereedsolomon/reconstructor.go | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/gerasure/gereedsolomon/encoder.go b/gerasure/gereedsolomon/encoder.go index f8295f8..cf78cc3 100644 --- a/gerasure/gereedsolomon/encoder.go +++ b/gerasure/gereedsolomon/encoder.go @@ -15,14 +15,14 @@ type Encoder struct { // NewEncoder returns a new Encoder. // The options within the given reedsolomon.Encoder determine the number of shards. -func NewEncoder(dataShreds, parityShreds int) (*Encoder, error) { +func NewEncoder(dataShreds, parityShreds int, opts ...reedsolomon.Option) (*Encoder, error) { if dataShreds <= 0 { return nil, fmt.Errorf("data shreds must be > 0") } if parityShreds <= 0 { return nil, fmt.Errorf("parity shreds must be > 0") } - rs, err := reedsolomon.New(dataShreds, parityShreds) + rs, err := reedsolomon.New(dataShreds, parityShreds, opts...) if err != nil { return nil, fmt.Errorf("failed to create reed-solomon encoder: %w", err) } diff --git a/gerasure/gereedsolomon/reconstructor.go b/gerasure/gereedsolomon/reconstructor.go index e79355b..2ea8f82 100644 --- a/gerasure/gereedsolomon/reconstructor.go +++ b/gerasure/gereedsolomon/reconstructor.go @@ -28,14 +28,14 @@ type Reconstructor struct { // NewReconstructor returns a new Reconstructor. // The options within the given reedsolomon.Encoder determine the number of shards. // The shardSize and totalDataSize must be discovered out of band; -func NewReconstructor(dataShards, parityShards, shardSize int) (*Reconstructor, error) { +func NewReconstructor(dataShards, parityShards, shardSize int, opts ...reedsolomon.Option) (*Reconstructor, error) { if dataShards <= 0 { return nil, fmt.Errorf("data shards must be > 0") } if parityShards <= 0 { return nil, fmt.Errorf("parity shards must be > 0") } - rs, err := reedsolomon.New(dataShards, parityShards) + rs, err := reedsolomon.New(dataShards, parityShards, opts...) if err != nil { return nil, fmt.Errorf("failed to create reed-solomon reconstructor: %w", err) } From 79e17b8790bafd63506b26d2ecd3797d9bf8b537 Mon Sep 17 00:00:00 2001 From: Andrew Gouin Date: Thu, 12 Dec 2024 08:50:59 -0700 Subject: [PATCH 4/4] panic for < 0 shred counts --- gerasure/gereedsolomon/encoder.go | 12 +++++++++--- gerasure/gereedsolomon/reconstructor.go | 10 ++++++++-- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/gerasure/gereedsolomon/encoder.go b/gerasure/gereedsolomon/encoder.go index cf78cc3..fa98bce 100644 --- a/gerasure/gereedsolomon/encoder.go +++ b/gerasure/gereedsolomon/encoder.go @@ -17,10 +17,16 @@ type Encoder struct { // The options within the given reedsolomon.Encoder determine the number of shards. func NewEncoder(dataShreds, parityShreds int, opts ...reedsolomon.Option) (*Encoder, error) { if dataShreds <= 0 { - return nil, fmt.Errorf("data shreds must be > 0") + panic(fmt.Errorf( + "BUG: attempted to create reed solomon encoder with dataShreds < 0, got %d", + dataShreds, + )) } if parityShreds <= 0 { - return nil, fmt.Errorf("parity shreds must be > 0") + panic(fmt.Errorf( + "BUG: attempted to create reed solomon encoder with parityShreds < 0, got %d", + parityShreds, + )) } rs, err := reedsolomon.New(dataShreds, parityShreds, opts...) if err != nil { @@ -31,7 +37,7 @@ func NewEncoder(dataShreds, parityShreds int, opts ...reedsolomon.Option) (*Enco // Encode satisfies [gerasure.Encoder]. // Callers should assume that the Encoder takes ownership of the given data slice. -func (e *Encoder) Encode(_ context.Context, data []byte) ([][]byte, error) { +func (e Encoder) Encode(_ context.Context, data []byte) ([][]byte, error) { // From the original data, produce new subslices for the data shards and parity shards. allShards, err := e.rs.Split(data) if err != nil { diff --git a/gerasure/gereedsolomon/reconstructor.go b/gerasure/gereedsolomon/reconstructor.go index 2ea8f82..6d5b696 100644 --- a/gerasure/gereedsolomon/reconstructor.go +++ b/gerasure/gereedsolomon/reconstructor.go @@ -30,10 +30,16 @@ type Reconstructor struct { // The shardSize and totalDataSize must be discovered out of band; func NewReconstructor(dataShards, parityShards, shardSize int, opts ...reedsolomon.Option) (*Reconstructor, error) { if dataShards <= 0 { - return nil, fmt.Errorf("data shards must be > 0") + panic(fmt.Errorf( + "BUG: attempted to create reed solomon encoder with dataShreds < 0, got %d", + dataShards, + )) } if parityShards <= 0 { - return nil, fmt.Errorf("parity shards must be > 0") + panic(fmt.Errorf( + "BUG: attempted to create reed solomon encoder with parityShreds < 0, got %d", + parityShards, + )) } rs, err := reedsolomon.New(dataShards, parityShards, opts...) if err != nil {