diff --git a/share/new_eds/accessor.go b/share/new_eds/accessor.go index 8fe740b29e..8ced738fa4 100644 --- a/share/new_eds/accessor.go +++ b/share/new_eds/accessor.go @@ -27,17 +27,24 @@ type Accessor interface { Shares(ctx context.Context) ([]share.Share, error) } -// AccessorCloser is an interface that groups Accessor and io.Closer interfaces. -type AccessorCloser interface { +// AccessorStreamer is an interface that groups Accessor and Streamer interfaces. +type AccessorStreamer interface { Accessor + Streamer +} + +type Streamer interface { + // Reader returns binary reader for the shares. It should read the shares from the + // ODS part of the square row by row. + Reader() (io.Reader, error) io.Closer } -type accessorCloser struct { +type accessorStreamer struct { Accessor - io.Closer + Streamer } -func WithCloser(a Accessor, c io.Closer) AccessorCloser { - return &accessorCloser{a, c} +func AccessorAndStreamer(a Accessor, s Streamer) AccessorStreamer { + return &accessorStreamer{a, s} } diff --git a/share/new_eds/close_once.go b/share/new_eds/close_once.go index f6d90a6332..c05851db7c 100644 --- a/share/new_eds/close_once.go +++ b/share/new_eds/close_once.go @@ -3,6 +3,7 @@ package eds import ( "context" "errors" + "io" "sync/atomic" "github.com/celestiaorg/rsmt2d" @@ -11,16 +12,16 @@ import ( "github.com/celestiaorg/celestia-node/share/shwap" ) -var _ AccessorCloser = (*closeOnce)(nil) +var _ AccessorStreamer = (*closeOnce)(nil) var errAccessorClosed = errors.New("accessor is closed") type closeOnce struct { - f AccessorCloser + f AccessorStreamer closed atomic.Bool } -func WithClosedOnce(f AccessorCloser) AccessorCloser { +func WithClosedOnce(f AccessorStreamer) AccessorStreamer { return &closeOnce{f: f} } @@ -76,3 +77,10 @@ func (c *closeOnce) Shares(ctx context.Context) ([]share.Share, error) { } return c.f.Shares(ctx) } + +func (c *closeOnce) Reader() (io.Reader, error) { + if c.closed.Load() { + return nil, errAccessorClosed + } + return c.f.Reader() +} diff --git a/share/new_eds/close_once_test.go b/share/new_eds/close_once_test.go index 7ba9ada94b..a063423a1c 100644 --- a/share/new_eds/close_once_test.go +++ b/share/new_eds/close_once_test.go @@ -2,7 +2,9 @@ package eds import ( "context" + "io" "testing" + "testing/iotest" "github.com/stretchr/testify/require" @@ -68,6 +70,10 @@ func (s *stubEdsAccessorCloser) Shares(context.Context) ([]share.Share, error) { return nil, nil } +func (s *stubEdsAccessorCloser) Reader() (io.Reader, error) { + return iotest.ErrReader(nil), nil +} + func (s *stubEdsAccessorCloser) Close() error { s.closed = true return nil diff --git a/share/new_eds/nd_test.go b/share/new_eds/nd_test.go index a0780292ef..60fd9888c3 100644 --- a/share/new_eds/nd_test.go +++ b/share/new_eds/nd_test.go @@ -20,7 +20,7 @@ func TestNamespacedData(t *testing.T) { namespace := sharetest.RandV0Namespace() for amount := 1; amount < sharesAmount; amount++ { eds, root := edstest.RandEDSWithNamespace(t, namespace, amount, odsSize) - rsmt2d := Rsmt2D{ExtendedDataSquare: eds} + rsmt2d := &Rsmt2D{ExtendedDataSquare: eds} nd, err := NamespacedData(ctx, root, rsmt2d, namespace) require.NoError(t, err) require.True(t, len(nd) > 0) diff --git a/share/new_eds/proofs_cache.go b/share/new_eds/proofs_cache.go index 068b4c4d0b..faf3e580ef 100644 --- a/share/new_eds/proofs_cache.go +++ b/share/new_eds/proofs_cache.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "io" "sync" "sync/atomic" @@ -20,13 +21,13 @@ import ( "github.com/celestiaorg/celestia-node/share/shwap" ) -var _ Accessor = (*proofsCache)(nil) +var _ AccessorStreamer = (*proofsCache)(nil) // proofsCache is eds accessor that caches proofs for rows and columns. It also caches extended // axis Shares. It is used to speed up the process of building proofs for rows and columns, // reducing the number of reads from the underlying accessor. type proofsCache struct { - inner Accessor + inner AccessorStreamer // lock protects axisCache lock sync.RWMutex @@ -57,7 +58,7 @@ type axisWithProofs struct { // WithProofsCache creates a new eds accessor with caching of proofs for rows and columns. It is // used to speed up the process of building proofs for rows and columns, reducing the number of // reads from the underlying accessor. -func WithProofsCache(ac Accessor) Accessor { +func WithProofsCache(ac AccessorStreamer) AccessorStreamer { rows := make(map[int]axisWithProofs) cols := make(map[int]axisWithProofs) axisCache := []map[int]axisWithProofs{rows, cols} @@ -211,21 +212,28 @@ func (c *proofsCache) Shares(ctx context.Context) ([]share.Share, error) { return shares, nil } +func (c *proofsCache) Reader() (io.Reader, error) { + odsSize := c.Size(context.TODO()) / 2 + reader := NewSharesReader(odsSize, c.getShare) + return reader, nil +} + +func (c *proofsCache) Close() error { + return c.inner.Close() +} + func (c *proofsCache) axisShares(ctx context.Context, axisType rsmt2d.Axis, axisIdx int) ([]share.Share, error) { ax, ok := c.getAxisFromCache(axisType, axisIdx) if ok && ax.shares != nil { return ax.shares, nil } - if len(ax.half.Shares) == 0 { - half, err := c.AxisHalf(ctx, axisType, axisIdx) - if err != nil { - return nil, err - } - ax.half = half + half, err := c.AxisHalf(ctx, axisType, axisIdx) + if err != nil { + return nil, err } - shares, err := ax.half.Extended() + shares, err := half.Extended() if err != nil { return nil, fmt.Errorf("extending shares: %w", err) } @@ -250,6 +258,30 @@ func (c *proofsCache) getAxisFromCache(axisType rsmt2d.Axis, axisIdx int) (axisW return ax, ok } +func (c *proofsCache) getShare(rowIdx, colIdx int) ([]byte, error) { + ctx := context.TODO() + odsSize := c.Size(ctx) / 2 + half, err := c.AxisHalf(ctx, rsmt2d.Row, rowIdx) + if err != nil { + return nil, fmt.Errorf("reading axis half: %w", err) + } + + // if share is from the same side of axis return share right away + if colIdx > odsSize == half.IsParity { + if half.IsParity { + colIdx -= odsSize + } + return half.Shares[colIdx], nil + } + + // if share index is from opposite part of axis, obtain full axis shares + shares, err := c.axisShares(ctx, rsmt2d.Row, rowIdx) + if err != nil { + return nil, fmt.Errorf("reading axis shares: %w", err) + } + return shares[colIdx], nil +} + // rowProofsGetter implements blockservice.BlockGetter interface type rowProofsGetter struct { proofs map[cid.Cid]blocks.Block diff --git a/share/new_eds/proofs_cache_test.go b/share/new_eds/proofs_cache_test.go index 8b22af6e4f..b570b15c1e 100644 --- a/share/new_eds/proofs_cache_test.go +++ b/share/new_eds/proofs_cache_test.go @@ -9,14 +9,19 @@ import ( ) func TestCache(t *testing.T) { - size := 8 - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + ODSSize := 16 + ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) t.Cleanup(cancel) - withProofsCache := func(tb testing.TB, inner *rsmt2d.ExtendedDataSquare) Accessor { + newAccessor := func(tb testing.TB, inner *rsmt2d.ExtendedDataSquare) Accessor { accessor := &Rsmt2D{ExtendedDataSquare: inner} return WithProofsCache(accessor) } + TestSuiteAccessor(ctx, t, newAccessor, ODSSize) - TestSuiteAccessor(ctx, t, withProofsCache, size) + newAccessorStreamer := func(tb testing.TB, inner *rsmt2d.ExtendedDataSquare) AccessorStreamer { + accessor := &Rsmt2D{ExtendedDataSquare: inner} + return WithProofsCache(accessor) + } + TestStreamer(ctx, t, newAccessorStreamer, ODSSize) } diff --git a/share/new_eds/reader.go b/share/new_eds/reader.go new file mode 100644 index 0000000000..bef11527e2 --- /dev/null +++ b/share/new_eds/reader.go @@ -0,0 +1,97 @@ +package eds + +import ( + "bytes" + "errors" + "fmt" + "io" + + "github.com/celestiaorg/celestia-node/share" +) + +// BufferedReader will read Shares from getShare function into the buffer. +// It exposes the buffer to be read by io.Reader interface implementation +type BufferedReader struct { + buf *bytes.Buffer + getShare func(rowIdx, colIdx int) ([]byte, error) + // current is the amount of Shares stored in square that have been written by squareCopy. When + // current reaches total, squareCopy will prevent further reads by returning io.EOF + current, odsSize, total int +} + +func NewSharesReader(odsSize int, getShare func(rowIdx, colIdx int) ([]byte, error)) *BufferedReader { + return &BufferedReader{ + getShare: getShare, + buf: bytes.NewBuffer(nil), + odsSize: odsSize, + total: odsSize * odsSize, + } +} + +func (r *BufferedReader) Read(p []byte) (int, error) { + if r.current >= r.total && r.buf.Len() == 0 { + return 0, io.EOF + } + // if provided array is smaller than data in buf, read from buf + if len(p) <= r.buf.Len() { + return r.buf.Read(p) + } + n, err := io.ReadFull(r.buf, p) + if err == nil { + return n, nil + } + if !errors.Is(err, io.ErrUnexpectedEOF) && !errors.Is(err, io.EOF) { + return n, fmt.Errorf("unexpected error reading from buf: %w", err) + } + + written := n + for r.current < r.total { + rowIdx, colIdx := r.current/r.odsSize, r.current%r.odsSize + share, err := r.getShare(rowIdx, colIdx) + if err != nil { + return 0, fmt.Errorf("get share: %w", err) + } + + // copy share to provided buffer + emptySpace := len(p) - written + r.current++ + if len(share) < emptySpace { + n := copy(p[written:], share) + written += n + continue + } + + // if share didn't fit into buffer fully, store remaining bytes into inner buf + n := copy(p[written:], share[:emptySpace]) + written += n + n, err = r.buf.Write(share[emptySpace:]) + if err != nil { + return 0, fmt.Errorf("write share to inner buffer: %w", err) + } + if n != len(share)-emptySpace { + return 0, fmt.Errorf("share was not written fully: %w", io.ErrShortWrite) + } + return written, nil + } + return written, nil +} + +// ReadShares reads shares from the provided reader and constructs an Extended Data Square. Provided +// reader should contain shares in row-major order. +func ReadShares(r io.Reader, shareSize, odsSize int) ([]share.Share, error) { + shares := make([]share.Share, odsSize*odsSize) + var total int + for i := range shares { + share := make(share.Share, shareSize) + n, err := io.ReadFull(r, share) + if err != nil { + return nil, fmt.Errorf("reading share: %w, bytes read: %v", err, total+n) + } + if n != shareSize { + return nil, fmt.Errorf("share size mismatch: expected %v, got %v", shareSize, n) + } + shares[i] = share + total += n + } + return shares, nil +} diff --git a/share/new_eds/reader_test.go b/share/new_eds/reader_test.go new file mode 100644 index 0000000000..7b05b17b69 --- /dev/null +++ b/share/new_eds/reader_test.go @@ -0,0 +1,54 @@ +package eds + +import ( + "errors" + "io" + "math/rand" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/celestiaorg/celestia-node/share" + "github.com/celestiaorg/celestia-node/share/eds/edstest" +) + +func TestSharesReader(t *testing.T) { + // create io.Writer that write random data + odsSize := 16 + eds := edstest.RandEDS(t, odsSize) + getShare := func(rowIdx, colIdx int) ([]byte, error) { + return eds.GetCell(uint(rowIdx), uint(colIdx)), nil + } + + reader := NewSharesReader(odsSize, getShare) + readBytes, err := readWithRandomBuffer(reader, 1024) + require.NoError(t, err) + expected := make([]byte, 0, odsSize*odsSize*share.Size) + for _, share := range eds.FlattenedODS() { + expected = append(expected, share...) + } + require.Len(t, readBytes, len(expected)) + require.Equal(t, expected, readBytes) +} + +// testRandReader reads from reader with buffers of random sizes. +func readWithRandomBuffer(reader io.Reader, maxBufSize int) ([]byte, error) { + // create buffer of random size + data := make([]byte, 0, maxBufSize) + for { + bufSize := rand.Intn(maxBufSize-1) + 1 + buf := make([]byte, bufSize) + n, err := reader.Read(buf) + if err != nil && !errors.Is(err, io.EOF) { + return nil, err + } + if n < bufSize { + buf = buf[:n] + } + data = append(data, buf...) + if errors.Is(err, io.EOF) { + break + } + } + return data, nil +} diff --git a/share/new_eds/rsmt2d.go b/share/new_eds/rsmt2d.go index e5ffe6704b..9bec30db66 100644 --- a/share/new_eds/rsmt2d.go +++ b/share/new_eds/rsmt2d.go @@ -3,6 +3,7 @@ package eds import ( "context" "fmt" + "io" "github.com/celestiaorg/celestia-app/pkg/wrapper" "github.com/celestiaorg/rsmt2d" @@ -11,7 +12,7 @@ import ( "github.com/celestiaorg/celestia-node/share/shwap" ) -var _ Accessor = Rsmt2D{} +var _ AccessorStreamer = (*Rsmt2D)(nil) // Rsmt2D is a rsmt2d based in-memory implementation of Accessor. type Rsmt2D struct { @@ -19,12 +20,12 @@ type Rsmt2D struct { } // Size returns the size of the Extended Data Square. -func (eds Rsmt2D) Size(context.Context) int { +func (eds *Rsmt2D) Size(context.Context) int { return int(eds.Width()) } // Sample returns share and corresponding proof for row and column indices. -func (eds Rsmt2D) Sample( +func (eds *Rsmt2D) Sample( _ context.Context, rowIdx, colIdx int, ) (shwap.Sample, error) { @@ -33,7 +34,7 @@ func (eds Rsmt2D) Sample( // SampleForProofAxis samples a share from an Extended Data Square based on the provided // row and column indices and proof axis. It returns a sample with the share and proof. -func (eds Rsmt2D) SampleForProofAxis( +func (eds *Rsmt2D) SampleForProofAxis( rowIdx, colIdx int, proofType rsmt2d.Axis, ) (shwap.Sample, error) { @@ -61,7 +62,7 @@ func (eds Rsmt2D) SampleForProofAxis( } // AxisHalf returns Shares for the first half of the axis of the given type and index. -func (eds Rsmt2D) AxisHalf(_ context.Context, axisType rsmt2d.Axis, axisIdx int) (AxisHalf, error) { +func (eds *Rsmt2D) AxisHalf(_ context.Context, axisType rsmt2d.Axis, axisIdx int) (AxisHalf, error) { shares := getAxis(eds.ExtendedDataSquare, axisType, axisIdx) halfShares := shares[:eds.Width()/2] return AxisHalf{ @@ -72,13 +73,13 @@ func (eds Rsmt2D) AxisHalf(_ context.Context, axisType rsmt2d.Axis, axisIdx int) // HalfRow constructs a new shwap.Row from an Extended Data Square based on the specified index and // side. -func (eds Rsmt2D) HalfRow(idx int, side shwap.RowSide) shwap.Row { +func (eds *Rsmt2D) HalfRow(idx int, side shwap.RowSide) shwap.Row { shares := eds.ExtendedDataSquare.Row(uint(idx)) return shwap.RowFromShares(shares, side) } // RowNamespaceData returns data for the given namespace and row index. -func (eds Rsmt2D) RowNamespaceData( +func (eds *Rsmt2D) RowNamespaceData( _ context.Context, namespace share.Namespace, rowIdx int, @@ -89,10 +90,35 @@ func (eds Rsmt2D) RowNamespaceData( // Shares returns data (ODS) shares extracted from the EDS. It returns new copy of the shares each // time. -func (eds Rsmt2D) Shares(_ context.Context) ([]share.Share, error) { +func (eds *Rsmt2D) Shares(_ context.Context) ([]share.Share, error) { return eds.ExtendedDataSquare.FlattenedODS(), nil } +func (eds *Rsmt2D) Close() error { + return nil +} + +// Reader returns binary reader for the file. +func (eds *Rsmt2D) Reader() (io.Reader, error) { + getShare := func(rowIdx, colIdx int) ([]byte, error) { + return eds.GetCell(uint(rowIdx), uint(colIdx)), nil + } + odsSize := int(eds.Width() / 2) + reader := NewSharesReader(odsSize, getShare) + return reader, nil +} + +// Rsmt2DFromShares constructs an Extended Data Square from shares. +func Rsmt2DFromShares(shares []share.Share, odsSize int) (Rsmt2D, error) { + treeFn := wrapper.NewConstructor(uint64(odsSize)) + eds, err := rsmt2d.ComputeExtendedDataSquare(shares, share.DefaultRSMT2DCodec(), treeFn) + if err != nil { + return Rsmt2D{}, fmt.Errorf("computing extended data square: %w", err) + } + + return Rsmt2D{eds}, nil +} + func getAxis(eds *rsmt2d.ExtendedDataSquare, axisType rsmt2d.Axis, axisIdx int) []share.Share { switch axisType { case rsmt2d.Row: diff --git a/share/new_eds/rsmt2d_test.go b/share/new_eds/rsmt2d_test.go index eafcb607ae..ef4fdae7bc 100644 --- a/share/new_eds/rsmt2d_test.go +++ b/share/new_eds/rsmt2d_test.go @@ -14,15 +14,20 @@ import ( "github.com/celestiaorg/celestia-node/share/shwap" ) -func TestMemFile(t *testing.T) { - odsSize := 8 +func TestRsmt2dAccessor(t *testing.T) { + odsSize := 16 newAccessor := func(tb testing.TB, eds *rsmt2d.ExtendedDataSquare) Accessor { return &Rsmt2D{ExtendedDataSquare: eds} } - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) t.Cleanup(cancel) TestSuiteAccessor(ctx, t, newAccessor, odsSize) + + newStreamer := func(tb testing.TB, eds *rsmt2d.ExtendedDataSquare) AccessorStreamer { + return &Rsmt2D{ExtendedDataSquare: eds} + } + TestStreamer(ctx, t, newStreamer, odsSize) } func TestRsmt2dHalfRow(t *testing.T) { diff --git a/share/new_eds/testing.go b/share/new_eds/testing.go index 3a1a3ad4a2..cbfa81f219 100644 --- a/share/new_eds/testing.go +++ b/share/new_eds/testing.go @@ -18,29 +18,49 @@ import ( "github.com/celestiaorg/celestia-node/share/shwap" ) -type createAccessor func(testing.TB, *rsmt2d.ExtendedDataSquare) Accessor +type ( + createAccessor func(testing.TB, *rsmt2d.ExtendedDataSquare) Accessor + createAccessorStreamer func(testing.TB, *rsmt2d.ExtendedDataSquare) AccessorStreamer +) // TestSuiteAccessor runs a suite of tests for the given Accessor implementation. func TestSuiteAccessor( ctx context.Context, t *testing.T, createAccessor createAccessor, - odsSize int, + maxSize int, ) { - t.Run("Sample", func(t *testing.T) { - testAccessorSample(ctx, t, createAccessor, odsSize) - }) - - t.Run("AxisHalf", func(t *testing.T) { - testAccessorAxisHalf(ctx, t, createAccessor, odsSize) - }) - - t.Run("RowNamespaceData", func(t *testing.T) { - testAccessorRowNamespaceData(ctx, t, createAccessor, odsSize) - }) + minSize := 2 + if !checkPowerOfTwo(maxSize) { + t.Errorf("minSize must be power of 2: %v", maxSize) + } + for size := minSize; size <= maxSize; size *= 2 { + t.Run(fmt.Sprintf("Sample:%d", size), func(t *testing.T) { + testAccessorSample(ctx, t, createAccessor, size) + }) + + t.Run(fmt.Sprintf("AxisHalf:%d", size), func(t *testing.T) { + testAccessorAxisHalf(ctx, t, createAccessor, size) + }) + + t.Run(fmt.Sprintf("RowNamespaceData:%d", size), func(t *testing.T) { + testAccessorRowNamespaceData(ctx, t, createAccessor, size) + }) + + t.Run(fmt.Sprintf("Shares:%d", size), func(t *testing.T) { + testAccessorShares(ctx, t, createAccessor, size) + }) + } +} - t.Run("Shares", func(t *testing.T) { - testAccessorShares(ctx, t, createAccessor, odsSize) +func TestStreamer( + ctx context.Context, + t *testing.T, + create createAccessorStreamer, + odsSize int, +) { + t.Run("Reader", func(t *testing.T) { + testAccessorReader(ctx, t, create, odsSize) }) } @@ -232,6 +252,40 @@ func testAccessorShares( require.Equal(t, expected, shares) } +func testAccessorReader( + ctx context.Context, + t *testing.T, + create createAccessorStreamer, + odsSize int, +) { + eds := edstest.RandEDS(t, odsSize) + f := create(t, eds) + + // verify that the reader represented by file can be read from + // multiple times, without exhausting the underlying reader. + wg := sync.WaitGroup{} + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + testReader(ctx, t, eds, f) + }() + } + wg.Wait() +} + +func testReader(ctx context.Context, t *testing.T, eds *rsmt2d.ExtendedDataSquare, as AccessorStreamer) { + reader, err := as.Reader() + require.NoError(t, err) + + odsSize := as.Size(ctx) / 2 + shares, err := ReadShares(reader, share.Size, odsSize) + require.NoError(t, err) + actual, err := Rsmt2DFromShares(shares, odsSize) + require.NoError(t, err) + require.True(t, eds.Equals(actual.ExtendedDataSquare)) +} + func BenchGetHalfAxisFromAccessor( ctx context.Context, b *testing.B, @@ -302,3 +356,11 @@ func (q quadrant) coordinates(edsSize int) (rowIdx, colIdx int) { rowIdx = edsSize/2*(int(q-1)/2) + 1 return rowIdx, colIdx } + +func checkPowerOfTwo(n int) bool { + // added one corner case if n is zero it will also consider as power 2 + if n == 0 { + return true + } + return n&(n-1) == 0 +} diff --git a/share/new_eds/validation_test.go b/share/new_eds/validation_test.go index ad4e6f2efc..98445c94ac 100644 --- a/share/new_eds/validation_test.go +++ b/share/new_eds/validation_test.go @@ -32,7 +32,7 @@ func TestValidation_Sample(t *testing.T) { t.Run(tt.name, func(t *testing.T) { randEDS := edstest.RandEDS(t, tt.odsSize) accessor := &Rsmt2D{ExtendedDataSquare: randEDS} - validation := WithValidation(WithCloser(accessor, nil)) + validation := WithValidation(AccessorAndStreamer(accessor, nil)) _, err := validation.Sample(context.Background(), tt.rowIdx, tt.colIdx) if tt.expectFail { @@ -61,7 +61,7 @@ func TestValidation_AxisHalf(t *testing.T) { t.Run(tt.name, func(t *testing.T) { randEDS := edstest.RandEDS(t, tt.odsSize) accessor := &Rsmt2D{ExtendedDataSquare: randEDS} - validation := WithValidation(WithCloser(accessor, nil)) + validation := WithValidation(AccessorAndStreamer(accessor, nil)) _, err := validation.AxisHalf(context.Background(), tt.axisType, tt.axisIdx) if tt.expectFail { @@ -89,7 +89,7 @@ func TestValidation_RowNamespaceData(t *testing.T) { t.Run(tt.name, func(t *testing.T) { randEDS := edstest.RandEDS(t, tt.odsSize) accessor := &Rsmt2D{ExtendedDataSquare: randEDS} - validation := WithValidation(WithCloser(accessor, nil)) + validation := WithValidation(AccessorAndStreamer(accessor, nil)) ns := sharetest.RandV0Namespace() _, err := validation.RowNamespaceData(context.Background(), ns, tt.rowIdx) diff --git a/share/shwap/p2p/bitswap/block_fetch.go b/share/shwap/p2p/bitswap/block_fetch.go index 4a66eb72ef..7b4543fb26 100644 --- a/share/shwap/p2p/bitswap/block_fetch.go +++ b/share/shwap/p2p/bitswap/block_fetch.go @@ -52,8 +52,10 @@ func Fetch(ctx context.Context, exchg exchange.Interface, root *share.Root, blks } // maxPerFetch sets the limit for maximum items in a single fetch. -// This limit comes from server side default limit size on max possible simultaneous CID WANTs from a peer. -// https://github.com/ipfs/boxo/blob/dfd4a53ba828a368cec8d61c3fe12969ac6aa94c/bitswap/internal/defaults/defaults.go#L29-L30 +// This limit comes from server side default limit size on max possible simultaneous CID WANTs from +// a peer. +// +//https:github.com/ipfs/boxo/blob/dfd4a53ba828a368cec8d61c3fe12969ac6aa94c/bitswap/internal/defaults/defaults.go#L29-L30 const maxPerFetch = 1024 // fetch fetches given Blocks. diff --git a/share/shwap/p2p/bitswap/block_fetch_test.go b/share/shwap/p2p/bitswap/block_fetch_test.go index 1c8aa367e1..59e9384e65 100644 --- a/share/shwap/p2p/bitswap/block_fetch_test.go +++ b/share/shwap/p2p/bitswap/block_fetch_test.go @@ -117,7 +117,7 @@ func TestFetch_Duplicates(t *testing.T) { func newExchangeOverEDS(ctx context.Context, t *testing.T, rsmt2d *rsmt2d.ExtendedDataSquare) exchange.SessionExchange { bstore := &Blockstore{ Getter: testAccessorGetter{ - Accessor: eds.Rsmt2D{ExtendedDataSquare: rsmt2d}, + Accessor: &eds.Rsmt2D{ExtendedDataSquare: rsmt2d}, }, } return newExchange(ctx, t, bstore) diff --git a/share/shwap/p2p/bitswap/block_store.go b/share/shwap/p2p/bitswap/block_store.go index ffc6b0ede7..d42b40704e 100644 --- a/share/shwap/p2p/bitswap/block_store.go +++ b/share/shwap/p2p/bitswap/block_store.go @@ -10,7 +10,8 @@ import ( eds "github.com/celestiaorg/celestia-node/share/new_eds" ) -// AccessorGetter abstracts storage system that indexes and manages multiple eds.AccessorGetter by network height. +// AccessorGetter abstracts storage system that indexes and manages multiple eds.AccessorGetter by +// network height. type AccessorGetter interface { // GetByHeight returns an Accessor by its height. GetByHeight(ctx context.Context, height uint64) (eds.Accessor, error) diff --git a/share/shwap/row_namespace_data_test.go b/share/shwap/row_namespace_data_test.go index 07f434b87e..19f15ef7a6 100644 --- a/share/shwap/row_namespace_data_test.go +++ b/share/shwap/row_namespace_data_test.go @@ -65,7 +65,7 @@ func TestValidateNamespacedRow(t *testing.T) { namespace := sharetest.RandV0Namespace() for amount := 1; amount < sharesAmount; amount++ { randEDS, root := edstest.RandEDSWithNamespace(t, namespace, amount, odsSize) - rsmt2d := eds.Rsmt2D{ExtendedDataSquare: randEDS} + rsmt2d := &eds.Rsmt2D{ExtendedDataSquare: randEDS} nd, err := eds.NamespacedData(ctx, root, rsmt2d, namespace) require.NoError(t, err) require.True(t, len(nd) > 0) @@ -87,7 +87,7 @@ func TestNamespacedRowProtoEncoding(t *testing.T) { const odsSize = 8 namespace := sharetest.RandV0Namespace() randEDS, root := edstest.RandEDSWithNamespace(t, namespace, odsSize, odsSize) - rsmt2d := eds.Rsmt2D{ExtendedDataSquare: randEDS} + rsmt2d := &eds.Rsmt2D{ExtendedDataSquare: randEDS} nd, err := eds.NamespacedData(ctx, root, rsmt2d, namespace) require.NoError(t, err) require.True(t, len(nd) > 0) diff --git a/store/file/ods.go b/store/file/ods.go index 9c16549898..20e966e9d4 100644 --- a/store/file/ods.go +++ b/store/file/ods.go @@ -14,7 +14,7 @@ import ( "github.com/celestiaorg/celestia-node/share/shwap" ) -var _ eds.AccessorCloser = (*ODSFile)(nil) +var _ eds.AccessorStreamer = (*ODSFile)(nil) type ODSFile struct { path string @@ -191,11 +191,27 @@ func (f *ODSFile) Shares(context.Context) ([]share.Share, error) { return ods.shares() } +// Reader returns binary reader for the file. It reads the shares from the ODS part of the square +// row by row. +func (f *ODSFile) Reader() (io.Reader, error) { + f.lock.RLock() + ods := f.ods + f.lock.RUnlock() + if ods != nil { + return ods.reader() + } + + offset := int64(f.hdr.Size()) + total := int64(f.hdr.shareSize) * int64(f.size()*f.size()/4) + reader := io.NewSectionReader(f.fl, offset, total) + return reader, nil +} + func (f *ODSFile) readAxisHalf(axisType rsmt2d.Axis, axisIdx int) (eds.AxisHalf, error) { f.lock.RLock() - ODS := f.ods + ods := f.ods f.lock.RUnlock() - if ODS != nil { + if ods != nil { return f.ods.axisHalf(axisType, axisIdx) } @@ -217,10 +233,11 @@ func (f *ODSFile) readAxisHalf(axisType rsmt2d.Axis, axisIdx int) (eds.AxisHalf, } func (f *ODSFile) readODS() (square, error) { - f.lock.Lock() - defer f.lock.Unlock() - if f.ods != nil { - return f.ods, nil + f.lock.RLock() + ods := f.ods + f.lock.RUnlock() + if ods != nil { + return ods, nil } // reset file pointer to the beginning of the file shares data @@ -235,7 +252,9 @@ func (f *ODSFile) readODS() (square, error) { } if !f.disableCache { + f.lock.Lock() f.ods = square + f.lock.Unlock() } return square, nil } diff --git a/store/file/ods_test.go b/store/file/ods_test.go index ee61c84281..307689092c 100644 --- a/store/file/ods_test.go +++ b/store/file/ods_test.go @@ -58,11 +58,13 @@ func TestReadODSFromFile(t *testing.T) { } func TestODSFile(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) t.Cleanup(cancel) - ODSSize := 8 - eds.TestSuiteAccessor(ctx, t, createODSFile, ODSSize) + ODSSize := 16 + eds.TestSuiteAccessor(ctx, t, createAccessor, ODSSize) + eds.TestStreamer(ctx, t, createCachedStreamer, ODSSize) + eds.TestStreamer(ctx, t, createStreamer, ODSSize) } // BenchmarkAxisFromODSFile/Size:32/ProofType:row/squareHalf:0-10 460231 2555 ns/op @@ -161,7 +163,22 @@ func BenchmarkSampleFromODSFileDisabledCache(b *testing.B) { eds.BenchGetSampleFromAccessor(ctx, b, newFile, minSize, maxSize) } -func createODSFile(t testing.TB, eds *rsmt2d.ExtendedDataSquare) eds.Accessor { +func createAccessor(t testing.TB, eds *rsmt2d.ExtendedDataSquare) eds.Accessor { + return createODSFile(t, eds) +} + +func createStreamer(t testing.TB, eds *rsmt2d.ExtendedDataSquare) eds.AccessorStreamer { + return createODSFile(t, eds) +} + +func createCachedStreamer(t testing.TB, eds *rsmt2d.ExtendedDataSquare) eds.AccessorStreamer { + f := createODSFile(t, eds) + _, err := f.readODS() + require.NoError(t, err) + return f +} + +func createODSFile(t testing.TB, eds *rsmt2d.ExtendedDataSquare) *ODSFile { path := t.TempDir() + "/" + strconv.Itoa(rand.Intn(1000)) fl, err := CreateODSFile(path, []byte{}, eds) require.NoError(t, err) diff --git a/store/file/q1q4_file.go b/store/file/q1q4_file.go index b23bd5f9b4..8b0bed86a9 100644 --- a/store/file/q1q4_file.go +++ b/store/file/q1q4_file.go @@ -12,7 +12,7 @@ import ( "github.com/celestiaorg/celestia-node/share/shwap" ) -var _ eds.AccessorCloser = (*Q1Q4File)(nil) +var _ eds.AccessorStreamer = (*Q1Q4File)(nil) // Q1Q4File represents a file that contains the first and fourth quadrants of an extended data // square. It extends the ODSFile with the ability to read the fourth quadrant of the square. @@ -98,6 +98,10 @@ func (f *Q1Q4File) Shares(ctx context.Context) ([]share.Share, error) { return f.ods.Shares(ctx) } +func (f *Q1Q4File) Reader() (io.Reader, error) { + return f.ods.Reader() +} + func (f *Q1Q4File) Close() error { return f.ods.Close() } diff --git a/store/file/q1q4_file_test.go b/store/file/q1q4_file_test.go index daa0c4cb87..d739ad40ba 100644 --- a/store/file/q1q4_file_test.go +++ b/store/file/q1q4_file_test.go @@ -41,10 +41,10 @@ func TestCreateQ1Q4File(t *testing.T) { } func TestQ1Q4File(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) t.Cleanup(cancel) - ODSSize := 8 + ODSSize := 16 eds.TestSuiteAccessor(ctx, t, createQ1Q4File, ODSSize) } diff --git a/store/file/square.go b/store/file/square.go index 85a0e5aa9e..cd0b8ea5aa 100644 --- a/store/file/square.go +++ b/store/file/square.go @@ -1,7 +1,6 @@ package file import ( - "bufio" "fmt" "io" @@ -21,29 +20,26 @@ type square [][]share.Share func readSquare(r io.Reader, shareSize, edsSize int) (square, error) { odsLn := edsSize / 2 + shares, err := eds.ReadShares(r, shareSize, odsLn) + if err != nil { + return nil, fmt.Errorf("reading shares: %w", err) + } square := make(square, odsLn) for i := range square { - square[i] = make([]share.Share, odsLn) - for j := range square[i] { - square[i][j] = make(share.Share, shareSize) - } + square[i] = shares[i*odsLn : (i+1)*odsLn] } + return square, nil +} - br := bufio.NewReaderSize(r, 4096) - var total int - for i := 0; i < odsLn; i++ { - for j := 0; j < odsLn; j++ { - n, err := io.ReadFull(br, square[i][j]) - if err != nil { - return nil, fmt.Errorf("reading share: %w, bytes read: %v", err, total+n) - } - if n != shareSize { - return nil, fmt.Errorf("share size mismatch: expected %v, got %v", shareSize, n) - } - total += n - } +func (s square) reader() (io.Reader, error) { + if s == nil { + return nil, fmt.Errorf("ods file not cached") } - return square, nil + getShare := func(rowIdx, colIdx int) ([]byte, error) { + return s[rowIdx][colIdx], nil + } + reader := eds.NewSharesReader(s.size(), getShare) + return reader, nil } func (s square) size() int {