Skip to content

Commit

Permalink
feat(shwap):Add eds streaming (celestiaorg#3531)
Browse files Browse the repository at this point in the history
  • Loading branch information
walldiss committed Jul 6, 2024
1 parent 807e335 commit 660246d
Show file tree
Hide file tree
Showing 21 changed files with 433 additions and 92 deletions.
19 changes: 13 additions & 6 deletions share/new_eds/accessor.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,24 @@ type Accessor interface {
Shares(ctx context.Context) ([]share.Share, error)
}

// AccessorCloser is an interface that groups Accessor and io.Closer interfaces.
type AccessorCloser interface {
// AccessorStreamer is an interface that groups Accessor and Streamer interfaces.
type AccessorStreamer interface {
Accessor
Streamer
}

type Streamer interface {
// Reader returns binary reader for the shares. It should read the shares from the
// ODS part of the square row by row.
Reader() (io.Reader, error)
io.Closer
}

type accessorCloser struct {
type accessorStreamer struct {
Accessor
io.Closer
Streamer
}

func WithCloser(a Accessor, c io.Closer) AccessorCloser {
return &accessorCloser{a, c}
func AccessorAndStreamer(a Accessor, s Streamer) AccessorStreamer {
return &accessorStreamer{a, s}
}
14 changes: 11 additions & 3 deletions share/new_eds/close_once.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package eds
import (
"context"
"errors"
"io"
"sync/atomic"

"github.com/celestiaorg/rsmt2d"
Expand All @@ -11,16 +12,16 @@ import (
"github.com/celestiaorg/celestia-node/share/shwap"
)

var _ AccessorCloser = (*closeOnce)(nil)
var _ AccessorStreamer = (*closeOnce)(nil)

var errAccessorClosed = errors.New("accessor is closed")

type closeOnce struct {
f AccessorCloser
f AccessorStreamer
closed atomic.Bool
}

func WithClosedOnce(f AccessorCloser) AccessorCloser {
func WithClosedOnce(f AccessorStreamer) AccessorStreamer {
return &closeOnce{f: f}
}

Expand Down Expand Up @@ -76,3 +77,10 @@ func (c *closeOnce) Shares(ctx context.Context) ([]share.Share, error) {
}
return c.f.Shares(ctx)
}

func (c *closeOnce) Reader() (io.Reader, error) {
if c.closed.Load() {
return nil, errAccessorClosed
}
return c.f.Reader()
}
6 changes: 6 additions & 0 deletions share/new_eds/close_once_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package eds

import (
"context"
"io"
"testing"
"testing/iotest"

"github.com/stretchr/testify/require"

Expand Down Expand Up @@ -68,6 +70,10 @@ func (s *stubEdsAccessorCloser) Shares(context.Context) ([]share.Share, error) {
return nil, nil
}

func (s *stubEdsAccessorCloser) Reader() (io.Reader, error) {
return iotest.ErrReader(nil), nil
}

func (s *stubEdsAccessorCloser) Close() error {
s.closed = true
return nil
Expand Down
2 changes: 1 addition & 1 deletion share/new_eds/nd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func TestNamespacedData(t *testing.T) {
namespace := sharetest.RandV0Namespace()
for amount := 1; amount < sharesAmount; amount++ {
eds, root := edstest.RandEDSWithNamespace(t, namespace, amount, odsSize)
rsmt2d := Rsmt2D{ExtendedDataSquare: eds}
rsmt2d := &Rsmt2D{ExtendedDataSquare: eds}
nd, err := NamespacedData(ctx, root, rsmt2d, namespace)
require.NoError(t, err)
require.True(t, len(nd) > 0)
Expand Down
52 changes: 42 additions & 10 deletions share/new_eds/proofs_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"io"
"sync"
"sync/atomic"

Expand All @@ -20,13 +21,13 @@ import (
"github.com/celestiaorg/celestia-node/share/shwap"
)

var _ Accessor = (*proofsCache)(nil)
var _ AccessorStreamer = (*proofsCache)(nil)

// proofsCache is eds accessor that caches proofs for rows and columns. It also caches extended
// axis Shares. It is used to speed up the process of building proofs for rows and columns,
// reducing the number of reads from the underlying accessor.
type proofsCache struct {
inner Accessor
inner AccessorStreamer

// lock protects axisCache
lock sync.RWMutex
Expand Down Expand Up @@ -57,7 +58,7 @@ type axisWithProofs struct {
// WithProofsCache creates a new eds accessor with caching of proofs for rows and columns. It is
// used to speed up the process of building proofs for rows and columns, reducing the number of
// reads from the underlying accessor.
func WithProofsCache(ac Accessor) Accessor {
func WithProofsCache(ac AccessorStreamer) AccessorStreamer {
rows := make(map[int]axisWithProofs)
cols := make(map[int]axisWithProofs)
axisCache := []map[int]axisWithProofs{rows, cols}
Expand Down Expand Up @@ -211,21 +212,28 @@ func (c *proofsCache) Shares(ctx context.Context) ([]share.Share, error) {
return shares, nil
}

func (c *proofsCache) Reader() (io.Reader, error) {
odsSize := c.Size(context.TODO()) / 2
reader := NewSharesReader(odsSize, c.getShare)
return reader, nil
}

func (c *proofsCache) Close() error {
return c.inner.Close()
}

func (c *proofsCache) axisShares(ctx context.Context, axisType rsmt2d.Axis, axisIdx int) ([]share.Share, error) {
ax, ok := c.getAxisFromCache(axisType, axisIdx)
if ok && ax.shares != nil {
return ax.shares, nil
}

if len(ax.half.Shares) == 0 {
half, err := c.AxisHalf(ctx, axisType, axisIdx)
if err != nil {
return nil, err
}
ax.half = half
half, err := c.AxisHalf(ctx, axisType, axisIdx)
if err != nil {
return nil, err
}

shares, err := ax.half.Extended()
shares, err := half.Extended()
if err != nil {
return nil, fmt.Errorf("extending shares: %w", err)
}
Expand All @@ -250,6 +258,30 @@ func (c *proofsCache) getAxisFromCache(axisType rsmt2d.Axis, axisIdx int) (axisW
return ax, ok
}

func (c *proofsCache) getShare(rowIdx, colIdx int) ([]byte, error) {
ctx := context.TODO()
odsSize := c.Size(ctx) / 2
half, err := c.AxisHalf(ctx, rsmt2d.Row, rowIdx)
if err != nil {
return nil, fmt.Errorf("reading axis half: %w", err)
}

// if share is from the same side of axis return share right away
if colIdx > odsSize == half.IsParity {
if half.IsParity {
colIdx -= odsSize
}
return half.Shares[colIdx], nil
}

// if share index is from opposite part of axis, obtain full axis shares
shares, err := c.axisShares(ctx, rsmt2d.Row, rowIdx)
if err != nil {
return nil, fmt.Errorf("reading axis shares: %w", err)
}
return shares[colIdx], nil
}

// rowProofsGetter implements blockservice.BlockGetter interface
type rowProofsGetter struct {
proofs map[cid.Cid]blocks.Block
Expand Down
13 changes: 9 additions & 4 deletions share/new_eds/proofs_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,19 @@ import (
)

func TestCache(t *testing.T) {
size := 8
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
ODSSize := 16
ctx, cancel := context.WithTimeout(context.Background(), time.Second*15)
t.Cleanup(cancel)

withProofsCache := func(tb testing.TB, inner *rsmt2d.ExtendedDataSquare) Accessor {
newAccessor := func(tb testing.TB, inner *rsmt2d.ExtendedDataSquare) Accessor {
accessor := &Rsmt2D{ExtendedDataSquare: inner}
return WithProofsCache(accessor)
}
TestSuiteAccessor(ctx, t, newAccessor, ODSSize)

TestSuiteAccessor(ctx, t, withProofsCache, size)
newAccessorStreamer := func(tb testing.TB, inner *rsmt2d.ExtendedDataSquare) AccessorStreamer {
accessor := &Rsmt2D{ExtendedDataSquare: inner}
return WithProofsCache(accessor)
}
TestStreamer(ctx, t, newAccessorStreamer, ODSSize)
}
97 changes: 97 additions & 0 deletions share/new_eds/reader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package eds

import (
"bytes"
"errors"
"fmt"
"io"

"github.com/celestiaorg/celestia-node/share"
)

// BufferedReader will read Shares from getShare function into the buffer.
// It exposes the buffer to be read by io.Reader interface implementation
type BufferedReader struct {
buf *bytes.Buffer
getShare func(rowIdx, colIdx int) ([]byte, error)
// current is the amount of Shares stored in square that have been written by squareCopy. When
// current reaches total, squareCopy will prevent further reads by returning io.EOF
current, odsSize, total int
}

func NewSharesReader(odsSize int, getShare func(rowIdx, colIdx int) ([]byte, error)) *BufferedReader {
return &BufferedReader{
getShare: getShare,
buf: bytes.NewBuffer(nil),
odsSize: odsSize,
total: odsSize * odsSize,
}
}

func (r *BufferedReader) Read(p []byte) (int, error) {
if r.current >= r.total && r.buf.Len() == 0 {
return 0, io.EOF
}
// if provided array is smaller than data in buf, read from buf
if len(p) <= r.buf.Len() {
return r.buf.Read(p)
}
n, err := io.ReadFull(r.buf, p)
if err == nil {
return n, nil
}
if !errors.Is(err, io.ErrUnexpectedEOF) && !errors.Is(err, io.EOF) {
return n, fmt.Errorf("unexpected error reading from buf: %w", err)
}

written := n
for r.current < r.total {
rowIdx, colIdx := r.current/r.odsSize, r.current%r.odsSize
share, err := r.getShare(rowIdx, colIdx)
if err != nil {
return 0, fmt.Errorf("get share: %w", err)
}

// copy share to provided buffer
emptySpace := len(p) - written
r.current++
if len(share) < emptySpace {
n := copy(p[written:], share)
written += n
continue
}

// if share didn't fit into buffer fully, store remaining bytes into inner buf
n := copy(p[written:], share[:emptySpace])
written += n
n, err = r.buf.Write(share[emptySpace:])
if err != nil {
return 0, fmt.Errorf("write share to inner buffer: %w", err)
}
if n != len(share)-emptySpace {
return 0, fmt.Errorf("share was not written fully: %w", io.ErrShortWrite)
}
return written, nil
}
return written, nil
}

// ReadShares reads shares from the provided reader and constructs an Extended Data Square. Provided
// reader should contain shares in row-major order.
func ReadShares(r io.Reader, shareSize, odsSize int) ([]share.Share, error) {
shares := make([]share.Share, odsSize*odsSize)
var total int
for i := range shares {
share := make(share.Share, shareSize)
n, err := io.ReadFull(r, share)
if err != nil {
return nil, fmt.Errorf("reading share: %w, bytes read: %v", err, total+n)
}
if n != shareSize {
return nil, fmt.Errorf("share size mismatch: expected %v, got %v", shareSize, n)
}
shares[i] = share
total += n
}
return shares, nil
}
54 changes: 54 additions & 0 deletions share/new_eds/reader_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package eds

import (
"errors"
"io"
"math/rand"
"testing"

"github.com/stretchr/testify/require"

"github.com/celestiaorg/celestia-node/share"
"github.com/celestiaorg/celestia-node/share/eds/edstest"
)

func TestSharesReader(t *testing.T) {
// create io.Writer that write random data
odsSize := 16
eds := edstest.RandEDS(t, odsSize)
getShare := func(rowIdx, colIdx int) ([]byte, error) {
return eds.GetCell(uint(rowIdx), uint(colIdx)), nil
}

reader := NewSharesReader(odsSize, getShare)
readBytes, err := readWithRandomBuffer(reader, 1024)
require.NoError(t, err)
expected := make([]byte, 0, odsSize*odsSize*share.Size)
for _, share := range eds.FlattenedODS() {
expected = append(expected, share...)
}
require.Len(t, readBytes, len(expected))
require.Equal(t, expected, readBytes)
}

// testRandReader reads from reader with buffers of random sizes.
func readWithRandomBuffer(reader io.Reader, maxBufSize int) ([]byte, error) {
// create buffer of random size
data := make([]byte, 0, maxBufSize)
for {
bufSize := rand.Intn(maxBufSize-1) + 1
buf := make([]byte, bufSize)
n, err := reader.Read(buf)
if err != nil && !errors.Is(err, io.EOF) {
return nil, err
}
if n < bufSize {
buf = buf[:n]
}
data = append(data, buf...)
if errors.Is(err, io.EOF) {
break
}
}
return data, nil
}
Loading

0 comments on commit 660246d

Please sign in to comment.