Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(shwap):Add eds streaming #3531

Merged
merged 10 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions share/new_eds/accessor.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,24 @@ type Accessor interface {
Shares(ctx context.Context) ([]share.Share, error)
}

// AccessorCloser is an interface that groups Accessor and io.Closer interfaces.
type AccessorCloser interface {
// AccessorStreamer is an interface that groups Accessor and Streamer interfaces.
type AccessorStreamer interface {
Accessor
Streamer
}

type Streamer interface {
// Reader returns binary reader for the file (ODS) shares. It should read the shares from the
// ODS part of the square row by row.
Reader() (io.Reader, error)
io.Closer
}

type accessorCloser struct {
type accessorStreamer struct {
Accessor
io.Closer
Streamer
}

func WithCloser(a Accessor, c io.Closer) AccessorCloser {
return &accessorCloser{a, c}
func WithStreamer(a Accessor, s Streamer) AccessorStreamer {
return &accessorStreamer{a, s}
}
14 changes: 11 additions & 3 deletions share/new_eds/close_once.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package eds
import (
"context"
"errors"
"io"
"sync/atomic"

"github.com/celestiaorg/rsmt2d"
Expand All @@ -11,16 +12,16 @@ import (
"github.com/celestiaorg/celestia-node/share/shwap"
)

var _ AccessorCloser = (*closeOnce)(nil)
var _ AccessorStreamer = (*closeOnce)(nil)

var errAccessorClosed = errors.New("accessor is closed")

type closeOnce struct {
f AccessorCloser
f AccessorStreamer
closed atomic.Bool
}

func WithClosedOnce(f AccessorCloser) AccessorCloser {
func WithClosedOnce(f AccessorStreamer) AccessorStreamer {
return &closeOnce{f: f}
}

Expand Down Expand Up @@ -76,3 +77,10 @@ func (c *closeOnce) Shares(ctx context.Context) ([]share.Share, error) {
}
return c.f.Shares(ctx)
}

func (c *closeOnce) Reader() (io.Reader, error) {
if c.closed.Load() {
return nil, errAccessorClosed
}
return c.f.Reader()
}
6 changes: 6 additions & 0 deletions share/new_eds/close_once_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package eds

import (
"context"
"io"
"testing"
"testing/iotest"

"github.com/stretchr/testify/require"

Expand Down Expand Up @@ -68,6 +70,10 @@ func (s *stubEdsAccessorCloser) Shares(context.Context) ([]share.Share, error) {
return nil, nil
}

func (s *stubEdsAccessorCloser) Reader() (io.Reader, error) {
return iotest.ErrReader(nil), nil
}

func (s *stubEdsAccessorCloser) Close() error {
s.closed = true
return nil
Expand Down
2 changes: 1 addition & 1 deletion share/new_eds/nd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func TestNamespacedData(t *testing.T) {
namespace := sharetest.RandV0Namespace()
for amount := 1; amount < sharesAmount; amount++ {
eds, root := edstest.RandEDSWithNamespace(t, namespace, amount, odsSize)
rsmt2d := Rsmt2D{ExtendedDataSquare: eds}
rsmt2d := &Rsmt2D{ExtendedDataSquare: eds}
nd, err := NamespacedData(ctx, root, rsmt2d, namespace)
require.NoError(t, err)
require.True(t, len(nd) > 0)
Expand Down
52 changes: 42 additions & 10 deletions share/new_eds/proofs_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"context"
"errors"
"fmt"
"io"
"sync"
"sync/atomic"

Expand All @@ -20,13 +21,13 @@
"github.com/celestiaorg/celestia-node/share/shwap"
)

var _ Accessor = (*proofsCache)(nil)
var _ AccessorStreamer = (*proofsCache)(nil)

// proofsCache is eds accessor that caches proofs for rows and columns. It also caches extended
// axis Shares. It is used to speed up the process of building proofs for rows and columns,
// reducing the number of reads from the underlying accessor.
type proofsCache struct {
inner Accessor
inner AccessorStreamer

// lock protects axisCache
lock sync.RWMutex
Expand Down Expand Up @@ -57,7 +58,7 @@
// WithProofsCache creates a new eds accessor with caching of proofs for rows and columns. It is
// used to speed up the process of building proofs for rows and columns, reducing the number of
// reads from the underlying accessor.
func WithProofsCache(ac Accessor) Accessor {
func WithProofsCache(ac AccessorStreamer) AccessorStreamer {
rows := make(map[int]axisWithProofs)
cols := make(map[int]axisWithProofs)
axisCache := []map[int]axisWithProofs{rows, cols}
Expand Down Expand Up @@ -211,21 +212,28 @@
return shares, nil
}

func (c *proofsCache) Reader() (io.Reader, error) {
odsSize := c.Size(context.TODO()) / 2
reader := NewSharesReader(odsSize, c.getShare)
return reader, nil
}

func (c *proofsCache) Close() error {
return c.inner.Close()
}

func (c *proofsCache) axisShares(ctx context.Context, axisType rsmt2d.Axis, axisIdx int) ([]share.Share, error) {
ax, ok := c.getAxisFromCache(axisType, axisIdx)
if ok && ax.shares != nil {
return ax.shares, nil
}

if len(ax.half.Shares) == 0 {
half, err := c.AxisHalf(ctx, axisType, axisIdx)
if err != nil {
return nil, err
}
ax.half = half
half, err := c.AxisHalf(ctx, axisType, axisIdx)
if err != nil {
return nil, err
}

shares, err := ax.half.Extended()
Comment on lines -220 to -228
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reason for this change? I don't follow.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It caches half axis independently inside AxisHalf call

shares, err := half.Extended()
if err != nil {
return nil, fmt.Errorf("extending shares: %w", err)
}
Expand All @@ -250,6 +258,30 @@
return ax, ok
}

func (c *proofsCache) getShare(rowIdx, colIdx int) ([]byte, error) {
ctx := context.TODO()
odsSize := c.Size(ctx) / 2
half, err := c.AxisHalf(ctx, rsmt2d.Row, rowIdx)
if err != nil {
return nil, fmt.Errorf("reading axis half: %w", err)
}

// if share is from the same side of axis return share right away
if colIdx > odsSize == half.IsParity {
if half.IsParity {
colIdx = colIdx - odsSize

Check failure on line 272 in share/new_eds/proofs_cache.go

View workflow job for this annotation

GitHub Actions / go-ci / Lint

assignOp: replace `colIdx = colIdx - odsSize` with `colIdx -= odsSize` (gocritic)
}
return half.Shares[colIdx], nil
}

// if share index is from opposite part of axis, obtain full axis shares
shares, err := c.axisShares(ctx, rsmt2d.Row, rowIdx)
if err != nil {
return nil, fmt.Errorf("reading axis shares: %w", err)
}
return shares[colIdx], nil
}

// rowProofsGetter implements blockservice.BlockGetter interface
type rowProofsGetter struct {
proofs map[cid.Cid]blocks.Block
Expand Down
11 changes: 8 additions & 3 deletions share/new_eds/proofs_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,19 @@ import (
)

func TestCache(t *testing.T) {
size := 8
ODSSize := 8
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
t.Cleanup(cancel)

withProofsCache := func(tb testing.TB, inner *rsmt2d.ExtendedDataSquare) Accessor {
newAccessor := func(tb testing.TB, inner *rsmt2d.ExtendedDataSquare) Accessor {
accessor := &Rsmt2D{ExtendedDataSquare: inner}
return WithProofsCache(accessor)
}
TestSuiteAccessor(ctx, t, newAccessor, ODSSize)

TestSuiteAccessor(ctx, t, withProofsCache, size)
newAccessorStreamer := func(tb testing.TB, inner *rsmt2d.ExtendedDataSquare) AccessorStreamer {
accessor := &Rsmt2D{ExtendedDataSquare: inner}
return WithProofsCache(accessor)
}
TestStreamer(ctx, t, newAccessorStreamer, ODSSize)
}
97 changes: 97 additions & 0 deletions share/new_eds/reader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package eds

import (
"bytes"
"errors"
"fmt"
"io"

"github.com/celestiaorg/celestia-node/share"
)

func NewSharesReader(odsSize int, getShare func(rowIdx, colIdx int) ([]byte, error)) *BufferedReader {
return &BufferedReader{
getShare: getShare,
buf: bytes.NewBuffer(nil),
odsSize: odsSize,
total: odsSize * odsSize,
}
}

// BufferedReader will read Shares from inMemOds into the buffer.
// It exposes the buffer to be read by io.Reader interface implementation
type BufferedReader struct {
buf *bytes.Buffer
getShare func(rowIdx, colIdx int) ([]byte, error)
// current is the amount of Shares stored in square that have been written by squareCopy. When
// current reaches total, squareCopy will prevent further reads by returning io.EOF
current, odsSize, total int
}

func (r *BufferedReader) Read(p []byte) (int, error) {
if r.current >= r.total && r.buf.Len() == 0 {
return 0, io.EOF
}
// if provided array is smaller than data in buf, read from buf
if len(p) <= r.buf.Len() {
return r.buf.Read(p)
}
n, err := io.ReadFull(r.buf, p)
if err == nil {
return n, nil
}
if !errors.Is(err, io.ErrUnexpectedEOF) && !errors.Is(err, io.EOF) {
return n, fmt.Errorf("unexpected error reading from buf: %w", err)
}

written := n
for r.current < r.total {
rowIdx, colIdx := r.current/r.odsSize, r.current%r.odsSize
share, err := r.getShare(rowIdx, colIdx)
if err != nil {
return 0, fmt.Errorf("get share; %w", err)
}

// copy share to provided buffer
emptySpace := len(p) - written
r.current++
if len(share) < emptySpace {
n := copy(p[written:], share)
written += n
continue
}

// if share didn't fit into buffer fully, store remaining bytes into inner buf
n := copy(p[written:], share[:emptySpace])
written += n
n, err = r.buf.Write(share[emptySpace:])
if err != nil {
return 0, fmt.Errorf("write share to inner buffer: %w", err)
}
if n != len(share)-emptySpace {
return 0, fmt.Errorf("share was not written fully: %w", io.ErrShortWrite)
}
return written, nil
}
return written, nil
}

// ReadShares reads shares from the provided reader and constructs an Extended Data Square. Provided
// reader should contain shares in row-major order.
func ReadShares(r io.Reader, shareSize, odsSize int) ([]share.Share, error) {
shares := make([]share.Share, odsSize*odsSize)
var total int
for i := range shares {
share := make(share.Share, shareSize)
n, err := io.ReadFull(r, share)
if err != nil {
return nil, fmt.Errorf("reading share: %w, bytes read: %v", err, total+n)
}
if n != shareSize {
return nil, fmt.Errorf("share size mismatch: expected %v, got %v", shareSize, n)
}
shares[i] = share
total += n
}
return shares, nil
}
64 changes: 64 additions & 0 deletions share/new_eds/reader_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package eds

import (
"errors"
"fmt"
"io"
"math/rand"
"testing"

"github.com/stretchr/testify/require"

"github.com/celestiaorg/celestia-node/share"
"github.com/celestiaorg/celestia-node/share/eds/edstest"
)

func TestSharesReaderMany(t *testing.T) {
// create io.Writer that write random data
for i := 0; i < 10000; i++ {
TestSharesReader(t)
}
}

func TestSharesReader(t *testing.T) {
// create io.Writer that write random data
odsSize := 16
eds := edstest.RandEDS(t, odsSize)
getShare := func(rowIdx, colIdx int) ([]byte, error) {
fmt.Println("get", rowIdx, colIdx)
return eds.GetCell(uint(rowIdx), uint(colIdx)), nil
}

reader := NewSharesReader(odsSize, getShare)
readBytes, err := readWithRandomBuffer(reader, 1024)
require.NoError(t, err)
expected := make([]byte, 0, odsSize*odsSize*share.Size)
for _, share := range eds.FlattenedODS() {
expected = append(expected, share...)
}
require.Len(t, readBytes, len(expected))
require.Equal(t, expected, readBytes)
}

// testRandReader reads from reader with buffers of random sizes.
func readWithRandomBuffer(reader io.Reader, maxBufSize int) ([]byte, error) {
// create buffer of random size
data := make([]byte, 0, maxBufSize)
for {
bufSize := rand.Intn(maxBufSize-1) + 1
buf := make([]byte, bufSize)
n, err := reader.Read(buf)
if err != nil && !errors.Is(err, io.EOF) {
return nil, err
}
if n < bufSize {
buf = buf[:n]
}
data = append(data, buf...)
if errors.Is(err, io.EOF) {
fmt.Println("eof?")
break
}
}
return data, nil
}
Loading
Loading