Skip to content

Commit

Permalink
refactor(share): GetShare -> GetSamples
Browse files Browse the repository at this point in the history
  • Loading branch information
Wondertan authored and cristaloleg committed Nov 8, 2024
1 parent 353141f commit 9024839
Show file tree
Hide file tree
Showing 22 changed files with 269 additions and 197 deletions.
7 changes: 3 additions & 4 deletions blob/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -902,10 +902,9 @@ func createService(ctx context.Context, t testing.TB, shares []libshare.Share) *
nd, err := eds.NamespaceData(ctx, accessor, ns)
return nd, err
})
shareGetter.EXPECT().GetShare(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().
DoAndReturn(func(ctx context.Context, h *header.ExtendedHeader, row, col int) (libshare.Share, error) {
s, err := accessor.Sample(ctx, row, col)
return s.Share, err
shareGetter.EXPECT().GetSamples(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().
DoAndReturn(func(ctx context.Context, h *header.ExtendedHeader, indices []shwap.SampleIndex) ([]shwap.Sample, error) {
return smpls, nil
})

// create header and put it into the store
Expand Down
30 changes: 15 additions & 15 deletions nodebuilder/share/mocks/api.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 13 additions & 1 deletion nodebuilder/share/share.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,24 @@ type module struct {
hs headerServ.Module
}

// TODO(@Wondertan): break
func (m module) GetShare(ctx context.Context, height uint64, row, col int) (libshare.Share, error) {
header, err := m.hs.GetByHeight(ctx, height)
if err != nil {
return libshare.Share{}, err
}
return m.getter.GetShare(ctx, header, row, col)

idx, err := shwap.SampleIndexFromCoordinates(row, col, len(header.DAH.RowRoots))
if err != nil {
return libshare.Share{}, err
}

smpls, err := m.getter.GetSamples(ctx, header, []shwap.SampleIndex{idx})
if err != nil {
return libshare.Share{}, err
}

return smpls[0].Share, nil
}

func (m module) GetEDS(ctx context.Context, height uint64) (*rsmt2d.ExtendedDataSquare, error) {
Expand Down
70 changes: 24 additions & 46 deletions share/availability/light/availability.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"errors"
"fmt"
"sync"
"time"

"github.com/ipfs/boxo/blockstore"
"github.com/ipfs/go-datastore"
Expand Down Expand Up @@ -114,59 +113,34 @@ func (la *ShareAvailability) SharesAvailable(ctx context.Context, header *header
return nil
}

var (
mutex sync.Mutex
failedSamples []Sample
wg sync.WaitGroup
)

log.Debugw("starting sampling session", "height", header.Height())

// remove one second from the deadline to ensure we have enough time to process the results
samplingCtx, cancel := context.WithCancel(ctx)
if deadline, ok := ctx.Deadline(); ok {
samplingCtx, cancel = context.WithDeadline(ctx, deadline.Add(-time.Second))
}
defer cancel()

// Concurrently sample shares
for _, s := range samples.Remaining {
wg.Add(1)
go func(s Sample) {
defer wg.Done()
_, err := la.getter.GetShare(samplingCtx, header, s.Row, s.Col)
mutex.Lock()
defer mutex.Unlock()
if err != nil {
log.Debugw("error fetching share", "height", header.Height(), "row", s.Row, "col", s.Col)
failedSamples = append(failedSamples, s)
} else {
samples.Available = append(samples.Available, s)
}
}(s)
}
wg.Wait()
log.Debugw("starting sampling session", "root", dah.String())

// Update remaining samples with failed ones
samples.Remaining = failedSamples
idxs := make([]shwap.SampleIndex, len(samples.Available))
for i, s := range samples.Available {
idx, err := shwap.SampleIndexFromCoordinates(int(s.Row), int(s.Col), len(dah.RowRoots))
if err != nil {
return err
}

// Store the updated sampling result
updatedData, err := json.Marshal(samples)
if err != nil {
return err
}
la.dsLk.Lock()
err = la.ds.Put(ctx, key, updatedData)
la.dsLk.Unlock()
if err != nil {
return fmt.Errorf("store sampling result: %w", err)
idxs[i] = idx
}

smpls, err := la.getter.GetSamples(ctx, header, idxs)
if errors.Is(ctx.Err(), context.Canceled) {
// Availability did not complete due to context cancellation, return context error instead of
// share.ErrNotAvailable
return ctx.Err()
}
if len(smpls) == 0 {
return share.ErrNotAvailable
}

var failedSamples []Sample
for i, smpl := range smpls {
if smpl.IsEmpty() {
failedSamples = append(failedSamples, samples.Available[i])
}
}

// if any of the samples failed, return an error
if len(failedSamples) > 0 {
Expand Down Expand Up @@ -210,7 +184,11 @@ func (la *ShareAvailability) Prune(ctx context.Context, h *header.ExtendedHeader

// delete stored samples
for _, sample := range result.Available {
blk, err := bitswap.NewEmptySampleBlock(h.Height(), sample.Row, sample.Col, len(h.DAH.RowRoots))
idx, err := shwap.SampleIndexFromCoordinates(sample.Row, sample.Col, len(h.DAH.RowRoots))
if err != nil {
return err
}
blk, err := bitswap.NewEmptySampleBlock(h.Height(), idx, len(h.DAH.RowRoots))
if err != nil {
return fmt.Errorf("marshal sample ID: %w", err)
}
Expand Down
66 changes: 44 additions & 22 deletions share/availability/light/availability_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ import (
"github.com/stretchr/testify/require"

libshare "github.com/celestiaorg/go-square/v2/share"
"github.com/celestiaorg/nmt"
"github.com/celestiaorg/rsmt2d"

"github.com/celestiaorg/celestia-node/header"
"github.com/celestiaorg/celestia-node/header/headertest"
"github.com/celestiaorg/celestia-node/share"
"github.com/celestiaorg/celestia-node/share/eds"
"github.com/celestiaorg/celestia-node/share/eds/edstest"
"github.com/celestiaorg/celestia-node/share/shwap"
"github.com/celestiaorg/celestia-node/share/shwap/getters/mock"
Expand All @@ -38,22 +40,32 @@ func TestSharesAvailableSuccess(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

eds := edstest.RandEDS(t, 16)
roots, err := share.NewAxisRoots(eds)
square := edstest.RandEDS(t, 16)
roots, err := share.NewAxisRoots(square)
require.NoError(t, err)
eh := headertest.RandExtendedHeaderWithRoot(t, roots)

getter := mock.NewMockGetter(gomock.NewController(t))
getter.EXPECT().
GetShare(gomock.Any(), eh, gomock.Any(), gomock.Any()).
GetSamples(gomock.Any(), eh, gomock.Any()).
DoAndReturn(
func(_ context.Context, _ *header.ExtendedHeader, row, col int) (libshare.Share, error) {
rawSh := eds.GetCell(uint(row), uint(col))
sh, err := libshare.NewShare(rawSh)
if err != nil {
return libshare.Share{}, err
func(_ context.Context, hdr *header.ExtendedHeader, indices []shwap.SampleIndex) ([]shwap.Sample, error) {
acc := eds.Rsmt2D{ExtendedDataSquare: square}
smpls := make([]shwap.Sample, len(indices))
for i, idx := range indices {
rowIdx, colIdx, err := idx.Coordinates(len(hdr.DAH.RowRoots))
if err != nil {
return nil, err
}

smpl, err := acc.Sample(ctx, rowIdx, colIdx)
if err != nil {
return nil, err
}

smpls[i] = smpl
}
return *sh, nil
return smpls, nil
}).
AnyTimes()

Expand Down Expand Up @@ -87,8 +99,8 @@ func TestSharesAvailableSkipSampled(t *testing.T) {
// Create a getter that always returns ErrNotFound
getter := mock.NewMockGetter(gomock.NewController(t))
getter.EXPECT().
GetShare(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(libshare.Share{}, shrex.ErrNotFound).
GetSamples(gomock.Any(), gomock.Any(), gomock.Any()).
Return(nil, shrex.ErrNotFound).
AnyTimes()

ds := datastore.NewMapDatastore()
Expand Down Expand Up @@ -147,8 +159,8 @@ func TestSharesAvailableFailed(t *testing.T) {

// Getter doesn't have the eds, so it should fail for all samples
getter.EXPECT().
GetShare(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(libshare.Share{}, shrex.ErrNotFound).
GetSamples(gomock.Any(), gomock.Any(), gomock.Any()).
Return(make([]shwap.Sample, avail.params.SampleAmount), shrex.ErrNotFound).
AnyTimes()
err = avail.SharesAvailable(ctx, eh)
require.ErrorIs(t, err, share.ErrNotAvailable)
Expand Down Expand Up @@ -185,7 +197,7 @@ func TestSharesAvailableFailed(t *testing.T) {

// onceGetter should have no more samples stored after the call
successfulGetter.checkOnce(t)
require.ElementsMatch(t, failed.Remaining, successfulGetter.sampledList())
require.ElementsMatch(t, failed.Remaining, len(successfulGetter.sampled))
}

func TestParallelAvailability(t *testing.T) {
Expand Down Expand Up @@ -213,7 +225,7 @@ func TestParallelAvailability(t *testing.T) {
}()
}
wg.Wait()
require.Len(t, successfulGetter.sampledList(), int(avail.params.SampleAmount))
require.Len(t, len(successfulGetter.sampled), int(avail.params.SampleAmount))

// Verify that the sampling result is stored with all samples marked as available
resultData, err := avail.ds.Get(ctx, datastoreKeyForRoot(roots))
Expand Down Expand Up @@ -249,14 +261,24 @@ func (g onceGetter) checkOnce(t *testing.T) {
}
}

func (g onceGetter) sampledList() []Sample {
g.Lock()
defer g.Unlock()
samples := make([]Sample, 0, len(g.sampled))
for s := range g.sampled {
samples = append(samples, s)
func (m onceGetter) GetSamples(_ context.Context, hdr *header.ExtendedHeader, indices []shwap.SampleIndex) ([]shwap.Sample, error) {
m.Lock()
defer m.Unlock()

smpls := make([]shwap.Sample, 0, len(indices))
for _, idx := range indices {
rowIdx, colIdx, err := idx.Coordinates(len(hdr.DAH.RowRoots))
if err != nil {
return nil, err
}

s := Sample{Row: rowIdx, Col: colIdx}
if _, ok := m.sampled[s]; ok {
delete(m.sampled, s)
smpls = append(smpls, shwap.Sample{Proof: &nmt.Proof{}})
}
}
return samples
return smpls, nil
}

func (g onceGetter) GetShare(_ context.Context, _ *header.ExtendedHeader, row, col int) (libshare.Share, error) {
Expand Down
1 change: 1 addition & 0 deletions share/eds/accessor.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ type Accessor interface {
// Sample returns share and corresponding proof for row and column indices. Implementation can
// choose which axis to use for proof. Chosen axis for proof should be indicated in the returned
// Sample.
// TODO(@Wondertan): change to SampleIndex
Sample(ctx context.Context, rowIdx, colIdx int) (shwap.Sample, error)
// AxisHalf returns half of shares axis of the given type and index. Side is determined by
// implementation. Implementations should indicate the side in the returned AxisHalf.
Expand Down
7 changes: 6 additions & 1 deletion share/eds/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,12 @@ func (f validation) Size(ctx context.Context) int {
}

func (f validation) Sample(ctx context.Context, rowIdx, colIdx int) (shwap.Sample, error) {
_, err := shwap.NewSampleID(1, rowIdx, colIdx, f.Size(ctx))
idx, err := shwap.SampleIndexFromCoordinates(rowIdx, colIdx, f.Size(ctx))
if err != nil {
return shwap.Sample{}, err
}

_, err = shwap.NewSampleID(1, idx, f.Size(ctx))
if err != nil {
return shwap.Sample{}, fmt.Errorf("sample validation: %w", err)
}
Expand Down
6 changes: 4 additions & 2 deletions share/shwap/getter.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@ var (
//
//go:generate mockgen -destination=getters/mock/getter.go -package=mock . Getter
type Getter interface {
// GetShare gets a Share by coordinates in EDS.
GetShare(ctx context.Context, header *header.ExtendedHeader, row, col int) (libshare.Share, error)
// GetSamples gets samples by their indices.
// Returns Sample slice with requested number of samples in the requested order.
// May return partial response with some samples being empty if they weren't found.
GetSamples(ctx context.Context, header *header.ExtendedHeader, indices []SampleIndex) ([]Sample, error)

// GetEDS gets the full EDS identified by the given extended header.
GetEDS(context.Context, *header.ExtendedHeader) (*rsmt2d.ExtendedDataSquare, error)
Expand Down
21 changes: 6 additions & 15 deletions share/shwap/getters/cascade.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,24 +41,15 @@ func NewCascadeGetter(getters []shwap.Getter) *CascadeGetter {
}
}

// GetShare gets a share from any of registered shwap.Getters in cascading order.
func (cg *CascadeGetter) GetShare(
ctx context.Context, header *header.ExtendedHeader, row, col int,
) (libshare.Share, error) {
ctx, span := tracer.Start(ctx, "cascade/get-share", trace.WithAttributes(
attribute.Int("row", row),
attribute.Int("col", col),
// GetSamples gets samples from any of registered shwap.Getters in cascading order.
func (cg *CascadeGetter) GetSamples(ctx context.Context, hdr *header.ExtendedHeader, indices []shwap.SampleIndex) ([]shwap.Sample, error) {
ctx, span := tracer.Start(ctx, "cascade/get-samples", trace.WithAttributes(
attribute.Int("amount", len(indices)),
))
defer span.End()

upperBound := len(header.DAH.RowRoots)
if row >= upperBound || col >= upperBound {
err := shwap.ErrOutOfBounds
span.RecordError(err)
return libshare.Share{}, err
}
get := func(ctx context.Context, get shwap.Getter) (libshare.Share, error) {
return get.GetShare(ctx, header, row, col)
get := func(ctx context.Context, get shwap.Getter) ([]shwap.Sample, error) {
return get.GetSamples(ctx, hdr, indices)
}

return cascadeGetters(ctx, cg.getters, get)
Expand Down
2 changes: 1 addition & 1 deletion share/shwap/getters/cascade_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func TestCascadeGetter(t *testing.T) {
getter := NewCascadeGetter(getters)
t.Run("GetShare", func(t *testing.T) {
for _, eh := range headers {
sh, err := getter.GetShare(ctx, eh, 0, 0)
sh, err := getter.GetSamples(ctx, eh, []shwap.SampleIndex{0})
assert.NoError(t, err)
assert.NotEmpty(t, sh)
}
Expand Down
Loading

0 comments on commit 9024839

Please sign in to comment.