diff --git a/share/eds/cache/noop.go b/share/eds/cache/noop.go index 8e1c17924a..95a199c345 100644 --- a/share/eds/cache/noop.go +++ b/share/eds/cache/noop.go @@ -38,6 +38,7 @@ var _ Accessor = (*NoopAccessor)(nil) type NoopAccessor struct{} func (n NoopAccessor) Blockstore() (dagstore.ReadBlockstore, error) { + //nolint:nilnil return nil, nil } diff --git a/share/new_eds/rsmt2d_test.go b/share/new_eds/rsmt2d_test.go index ae10d3b4be..5035242e1b 100644 --- a/share/new_eds/rsmt2d_test.go +++ b/share/new_eds/rsmt2d_test.go @@ -3,6 +3,7 @@ package eds import ( "context" "testing" + "time" "github.com/stretchr/testify/require" @@ -13,22 +14,32 @@ import ( "github.com/celestiaorg/celestia-node/share/shwap" ) -func TestRsmt2dSample(t *testing.T) { - eds, root := randRsmt2dAccsessor(t, 8) +func TestMemFile(t *testing.T) { + odsSize := 8 + newAccessor := func(eds *rsmt2d.ExtendedDataSquare) Accessor { + return &Rsmt2D{ExtendedDataSquare: eds} + } + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + t.Cleanup(cancel) - width := int(eds.Width()) - for rowIdx := 0; rowIdx < width; rowIdx++ { - for colIdx := 0; colIdx < width; colIdx++ { - shr, err := eds.Sample(context.TODO(), rowIdx, colIdx) - require.NoError(t, err) + t.Run("Sample", func(t *testing.T) { + TestAccessorSample(ctx, t, newAccessor, odsSize) + }) - err = shr.Validate(root, rowIdx, colIdx) - require.NoError(t, err) - } - } + t.Run("AxisHalf", func(t *testing.T) { + TestAccessorAxisHalf(ctx, t, newAccessor, odsSize) + }) + + t.Run("Data", func(t *testing.T) { + TestAccessorRowNamespaceData(ctx, t, newAccessor, odsSize) + }) + + t.Run("EDS", func(t *testing.T) { + TestAccessorEds(ctx, t, newAccessor, odsSize) + }) } -func TestRsmt2dHalfRowFrom(t *testing.T) { +func TestRsmt2dHalfRow(t *testing.T) { const odsSize = 8 eds, _ := randRsmt2dAccsessor(t, odsSize) diff --git a/share/new_eds/testing.go b/share/new_eds/testing.go new file mode 100644 index 0000000000..8248e0dcbd --- /dev/null +++ b/share/new_eds/testing.go @@ -0,0 +1,276 @@ +package eds + +import ( + "context" + "fmt" + "strconv" + "sync" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/celestiaorg/nmt" + "github.com/celestiaorg/rsmt2d" + + "github.com/celestiaorg/celestia-node/share" + "github.com/celestiaorg/celestia-node/share/eds/edstest" + "github.com/celestiaorg/celestia-node/share/sharetest" + "github.com/celestiaorg/celestia-node/share/shwap" +) + +type createAccessor func(eds *rsmt2d.ExtendedDataSquare) Accessor + +func TestAccessorSample( + ctx context.Context, + t *testing.T, + createAccessor createAccessor, + odsSize int, +) { + eds := edstest.RandEDS(t, odsSize) + fl := createAccessor(eds) + + dah, err := share.NewRoot(eds) + require.NoError(t, err) + + width := int(eds.Width()) + t.Run("single thread", func(t *testing.T) { + for rowIdx := 0; rowIdx < width; rowIdx++ { + for colIdx := 0; colIdx < width; colIdx++ { + testSample(ctx, t, fl, dah, colIdx, rowIdx) + } + } + }) + + t.Run("parallel", func(t *testing.T) { + wg := sync.WaitGroup{} + for rowIdx := 0; rowIdx < width; rowIdx++ { + for colIdx := 0; colIdx < width; colIdx++ { + wg.Add(1) + go func(rowIdx, colIdx int) { + defer wg.Done() + testSample(ctx, t, fl, dah, rowIdx, colIdx) + }(rowIdx, colIdx) + } + } + wg.Wait() + }) +} + +func testSample( + ctx context.Context, + t *testing.T, + fl Accessor, + dah *share.Root, + rowIdx, colIdx int, +) { + shr, err := fl.Sample(ctx, rowIdx, colIdx) + require.NoError(t, err) + + err = shr.Validate(dah, rowIdx, colIdx) + require.NoError(t, err) +} + +func TestAccessorRowNamespaceData( + ctx context.Context, + t *testing.T, + createAccessor createAccessor, + odsSize int, +) { + t.Run("included", func(t *testing.T) { + // generate EDS with random data and some Shares with the same namespace + sharesAmount := odsSize * odsSize + namespace := sharetest.RandV0Namespace() + // test with different amount of shares + for amount := 1; amount < sharesAmount; amount++ { + // select random amount of shares, but not less than 1 + eds, dah := edstest.RandEDSWithNamespace(t, namespace, amount, odsSize) + f := createAccessor(eds) + + var actualSharesAmount int + // loop over all rows and check that the amount of shares in the namespace is equal to the expected + // amount + for i, root := range dah.RowRoots { + rowData, err := f.RowNamespaceData(ctx, namespace, i) + + // namespace is not included in the row, so there should be no shares + if namespace.IsOutsideRange(root, root) { + require.ErrorIs(t, err, shwap.ErrNamespaceOutsideRange) + require.Len(t, rowData.Shares, 0) + continue + } + + actualSharesAmount += len(rowData.Shares) + require.NoError(t, err) + require.True(t, len(rowData.Shares) > 0) + err = rowData.Validate(dah, namespace, i) + require.NoError(t, err) + } + + // check that the amount of shares in the namespace is equal to the expected amount + require.Equal(t, amount, actualSharesAmount) + } + }) + + t.Run("not included", func(t *testing.T) { + // generate EDS with random data and some Shares with the same namespace + eds := edstest.RandEDS(t, odsSize) + dah, err := share.NewRoot(eds) + require.NoError(t, err) + + // loop over first half of the rows, because the second half is parity and does not contain + // namespaced shares + for i, root := range dah.RowRoots[:odsSize] { + // select namespace that within the range of root namespaces, but is not included + maxNs := nmt.MaxNamespace(root, share.NamespaceSize) + absentNs, err := share.Namespace(maxNs).AddInt(-1) + require.NoError(t, err) + + f := createAccessor(eds) + rowData, err := f.RowNamespaceData(ctx, absentNs, i) + require.NoError(t, err) + + // namespace is not included in the row, so there should be no shares + require.Len(t, rowData.Shares, 0) + require.True(t, rowData.Proof.IsOfAbsence()) + + err = rowData.Validate(dah, absentNs, i) + require.NoError(t, err) + } + }) +} + +func TestAccessorAxisHalf( + ctx context.Context, + t *testing.T, + createAccessor createAccessor, + odsSize int, +) { + eds := edstest.RandEDS(t, odsSize) + fl := createAccessor(eds) + + t.Run("single thread", func(t *testing.T) { + for _, axisType := range []rsmt2d.Axis{rsmt2d.Col, rsmt2d.Row} { + for axisIdx := 0; axisIdx < int(eds.Width()); axisIdx++ { + half, err := fl.AxisHalf(ctx, axisType, axisIdx) + require.NoError(t, err) + require.Len(t, half.Shares, odsSize) + + var expected []share.Share + if half.IsParity { + expected = getAxis(eds, axisType, axisIdx)[odsSize:] + } else { + expected = getAxis(eds, axisType, axisIdx)[:odsSize] + } + + require.Equal(t, expected, half.Shares) + } + } + }) + + t.Run("parallel", func(t *testing.T) { + wg := sync.WaitGroup{} + for _, axisType := range []rsmt2d.Axis{rsmt2d.Col, rsmt2d.Row} { + for i := 0; i < int(eds.Width()); i++ { + wg.Add(1) + go func(axisType rsmt2d.Axis, idx int) { + defer wg.Done() + half, err := fl.AxisHalf(ctx, axisType, idx) + require.NoError(t, err) + require.Len(t, half.Shares, odsSize) + + var expected []share.Share + if half.IsParity { + expected = getAxis(eds, axisType, idx)[odsSize:] + } else { + expected = getAxis(eds, axisType, idx)[:odsSize] + } + + require.Equal(t, expected, half.Shares) + }(axisType, i) + } + } + wg.Wait() + }) +} + +func TestAccessorEds( + ctx context.Context, + t *testing.T, + createAccessor createAccessor, + odsSize int, +) { + eds := edstest.RandEDS(t, odsSize) + fl := createAccessor(eds) + + shares, err := fl.Shares(ctx) + require.NoError(t, err) + expected := eds.Flattened() + require.Equal(t, expected, shares) +} + +func BenchGetHalfAxisFromAccessor( + ctx context.Context, + b *testing.B, + newAccessor func(size int) Accessor, + minOdsSize, maxOdsSize int, +) { + for size := minOdsSize; size <= maxOdsSize; size *= 2 { + f := newAccessor(size) + + // loop over all possible axis types and quadrants + for _, axisType := range []rsmt2d.Axis{rsmt2d.Row, rsmt2d.Col} { + for _, squareHalf := range []int{0, 1} { + name := fmt.Sprintf("Size:%v/ProofType:%s/squareHalf:%s", size, axisType, strconv.Itoa(squareHalf)) + b.Run(name, func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := f.AxisHalf(ctx, axisType, f.Size(ctx)/2*(squareHalf)) + require.NoError(b, err) + } + }) + } + } + } +} + +func BenchGetSampleFromAccessor( + ctx context.Context, + b *testing.B, + newAccessor func(size int) Accessor, + minOdsSize, maxOdsSize int, +) { + for size := minOdsSize; size <= maxOdsSize; size *= 2 { + f := newAccessor(size) + + // loop over all possible axis types and quadrants + for _, q := range quadrants { + name := fmt.Sprintf("Size:%v/quadrant:%s", size, q) + b.Run(name, func(b *testing.B) { + rowIdx, colIdx := q.coordinates(f.Size(ctx)) + // warm up cache + _, err := f.Sample(ctx, rowIdx, colIdx) + require.NoError(b, err, q.String()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := f.Sample(ctx, rowIdx, colIdx) + require.NoError(b, err) + } + }) + } + } +} + +type quadrant int + +var quadrants = []quadrant{1, 2, 3, 4} + +func (q quadrant) String() string { + return strconv.Itoa(int(q)) +} + +func (q quadrant) coordinates(edsSize int) (rowIdx, colIdx int) { + colIdx = edsSize/2*(int(q-1)%2) + 1 + rowIdx = edsSize/2*(int(q-1)/2) + 1 + return rowIdx, colIdx +} diff --git a/share/shwap/row_namespace_data.go b/share/shwap/row_namespace_data.go index 2bbf4c06fa..fe7bd6da36 100644 --- a/share/shwap/row_namespace_data.go +++ b/share/shwap/row_namespace_data.go @@ -1,8 +1,10 @@ package shwap import ( + "errors" "fmt" + "github.com/celestiaorg/celestia-app/pkg/appconsts" "github.com/celestiaorg/celestia-app/pkg/wrapper" "github.com/celestiaorg/nmt" nmt_pb "github.com/celestiaorg/nmt/pb" @@ -11,6 +13,11 @@ import ( "github.com/celestiaorg/celestia-node/share/shwap/pb" ) +// ErrNamespaceOutsideRange is returned by RowNamespaceDataFromShares when the target namespace is +// outside of the namespace range for the given row. In this case, the implementation cannot return +// the non-inclusion proof and will return ErrNamespaceOutsideRange. +var ErrNamespaceOutsideRange = errors.New("target namespace is outside of namespace range for the given root") + // RowNamespaceData holds shares and their corresponding proof for a single row within a namespace. type RowNamespaceData struct { Shares []share.Share `json:"shares"` // Shares within the namespace. @@ -24,6 +31,28 @@ func RowNamespaceDataFromShares( namespace share.Namespace, rowIndex int, ) (RowNamespaceData, error) { + tree := wrapper.NewErasuredNamespacedMerkleTree(uint64(len(shares)/2), uint(rowIndex)) + nmtTree := nmt.New( + appconsts.NewBaseHashFunc(), + nmt.NamespaceIDSize(appconsts.NamespaceSize), + nmt.IgnoreMaxNamespace(true), + ) + tree.SetTree(nmtTree) + + for _, shr := range shares { + if err := tree.Push(shr); err != nil { + return RowNamespaceData{}, fmt.Errorf("failed to build tree for row %d: %w", rowIndex, err) + } + } + + root, err := tree.Root() + if err != nil { + return RowNamespaceData{}, fmt.Errorf("failed to get root for row %d: %w", rowIndex, err) + } + if namespace.IsOutsideRange(root, root) { + return RowNamespaceData{}, ErrNamespaceOutsideRange + } + var from, count int for i := range len(shares) / 2 { if namespace.Equals(share.GetNamespace(shares[i])) { @@ -37,22 +66,22 @@ func RowNamespaceDataFromShares( break } } + + // if count is 0, then the namespace is not present in the shares. Return non-inclusion proof. if count == 0 { - // FIXME: This should return Non-inclusion proofs instead. Need support in app wrapper to generate - // absence proofs. - return RowNamespaceData{}, fmt.Errorf("no shares found in the namespace for row %d", rowIndex) + proof, err := nmtTree.ProveNamespace(namespace.ToNMT()) + if err != nil { + return RowNamespaceData{}, fmt.Errorf("failed to generate non-inclusion proof for row %d: %w", rowIndex, err) + } + + return RowNamespaceData{ + Proof: &proof, + }, nil } namespacedShares := make([]share.Share, count) copy(namespacedShares, shares[from:from+count]) - tree := wrapper.NewErasuredNamespacedMerkleTree(uint64(len(shares)/2), uint(rowIndex)) - for _, shr := range shares { - if err := tree.Push(shr); err != nil { - return RowNamespaceData{}, fmt.Errorf("failed to build tree for row %d: %w", rowIndex, err) - } - } - proof, err := tree.ProveRange(from, from+count) if err != nil { return RowNamespaceData{}, fmt.Errorf("failed to generate proof for row %d: %w", rowIndex, err)