diff --git a/go.mod b/go.mod index d852a48c61..a3b2525e82 100644 --- a/go.mod +++ b/go.mod @@ -41,6 +41,7 @@ require ( github.com/ipfs/go-ipld-format v0.6.0 github.com/ipfs/go-log/v2 v2.5.1 github.com/ipld/go-car v0.6.2 + github.com/klauspost/reedsolomon v1.12.1 github.com/libp2p/go-libp2p v0.33.2 github.com/libp2p/go-libp2p-kad-dht v0.25.2 github.com/libp2p/go-libp2p-pubsub v0.10.1 @@ -228,7 +229,6 @@ require ( github.com/josharian/intern v1.0.0 // indirect github.com/klauspost/compress v1.17.6 // indirect github.com/klauspost/cpuid/v2 v2.2.7 // indirect - github.com/klauspost/reedsolomon v1.12.1 // indirect github.com/koron/go-ssdp v0.0.4 // indirect github.com/lib/pq v1.10.7 // indirect github.com/libp2p/go-buffer-pool v0.1.0 // indirect diff --git a/share/new_eds/axis_half.go b/share/new_eds/axis_half.go index 17b29de591..dede70ebbc 100644 --- a/share/new_eds/axis_half.go +++ b/share/new_eds/axis_half.go @@ -1,6 +1,8 @@ package eds import ( + "fmt" + "github.com/celestiaorg/celestia-node/share" "github.com/celestiaorg/celestia-node/share/shwap" ) @@ -20,3 +22,48 @@ func (a AxisHalf) ToRow() shwap.Row { } return shwap.NewRow(a.Shares, side) } + +// Extended returns full axis shares from half axis shares. +func (a AxisHalf) Extended() ([]share.Share, error) { + if a.IsParity { + return reconstructShares(a.Shares) + } + return extendShares(a.Shares) +} + +// extendShares constructs full axis shares from original half axis shares. +func extendShares(original []share.Share) ([]share.Share, error) { + if len(original) == 0 { + return nil, fmt.Errorf("original shares are empty") + } + + codec := share.DefaultRSMT2DCodec() + parity, err := codec.Encode(original) + if err != nil { + return nil, fmt.Errorf("encoding: %w", err) + } + shares := make([]share.Share, len(original)*2) + copy(shares, original) + copy(shares[len(original):], parity) + return shares, nil +} + +// reconstructShares constructs full axis shares from parity half axis shares. +func reconstructShares(parity []share.Share) ([]share.Share, error) { + if len(parity) == 0 { + return nil, fmt.Errorf("parity shares are empty") + } + + sqLen := len(parity) * 2 + shares := make([]share.Share, sqLen) + for i := sqLen / 2; i < sqLen; i++ { + shares[i] = parity[i-sqLen/2] + } + + codec := share.DefaultRSMT2DCodec() + shares, err := codec.Decode(shares) + if err != nil { + return nil, fmt.Errorf("reconstructing: %w", err) + } + return shares, nil +} diff --git a/share/new_eds/axis_half_test.go b/share/new_eds/axis_half_test.go new file mode 100644 index 0000000000..752add5acd --- /dev/null +++ b/share/new_eds/axis_half_test.go @@ -0,0 +1,32 @@ +package eds + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/celestiaorg/celestia-node/share/sharetest" +) + +func TestExtendAxisHalf(t *testing.T) { + shares := sharetest.RandShares(t, 16) + + original := AxisHalf{ + Shares: shares, + IsParity: false, + } + + extended, err := original.Extended() + require.NoError(t, err) + require.Len(t, extended, len(shares)*2) + + parity := AxisHalf{ + Shares: extended[len(shares):], + IsParity: true, + } + + parityExtended, err := parity.Extended() + require.NoError(t, err) + + require.Equal(t, extended, parityExtended) +} diff --git a/share/shwap/row_namespace_data.go b/share/shwap/row_namespace_data.go index fe7bd6da36..5d424ee0f3 100644 --- a/share/shwap/row_namespace_data.go +++ b/share/shwap/row_namespace_data.go @@ -164,10 +164,14 @@ func (rnd RowNamespaceData) Validate(dah *share.Root, namespace share.Namespace, // verifyInclusion checks the inclusion of the row's shares in the provided root using NMT. func (rnd RowNamespaceData) verifyInclusion(rowRoot []byte, namespace share.Namespace) bool { leaves := make([][]byte, 0, len(rnd.Shares)) - for _, shr := range rnd.Shares { - namespaceBytes := share.GetNamespace(shr) - leaves = append(leaves, append(namespaceBytes, shr...)) + for _, sh := range rnd.Shares { + namespaceBytes := share.GetNamespace(sh) + leave := make([]byte, len(sh)+len(namespaceBytes)) + copy(leave, namespaceBytes) + copy(leave[len(namespaceBytes):], sh) + leaves = append(leaves, leave) } + return rnd.Proof.VerifyNamespace( share.NewSHA256Hasher(), namespace.ToNMT(), diff --git a/share/shwap/sample.go b/share/shwap/sample.go index c521410e35..a58a41910e 100644 --- a/share/shwap/sample.go +++ b/share/shwap/sample.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" + "github.com/celestiaorg/celestia-app/pkg/wrapper" "github.com/celestiaorg/nmt" nmt_pb "github.com/celestiaorg/nmt/pb" "github.com/celestiaorg/rsmt2d" @@ -24,6 +25,27 @@ type Sample struct { ProofType rsmt2d.Axis // ProofType indicates whether the proof is against a row or a column. } +func SampleFromShares(shares []share.Share, proofType rsmt2d.Axis, axisIdx, shrIdx int) (Sample, error) { + tree := wrapper.NewErasuredNamespacedMerkleTree(uint64(len(shares)/2), uint(axisIdx)) + for _, shr := range shares { + err := tree.Push(shr) + if err != nil { + return Sample{}, err + } + } + + proof, err := tree.ProveRange(shrIdx, shrIdx+1) + if err != nil { + return Sample{}, err + } + + return Sample{ + Share: shares[shrIdx], + Proof: &proof, + ProofType: proofType, + }, nil +} + // SampleFromProto converts a protobuf Sample back into its domain model equivalent. func SampleFromProto(s *pb.Sample) Sample { proof := nmt.NewInclusionProof( diff --git a/store/file/codec.go b/store/file/codec.go new file mode 100644 index 0000000000..a27280be11 --- /dev/null +++ b/store/file/codec.go @@ -0,0 +1,38 @@ +package file + +import ( + "sync" + + "github.com/klauspost/reedsolomon" +) + +var codec Codec + +func init() { + codec = NewCodec() +} + +type Codec interface { + Encoder(len int) (reedsolomon.Encoder, error) +} + +type codecCache struct { + cache sync.Map +} + +func NewCodec() Codec { + return &codecCache{} +} + +func (l *codecCache) Encoder(len int) (reedsolomon.Encoder, error) { + enc, ok := l.cache.Load(len) + if !ok { + var err error + enc, err = reedsolomon.New(len/2, len/2, reedsolomon.WithLeopardGF(true)) + if err != nil { + return nil, err + } + l.cache.Store(len, enc) + } + return enc.(reedsolomon.Encoder), nil +} diff --git a/store/file/codec_test.go b/store/file/codec_test.go new file mode 100644 index 0000000000..d6fdbb3045 --- /dev/null +++ b/store/file/codec_test.go @@ -0,0 +1,83 @@ +package file + +import ( + "fmt" + "testing" + + "github.com/klauspost/reedsolomon" + "github.com/stretchr/testify/require" + + "github.com/celestiaorg/celestia-node/share/sharetest" +) + +func BenchmarkCodec(b *testing.B) { + minSize, maxSize := 32, 128 + + for size := minSize; size <= maxSize; size *= 2 { + // BenchmarkCodec/Leopard/size:32-10 409194 2793 ns/op + // BenchmarkCodec/Leopard/size:64-10 190969 6170 ns/op + // BenchmarkCodec/Leopard/size:128-10 82821 14287 ns/op + b.Run(fmt.Sprintf("Leopard/size:%v", size), func(b *testing.B) { + enc, err := reedsolomon.New(size/2, size/2, reedsolomon.WithLeopardGF(true)) + require.NoError(b, err) + + shards := newShards(b, size, true) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + err = enc.Encode(shards) + require.NoError(b, err) + } + }) + + // BenchmarkCodec/default/size:32-10 222153 5364 ns/op + // BenchmarkCodec/default/size:64-10 58831 20349 ns/op + // BenchmarkCodec/default/size:128-10 14940 80471 ns/op + b.Run(fmt.Sprintf("default/size:%v", size), func(b *testing.B) { + enc, err := reedsolomon.New(size/2, size/2, reedsolomon.WithLeopardGF(false)) + require.NoError(b, err) + + shards := newShards(b, size, true) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + err = enc.Encode(shards) + require.NoError(b, err) + } + }) + + // BenchmarkCodec/default-reconstructSome/size:32-10 1263585 954.4 ns/op + // BenchmarkCodec/default-reconstructSome/size:64-10 762273 1554 ns/op + // BenchmarkCodec/default-reconstructSome/size:128-10 429268 2974 ns/op + b.Run(fmt.Sprintf("default-reconstructSome/size:%v", size), func(b *testing.B) { + enc, err := reedsolomon.New(size/2, size/2, reedsolomon.WithLeopardGF(false)) + require.NoError(b, err) + + shards := newShards(b, size, false) + targets := make([]bool, size) + target := size - 2 + targets[target] = true + + b.ResetTimer() + for i := 0; i < b.N; i++ { + err = enc.ReconstructSome(shards, targets) + require.NoError(b, err) + shards[target] = nil + } + }) + } +} + +func newShards(b require.TestingT, size int, fillParity bool) [][]byte { + shards := make([][]byte, size) + original := sharetest.RandShares(b, size/2) + copy(shards, original) + + if fillParity { + // fill with parity empty Shares + for j := len(original); j < len(shards); j++ { + shards[j] = make([]byte, len(original[0])) + } + } + return shards +} diff --git a/store/file/header.go b/store/file/header.go new file mode 100644 index 0000000000..d7d0a7760a --- /dev/null +++ b/store/file/header.go @@ -0,0 +1,94 @@ +package file + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + + "github.com/celestiaorg/celestia-node/share" +) + +const headerSize = 64 + +type header struct { + version fileVersion + fileType fileType + + // Taken directly from EDS + shareSize uint16 + squareSize uint16 + + datahash share.DataHash +} + +type fileVersion uint8 + +const ( + fileV0 fileVersion = iota +) + +type fileType uint8 + +const ( + ods fileType = iota + q1q4 +) + +func (h *header) WriteTo(w io.Writer) (int64, error) { + b := bytes.NewBuffer(make([]byte, 0, headerSize)) + _ = b.WriteByte(byte(h.version)) + _ = b.WriteByte(byte(h.fileType)) + _ = binary.Write(b, binary.LittleEndian, h.shareSize) + _ = binary.Write(b, binary.LittleEndian, h.squareSize) + _, _ = b.Write(h.datahash) + // write padding + _, _ = b.Write(make([]byte, headerSize-b.Len()-1)) + return writeLenEncoded(w, b.Bytes()) +} + +func readHeader(r io.Reader) (*header, error) { + bytesHeader, err := readLenEncoded(r) + if err != nil { + return nil, err + } + if len(bytesHeader) != headerSize-1 { + return nil, fmt.Errorf("readHeader: read %d bytes, expected %d", len(bytesHeader), headerSize) + } + h := &header{ + version: fileVersion(bytesHeader[0]), + fileType: fileType(bytesHeader[1]), + shareSize: binary.LittleEndian.Uint16(bytesHeader[2:4]), + squareSize: binary.LittleEndian.Uint16(bytesHeader[4:6]), + datahash: make([]byte, 32), + } + + copy(h.datahash, bytesHeader[6:6+32]) + return h, err +} + +func writeLenEncoded(w io.Writer, data []byte) (int64, error) { + _, err := w.Write([]byte{byte(len(data))}) + if err != nil { + return 0, err + } + return io.Copy(w, bytes.NewBuffer(data)) +} + +func readLenEncoded(r io.Reader) ([]byte, error) { + lenBuf := make([]byte, 1) + _, err := io.ReadFull(r, lenBuf) + if err != nil { + return nil, err + } + + data := make([]byte, lenBuf[0]) + n, err := io.ReadFull(r, data) + if err != nil { + return nil, err + } + if n != len(data) { + return nil, fmt.Errorf("readLenEncoded: read %d bytes, expected %d", n, len(data)) + } + return data, nil +} diff --git a/store/file/ods.go b/store/file/ods.go new file mode 100644 index 0000000000..51459bb95e --- /dev/null +++ b/store/file/ods.go @@ -0,0 +1,269 @@ +package file + +import ( + "context" + "fmt" + "io" + "os" + "sync" + + "github.com/celestiaorg/rsmt2d" + + "github.com/celestiaorg/celestia-node/share" + eds "github.com/celestiaorg/celestia-node/share/new_eds" + "github.com/celestiaorg/celestia-node/share/shwap" +) + +var _ eds.Accessor = (*OdsFile)(nil) + +type OdsFile struct { + path string + hdr *header + fl *os.File + + lock sync.RWMutex + ods square +} + +// OpenOdsFile opens an existing file. File has to be closed after usage. +func OpenOdsFile(path string) (*OdsFile, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + + h, err := readHeader(f) + if err != nil { + return nil, err + } + + return &OdsFile{ + path: path, + hdr: h, + fl: f, + }, nil +} + +// CreateOdsFile creates a new file. File has to be closed after usage. +func CreateOdsFile( + path string, + datahash share.DataHash, + eds *rsmt2d.ExtendedDataSquare, +) (*OdsFile, error) { + f, err := os.Create(path) + if err != nil { + return nil, fmt.Errorf("file create: %w", err) + } + + h := &header{ + version: fileV0, + fileType: ods, + shareSize: share.Size, // TODO: rsmt2d should expose this field + squareSize: uint16(eds.Width()), + datahash: datahash, + } + + err = writeOdsFile(f, h, eds) + if err != nil { + return nil, fmt.Errorf("writing ODS file: %w", err) + } + + // TODO: fill ods field with data from eds + return &OdsFile{ + path: path, + fl: f, + hdr: h, + }, f.Sync() +} + +func writeOdsFile(w io.Writer, h *header, eds *rsmt2d.ExtendedDataSquare) error { + _, err := h.WriteTo(w) + if err != nil { + return err + } + + for _, shr := range eds.FlattenedODS() { + if _, err := w.Write(shr); err != nil { + return err + } + } + return nil +} + +// Size returns square size of the Accessor. +func (f *OdsFile) Size(context.Context) int { + return f.size() +} + +func (f *OdsFile) size() int { + return int(f.hdr.squareSize) +} + +// Close closes the file. +func (f *OdsFile) Close() error { + return f.fl.Close() +} + +// Sample returns share and corresponding proof for row and column indices. Implementation can +// choose which axis to use for proof. Chosen axis for proof should be indicated in the returned +// Sample. +func (f *OdsFile) Sample(ctx context.Context, rowIdx, colIdx int) (shwap.Sample, error) { + // Sample proof axis is selected to optimize read performance. + // - For the first and second quadrants, we read the row axis because it is more efficient to read + // single row than reading full ods to calculate single column + // - For the third quadrants, we read the column axis because it is more efficient to read single + // column than reading full ods to calculate single row + // - For the fourth quadrant, it does not matter which axis we read because we need to read full ods + // to calculate the sample + axisType, axisIdx, shrIdx := rsmt2d.Row, rowIdx, colIdx + if colIdx < f.size()/2 && rowIdx >= f.size()/2 { + axisType, axisIdx, shrIdx = rsmt2d.Col, colIdx, rowIdx + } + + axis, err := f.axis(ctx, axisType, axisIdx) + if err != nil { + return shwap.Sample{}, fmt.Errorf("reading axis: %w", err) + } + + return shwap.SampleFromShares(axis, axisType, axisIdx, shrIdx) +} + +// AxisHalf returns half of shares axis of the given type and index. Side is determined by +// implementation. Implementations should indicate the side in the returned AxisHalf. +func (f *OdsFile) AxisHalf(ctx context.Context, axisType rsmt2d.Axis, axisIdx int) (eds.AxisHalf, error) { + // read axis from file if axisis row and from top half of the square or if axis is column and from + // left half of the square + if axisIdx < f.size()/2 { + shares, err := f.readAxisHalf(axisType, axisIdx) + if err != nil { + return eds.AxisHalf{}, fmt.Errorf("reading axis half: %w", err) + } + return eds.AxisHalf{ + Shares: shares, + IsParity: false, + }, nil + } + + // if axis is from the second half of the square, read full ods and compute the axis half + err := f.readOds() + if err != nil { + return eds.AxisHalf{}, err + } + + shares, err := f.ods.computeAxisHalf(ctx, axisType, axisIdx) + if err != nil { + return eds.AxisHalf{}, fmt.Errorf("computing axis half: %w", err) + } + return eds.AxisHalf{ + Shares: shares, + IsParity: false, + }, nil +} + +// RowNamespaceData returns data for the given namespace and row index. +func (f *OdsFile) RowNamespaceData( + ctx context.Context, + namespace share.Namespace, + rowIdx int, +) (shwap.RowNamespaceData, error) { + shares, err := f.axis(ctx, rsmt2d.Row, rowIdx) + if err != nil { + return shwap.RowNamespaceData{}, err + } + return shwap.RowNamespaceDataFromShares(shares, namespace, rowIdx) +} + +// Shares returns data shares extracted from the Accessor. +func (f *OdsFile) Shares(context.Context) ([]share.Share, error) { + err := f.readOds() + if err != nil { + return nil, err + } + return f.ods.shares() +} + +func (f *OdsFile) readAxisHalf(axisType rsmt2d.Axis, axisIdx int) ([]share.Share, error) { + f.lock.RLock() + ods := f.ods + f.lock.RUnlock() + if ods != nil { + return f.ods.axisHalf(context.Background(), axisType, axisIdx) + } + + switch axisType { + case rsmt2d.Col: + return f.readCol(axisIdx, 0) + case rsmt2d.Row: + return f.readRow(axisIdx) + } + return nil, fmt.Errorf("unknown axis") +} + +func (f *OdsFile) readOds() error { + f.lock.Lock() + defer f.lock.Unlock() + if f.ods != nil { + return nil + } + + // reset file pointer to the beginning of the file shares data + _, err := f.fl.Seek(headerSize, io.SeekStart) + if err != nil { + return fmt.Errorf("discarding header: %w", err) + } + + square, err := readSquare(f.fl, share.Size, f.size()) + if err != nil { + return fmt.Errorf("reading ods: %w", err) + } + f.ods = square + return nil +} + +func (f *OdsFile) readRow(idx int) ([]share.Share, error) { + shrLn := int(f.hdr.shareSize) + odsLn := int(f.hdr.squareSize) / 2 + + shares := make([]share.Share, odsLn) + + pos := idx * odsLn + offset := pos*shrLn + headerSize + + axsData := make([]byte, odsLn*shrLn) + if _, err := f.fl.ReadAt(axsData, int64(offset)); err != nil { + return nil, err + } + + for i := range shares { + shares[i] = axsData[i*shrLn : (i+1)*shrLn] + } + return shares, nil +} + +func (f *OdsFile) readCol(axisIdx, quadrantIdx int) ([]share.Share, error) { + shrLn := int(f.hdr.shareSize) + odsLn := int(f.hdr.squareSize) / 2 + quadrantOffset := quadrantIdx * odsLn * odsLn * shrLn + + shares := make([]share.Share, odsLn) + for i := range shares { + pos := axisIdx + i*odsLn + offset := pos*shrLn + headerSize + quadrantOffset + + shr := make(share.Share, shrLn) + if _, err := f.fl.ReadAt(shr, int64(offset)); err != nil { + return nil, err + } + shares[i] = shr + } + return shares, nil +} + +func (f *OdsFile) axis(ctx context.Context, axisType rsmt2d.Axis, axisIdx int) ([]share.Share, error) { + half, err := f.AxisHalf(ctx, axisType, axisIdx) + if err != nil { + return nil, err + } + + return half.Extended() +} diff --git a/store/file/ods_test.go b/store/file/ods_test.go new file mode 100644 index 0000000000..eb8731e13a --- /dev/null +++ b/store/file/ods_test.go @@ -0,0 +1,146 @@ +package file + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/tendermint/tendermint/libs/rand" + + "github.com/celestiaorg/rsmt2d" + + "github.com/celestiaorg/celestia-node/share" + "github.com/celestiaorg/celestia-node/share/eds/edstest" + eds "github.com/celestiaorg/celestia-node/share/new_eds" +) + +func TestCreateOdsFile(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + t.Cleanup(cancel) + + path := t.TempDir() + "/testfile" + edsIn := edstest.RandEDS(t, 8) + datahash := share.DataHash(rand.Bytes(32)) + f, err := CreateOdsFile(path, datahash, edsIn) + require.NoError(t, err) + + shares, err := f.Shares(ctx) + require.NoError(t, err) + expected := edsIn.FlattenedODS() + require.Equal(t, expected, shares) + require.Equal(t, datahash, f.hdr.datahash) + require.NoError(t, f.Close()) + + f, err = OpenOdsFile(path) + require.NoError(t, err) + shares, err = f.Shares(ctx) + require.NoError(t, err) + require.Equal(t, expected, shares) + require.Equal(t, datahash, f.hdr.datahash) + require.NoError(t, f.Close()) +} + +func TestReadOdsFromFile(t *testing.T) { + eds := edstest.RandEDS(t, 8) + path := t.TempDir() + "/testfile" + f, err := CreateOdsFile(path, []byte{}, eds) + require.NoError(t, err) + + err = f.readOds() + require.NoError(t, err) + for i, row := range f.ods { + original := eds.Row(uint(i))[:eds.Width()/2] + require.True(t, len(original) == len(row)) + require.Equal(t, original, row) + } +} + +func TestOdsFile(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + t.Cleanup(cancel) + + odsSize := 8 + createOdsFile := func(eds *rsmt2d.ExtendedDataSquare) eds.Accessor { + path := t.TempDir() + "/testfile" + fl, err := CreateOdsFile(path, []byte{}, eds) + require.NoError(t, err) + return fl + } + + t.Run("Sample", func(t *testing.T) { + eds.TestAccessorSample(ctx, t, createOdsFile, odsSize) + }) + + t.Run("AxisHalf", func(t *testing.T) { + eds.TestAccessorAxisHalf(ctx, t, createOdsFile, odsSize) + }) + + t.Run("RowNamespaceData", func(t *testing.T) { + eds.TestAccessorRowNamespaceData(ctx, t, createOdsFile, odsSize) + }) + + t.Run("Shares", func(t *testing.T) { + eds.TestAccessorShares(ctx, t, createOdsFile, odsSize) + }) +} + +// ReconstructSome, default codec +// BenchmarkAxisFromOdsFile/Size:32/Axis:row/squareHalf:first(original)-10 455848 2588 ns/op +// BenchmarkAxisFromOdsFile/Size:32/Axis:row/squareHalf:second(extended)-10 9015 203950 ns/op +// BenchmarkAxisFromOdsFile/Size:32/Axis:col/squareHalf:first(original)-10 52734 21178 ns/op +// BenchmarkAxisFromOdsFile/Size:32/Axis:col/squareHalf:second(extended)-10 8830 127452 ns/op +// BenchmarkAxisFromOdsFile/Size:64/Axis:row/squareHalf:first(original)-10 303834 4763 ns/op +// BenchmarkAxisFromOdsFile/Size:64/Axis:row/squareHalf:second(extended)-10 2940 426246 ns/op +// BenchmarkAxisFromOdsFile/Size:64/Axis:col/squareHalf:first(original)-10 27758 42842 ns/op +// BenchmarkAxisFromOdsFile/Size:64/Axis:col/squareHalf:second(extended)-10 3385 353868 ns/op +// BenchmarkAxisFromOdsFile/Size:128/Axis:row/squareHalf:first(original)-10 172086 6455 ns/op +// BenchmarkAxisFromOdsFile/Size:128/Axis:row/squareHalf:second(extended)-10 672 1550386 ns/op +// BenchmarkAxisFromOdsFile/Size:128/Axis:col/squareHalf:first(original)-10 14202 84316 ns/op +// BenchmarkAxisFromOdsFile/Size:128/Axis:col/squareHalf:second(extended)-10 978 1230980 ns/op +func BenchmarkAxisFromOdsFile(b *testing.B) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + b.Cleanup(cancel) + + minSize, maxSize := 32, 128 + dir := b.TempDir() + + newFile := func(size int) eds.Accessor { + eds := edstest.RandEDS(b, size) + path := dir + "/testfile" + f, err := CreateOdsFile(path, []byte{}, eds) + require.NoError(b, err) + return f + } + eds.BenchGetHalfAxisFromAccessor(ctx, b, newFile, minSize, maxSize) +} + +// BenchmarkShareFromOdsFile/Size:32/Axis:row/squareHalf:first(original)-10 10339 111328 ns/op +// BenchmarkShareFromOdsFile/Size:32/Axis:row/squareHalf:second(extended)-10 3392 359180 ns/op +// BenchmarkShareFromOdsFile/Size:32/Axis:col/squareHalf:first(original)-10 8925 131352 ns/op +// BenchmarkShareFromOdsFile/Size:32/Axis:col/squareHalf:second(extended)-10 3447 346218 ns/op +// BenchmarkShareFromOdsFile/Size:64/Axis:row/squareHalf:first(original)-10 5503 215833 ns/op +// BenchmarkShareFromOdsFile/Size:64/Axis:row/squareHalf:second(extended)-10 1231 1001053 ns/op +// BenchmarkShareFromOdsFile/Size:64/Axis:col/squareHalf:first(original)-10 4711 250001 ns/op +// BenchmarkShareFromOdsFile/Size:64/Axis:col/squareHalf:second(extended)-10 1315 910079 ns/op +// BenchmarkShareFromOdsFile/Size:128/Axis:row/squareHalf:first(original)-10 2364 435748 ns/op +// BenchmarkShareFromOdsFile/Size:128/Axis:row/squareHalf:second(extended)-10 358 3330620 ns/op +// BenchmarkShareFromOdsFile/Size:128/Axis:col/squareHalf:first(original)-10 2114 514642 ns/op +// BenchmarkShareFromOdsFile/Size:128/Axis:col/squareHalf:second(extended)-10 373 3068104 ns/op +func BenchmarkShareFromOdsFile(b *testing.B) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + b.Cleanup(cancel) + + minSize, maxSize := 32, 128 + dir := b.TempDir() + + newFile := func(size int) eds.Accessor { + eds := edstest.RandEDS(b, size) + path := dir + "/testfile" + f, err := CreateOdsFile(path, []byte{}, eds) + require.NoError(b, err) + return f + } + + eds.BenchGetSampleFromAccessor(ctx, b, newFile, minSize, maxSize) +} diff --git a/store/file/square.go b/store/file/square.go new file mode 100644 index 0000000000..f4bc7e8416 --- /dev/null +++ b/store/file/square.go @@ -0,0 +1,132 @@ +package file + +import ( + "bufio" + "context" + "fmt" + "io" + + "golang.org/x/sync/errgroup" + + "github.com/celestiaorg/rsmt2d" + + "github.com/celestiaorg/celestia-node/share" +) + +type square [][]share.Share + +// readSquare reads Shares from the reader and returns a square. It assumes that the reader is +// positioned at the beginning of the Shares. It knows the size of the Shares and the size of the +// square, so reads from reader are limited to exactly the amount of data required. +func readSquare(r io.Reader, shareSize, edsSize int) (square, error) { + odsLn := edsSize / 2 + + square := make(square, odsLn) + for i := range square { + square[i] = make([]share.Share, odsLn) + for j := range square[i] { + square[i][j] = make(share.Share, shareSize) + } + } + + // TODO(@walldiss): run benchmark to find optimal size for this buffer + br := bufio.NewReaderSize(r, 4096) + var total int + for i := 0; i < odsLn; i++ { + for j := 0; j < odsLn; j++ { + n, err := io.ReadFull(br, square[i][j]) + if err != nil { + return nil, fmt.Errorf("reading share: %w, bytes read: %v", err, total+n) + } + if n != shareSize { + return nil, fmt.Errorf("share size mismatch: expected %v, got %v", shareSize, n) + } + total += n + } + } + return square, nil +} + +func (s square) size() int { + return len(s) +} + +func (s square) axisHalf(_ context.Context, axisType rsmt2d.Axis, axisIdx int) ([]share.Share, error) { + if s == nil { + return nil, fmt.Errorf("square is nil") + } + + if axisIdx >= s.size() { + return nil, fmt.Errorf("index is out of square bounds") + } + + // square stores rows directly in high level slice, so we can return by accessing row by index + if axisType == rsmt2d.Row { + return s[axisIdx], nil + } + + // construct half column from row ordered square + col := make([]share.Share, s.size()) + for i := 0; i < s.size(); i++ { + col[i] = s[i][axisIdx] + } + return col, nil +} + +func (s square) shares() ([]share.Share, error) { + shares := make([]share.Share, 0, s.size()*s.size()) + for _, row := range s { + shares = append(shares, row...) + } + return shares, nil +} + +func (s square) computeAxisHalf( + ctx context.Context, + axisType rsmt2d.Axis, + axisIdx int, +) ([]share.Share, error) { + shares := make([]share.Share, s.size()) + + // extend opposite half of the square while collecting Shares for the first half of required axis + g, ctx := errgroup.WithContext(ctx) + opposite := oppositeAxis(axisType) + for i := 0; i < s.size(); i++ { + i := i + g.Go(func() error { + original, err := s.axisHalf(ctx, opposite, i) + if err != nil { + return err + } + + enc, err := codec.Encoder(s.size() * 2) + if err != nil { + return fmt.Errorf("encoder: %w", err) + } + + shards := make([][]byte, s.size()*2) + copy(shards, original) + + target := make([]bool, s.size()*2) + target[axisIdx] = true + + err = enc.ReconstructSome(shards, target) + if err != nil { + return fmt.Errorf("reconstruct some: %w", err) + } + + shares[i] = shards[axisIdx] + return nil + }) + } + + err := g.Wait() + return shares, err +} + +func oppositeAxis(axis rsmt2d.Axis) rsmt2d.Axis { + if axis == rsmt2d.Col { + return rsmt2d.Row + } + return rsmt2d.Col +}