Skip to content

Commit

Permalink
rework streaming and caching
Browse files Browse the repository at this point in the history
  • Loading branch information
walldiss committed Jul 1, 2024
1 parent 3909fca commit ed532f8
Show file tree
Hide file tree
Showing 14 changed files with 273 additions and 229 deletions.
22 changes: 13 additions & 9 deletions share/new_eds/accessor.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,26 @@ type Accessor interface {
RowNamespaceData(ctx context.Context, namespace share.Namespace, rowIdx int) (shwap.RowNamespaceData, error)
// Shares returns data (ODS) shares extracted from the Accessor.
Shares(ctx context.Context) ([]share.Share, error)
// Reader returns binary reader for the file (ODS) shares. It should read the shares from the
// ODS part of the square row by row.
Reader() (io.Reader, error)
}

// AccessorCloser is an interface that groups Accessor and io.Closer interfaces.
type AccessorCloser interface {
// AccessorStreamer is an interface that groups Accessor and Streamer interfaces.
type AccessorStreamer interface {
Accessor
Streamer
}

type Streamer interface {
// Reader returns binary reader for the file (ODS) shares. It should read the shares from the
// ODS part of the square row by row.
Reader() (io.Reader, error)
io.Closer
}

type accessorCloser struct {
type accessorStreamer struct {
Accessor
io.Closer
Streamer
}

func WithCloser(a Accessor, c io.Closer) AccessorCloser {
return &accessorCloser{a, c}
func WithStreamer(a Accessor, s Streamer) AccessorStreamer {
return &accessorStreamer{a, s}
}
6 changes: 3 additions & 3 deletions share/new_eds/close_once.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@ import (
"github.com/celestiaorg/celestia-node/share/shwap"
)

var _ AccessorCloser = (*closeOnce)(nil)
var _ AccessorStreamer = (*closeOnce)(nil)

var errAccessorClosed = errors.New("accessor is closed")

type closeOnce struct {
f AccessorCloser
f AccessorStreamer
closed atomic.Bool
}

func WithClosedOnce(f AccessorCloser) AccessorCloser {
func WithClosedOnce(f AccessorStreamer) AccessorStreamer {
return &closeOnce{f: f}
}

Expand Down
52 changes: 42 additions & 10 deletions share/new_eds/proofs_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"io"
"sync"
"sync/atomic"

Expand All @@ -20,13 +21,13 @@ import (
"github.com/celestiaorg/celestia-node/share/shwap"
)

var _ Accessor = (*proofsCache)(nil)
var _ AccessorStreamer = (*proofsCache)(nil)

// proofsCache is eds accessor that caches proofs for rows and columns. It also caches extended
// axis Shares. It is used to speed up the process of building proofs for rows and columns,
// reducing the number of reads from the underlying accessor.
type proofsCache struct {
inner Accessor
inner AccessorStreamer

// lock protects axisCache
lock sync.RWMutex
Expand Down Expand Up @@ -57,7 +58,7 @@ type axisWithProofs struct {
// WithProofsCache creates a new eds accessor with caching of proofs for rows and columns. It is
// used to speed up the process of building proofs for rows and columns, reducing the number of
// reads from the underlying accessor.
func WithProofsCache(ac Accessor) Accessor {
func WithProofsCache(ac AccessorStreamer) AccessorStreamer {
rows := make(map[int]axisWithProofs)
cols := make(map[int]axisWithProofs)
axisCache := []map[int]axisWithProofs{rows, cols}
Expand Down Expand Up @@ -211,21 +212,28 @@ func (c *proofsCache) Shares(ctx context.Context) ([]share.Share, error) {
return shares, nil
}

func (c *proofsCache) Reader() (io.Reader, error) {
odsSize := c.Size(context.TODO()) / 2
reader := NewSharesReader(odsSize, c.getShare)
return reader, nil
}

func (c *proofsCache) Close() error {
return c.inner.Close()
}

func (c *proofsCache) axisShares(ctx context.Context, axisType rsmt2d.Axis, axisIdx int) ([]share.Share, error) {
ax, ok := c.getAxisFromCache(axisType, axisIdx)
if ok && ax.shares != nil {
return ax.shares, nil
}

if len(ax.half.Shares) == 0 {
half, err := c.AxisHalf(ctx, axisType, axisIdx)
if err != nil {
return nil, err
}
ax.half = half
half, err := c.AxisHalf(ctx, axisType, axisIdx)
if err != nil {
return nil, err
}

shares, err := ax.half.Extended()
shares, err := half.Extended()
if err != nil {
return nil, fmt.Errorf("extending shares: %w", err)
}
Expand All @@ -250,6 +258,30 @@ func (c *proofsCache) getAxisFromCache(axisType rsmt2d.Axis, axisIdx int) (axisW
return ax, ok
}

func (c *proofsCache) getShare(rowIdx, colIdx int) ([]byte, error) {
ctx := context.TODO()
odsSize := c.Size(ctx) / 2
half, err := c.AxisHalf(ctx, rsmt2d.Row, rowIdx)
if err != nil {
return nil, fmt.Errorf("reading axis half: %w", err)
}

// if share is from the same side of axis return share right away
if colIdx > odsSize == half.IsParity {
if half.IsParity {
colIdx = colIdx - odsSize

Check failure on line 272 in share/new_eds/proofs_cache.go

View workflow job for this annotation

GitHub Actions / go-ci / Lint

assignOp: replace `colIdx = colIdx - odsSize` with `colIdx -= odsSize` (gocritic)
}
return half.Shares[colIdx], nil
}

// if share index is from opposite part of axis, obtain full axis shares
shares, err := c.axisShares(ctx, rsmt2d.Row, rowIdx)
if err != nil {
return nil, fmt.Errorf("reading axis shares: %w", err)
}
return shares[colIdx], nil
}

// rowProofsGetter implements blockservice.BlockGetter interface
type rowProofsGetter struct {
proofs map[cid.Cid]blocks.Block
Expand Down
11 changes: 8 additions & 3 deletions share/new_eds/proofs_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,19 @@ import (
)

func TestCache(t *testing.T) {
size := 8
ODSSize := 8
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
t.Cleanup(cancel)

withProofsCache := func(tb testing.TB, inner *rsmt2d.ExtendedDataSquare) Accessor {
newAccessor := func(tb testing.TB, inner *rsmt2d.ExtendedDataSquare) Accessor {
accessor := &Rsmt2D{ExtendedDataSquare: inner}
return WithProofsCache(accessor)
}
TestSuiteAccessor(ctx, t, newAccessor, ODSSize)

TestSuiteAccessor(ctx, t, withProofsCache, size)
newAccessorStreamer := func(tb testing.TB, inner *rsmt2d.ExtendedDataSquare) AccessorStreamer {
accessor := &Rsmt2D{ExtendedDataSquare: inner}
return WithProofsCache(accessor)
}
TestStreamer(ctx, t, newAccessorStreamer, ODSSize)
}
85 changes: 70 additions & 15 deletions share/new_eds/reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,40 +3,95 @@ package eds
import (
"bytes"
"errors"
"fmt"
"io"

"github.com/celestiaorg/celestia-node/share"
)

func NewBufferedReader(w minWriterTo) *BufferedReader {
func NewSharesReader(odsSize int, getShare func(rowIdx, colIdx int) ([]byte, error)) *BufferedReader {
return &BufferedReader{
w: w,
buf: bytes.NewBuffer(nil),
getShare: getShare,
buf: bytes.NewBuffer(nil),
odsSize: odsSize,
total: odsSize * odsSize,
}
}

// BufferedReader will read Shares from inMemOds into the buffer.
// It exposes the buffer to be read by io.Reader interface implementation
type BufferedReader struct {
w minWriterTo
buf *bytes.Buffer
buf *bytes.Buffer
getShare func(rowIdx, colIdx int) ([]byte, error)
// current is the amount of Shares stored in square that have been written by squareCopy. When
// current reaches total, squareCopy will prevent further reads by returning io.EOF
current, odsSize, total int
}

func (r *BufferedReader) Read(p []byte) (int, error) {
if r.current >= r.total && r.buf.Len() == 0 {
return 0, io.EOF
}
// if provided array is smaller than data in buf, read from buf
if len(p) <= r.buf.Len() {
return r.buf.Read(p)
}
n, err := io.ReadFull(r.buf, p)
if err == nil {
return n, nil
}
if !errors.Is(err, io.ErrUnexpectedEOF) && !errors.Is(err, io.EOF) {
return n, fmt.Errorf("unexpected error reading from buf: %w", err)
}

// fill the buffer with data from writer
min := len(p) - r.buf.Len()
n, err := r.w.WriteTo(r.buf, min)
if err != nil && !errors.Is(err, io.EOF) {
return n, err
written := n
for r.current < r.total {
rowIdx, colIdx := r.current/r.odsSize, r.current%r.odsSize
share, err := r.getShare(rowIdx, colIdx)
if err != nil {
return 0, fmt.Errorf("get share; %w", err)
}

// copy share to provided buffer
emptySpace := len(p) - written
r.current++
if len(share) < emptySpace {
n := copy(p[written:], share)
written += n
continue
}

// if share didn't fit into buffer fully, store remaining bytes into inner buf
n := copy(p[written:], share[:emptySpace])
written += n
n, err = r.buf.Write(share[emptySpace:])
if err != nil {
return 0, fmt.Errorf("write share to inner buffer: %w", err)
}
if n != len(share)-emptySpace {
return 0, fmt.Errorf("share was not written fully: %w", io.ErrShortWrite)
}
return written, nil
}
// save remaining buffer for next read
return r.buf.Read(p)
return written, nil
}

// minWriterTo writes to provided writer at least min amount of bytes
type minWriterTo interface {
WriteTo(writer io.Writer, minAmount int) (int, error)
// ReadShares reads shares from the provided reader and constructs an Extended Data Square. Provided
// reader should contain shares in row-major order.
func ReadShares(r io.Reader, shareSize, odsSize int) ([]share.Share, error) {
shares := make([]share.Share, odsSize*odsSize)
var total int
for i := range shares {
share := make(share.Share, shareSize)
n, err := io.ReadFull(r, share)
if err != nil {
return nil, fmt.Errorf("reading share: %w, bytes read: %v", err, total+n)
}
if n != shareSize {
return nil, fmt.Errorf("share size mismatch: expected %v, got %v", shareSize, n)
}
shares[i] = share
total += n
}
return shares, nil
}
64 changes: 21 additions & 43 deletions share/new_eds/reader_test.go
Original file line number Diff line number Diff line change
@@ -1,37 +1,42 @@
package eds

import (
"bytes"
crand "crypto/rand"
"errors"
"fmt"
"github.com/celestiaorg/celestia-node/share"

Check failure on line 6 in share/new_eds/reader_test.go

View workflow job for this annotation

GitHub Actions / go-ci / Lint

File is not `gofumpt`-ed with `-extra` (gofumpt)
"github.com/celestiaorg/celestia-node/share/eds/edstest"
"io"
"math/rand"
"testing"

Check failure on line 10 in share/new_eds/reader_test.go

View workflow job for this annotation

GitHub Actions / go-ci / Lint

File is not `gofumpt`-ed with `-extra` (gofumpt)

Check failure on line 11 in share/new_eds/reader_test.go

View workflow job for this annotation

GitHub Actions / go-ci / Lint

File is not `goimports`-ed with -local github.com/celestiaorg/celestia-node (goimports)
"github.com/stretchr/testify/require"
)

func TestNewBufferedReaderMany(t *testing.T) {
func TestSharesReaderMany(t *testing.T) {
// create io.Writer that write random data
for i := 0; i < 10000; i++ {
TestNewBufferedReader(t)
TestSharesReader(t)
}
}

func TestNewBufferedReader(t *testing.T) {
func TestSharesReader(t *testing.T) {
// create io.Writer that write random data
size := 200
randAmount := size + rand.Intn(size)
randBytes := make([]byte, randAmount)
_, err := crand.Read(randBytes)
require.NoError(t, err)

// randBytes := bytes.Repeat([]byte("1234567890"), 10)
odsSize := 16
eds := edstest.RandEDS(t, odsSize)
getShare := func(rowIdx, colIdx int) ([]byte, error) {
fmt.Println("get", rowIdx, colIdx)
return eds.GetCell(uint(rowIdx), uint(colIdx)), nil
}

reader := NewBufferedReader(randMinWriter{bytes.NewReader(randBytes)})
readBytes, err := readWithRandomBuffer(reader, size/10)
reader := NewSharesReader(odsSize, getShare)
readBytes, err := readWithRandomBuffer(reader, 1024)
require.NoError(t, err)
require.Equal(t, randBytes, readBytes)
expected := make([]byte, 0, odsSize*odsSize*share.Size)
for _, share := range eds.FlattenedODS() {
expected = append(expected, share...)
}
require.Len(t, readBytes, len(expected))
require.Equal(t, expected, readBytes)
}

// testRandReader reads from reader with buffers of random sizes.
Expand All @@ -50,36 +55,9 @@ func readWithRandomBuffer(reader io.Reader, maxBufSize int) ([]byte, error) {
}
data = append(data, buf...)
if errors.Is(err, io.EOF) {
fmt.Println("eof?")
break
}
}
return data, nil
}

type randMinWriter struct {
*bytes.Reader
}

func (lwt randMinWriter) WriteTo(writer io.Writer, limit int) (int, error) {
var amount int
for amount < limit {
bufLn := limit
if bufLn > 1 {
bufLn = rand.Intn(limit-1) + 1
}
buf := make([]byte, bufLn)
n, err := lwt.Read(buf)
if err != nil {
return amount, err
}
n, err = writer.Write(buf[:n])
amount += n
if err != nil {
return amount, err
}
if n < bufLn {
return amount, io.EOF
}
}
return amount, nil
}
Loading

0 comments on commit ed532f8

Please sign in to comment.