Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(shwap): Extract eds interface #3452

Merged
merged 11 commits into from
Jun 5, 2024
34 changes: 34 additions & 0 deletions share/new_eds/accessor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package eds

import (
"context"
"io"

"github.com/celestiaorg/rsmt2d"

"github.com/celestiaorg/celestia-node/share"
"github.com/celestiaorg/celestia-node/share/shwap"
)

// Accessor is an interface for accessing extended data square data.
type Accessor interface {
// Size returns square size of the Accessor.
Size(ctx context.Context) int
// Sample returns share and corresponding proof for row and column indices. Implementation can
// choose which axis to use for proof. Chosen axis for proof should be indicated in the returned
// Sample.
Sample(ctx context.Context, rowIdx, colIdx int) (shwap.Sample, error)
// AxisHalf returns half of shares axis of the given type and index. Side is determined by
// implementation. Implementations should indicate the side in the returned AxisHalf.
AxisHalf(ctx context.Context, axisType rsmt2d.Axis, axisIdx int) (AxisHalf, error)
// RowNamespaceData returns data for the given namespace and row index.
RowNamespaceData(ctx context.Context, namespace share.Namespace, rowIdx int) (shwap.RowNamespaceData, error)
// Shares returns data shares extracted from the Accessor.
Shares(ctx context.Context) ([]share.Share, error)
}

// AccessorCloser is an interface that groups Accessor and io.Closer interfaces.
type AccessorCloser interface {
Accessor
io.Closer
}
7 changes: 5 additions & 2 deletions share/store/file/axis_half.go → share/new_eds/axis_half.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
package file
package eds

import (
"github.com/celestiaorg/celestia-node/share"
"github.com/celestiaorg/celestia-node/share/shwap"
)

// AxisHalf represents a half of data for a row or column in the EDS.
type AxisHalf struct {
Shares []share.Share
Shares []share.Share
// IsParity indicates whether the half is parity or data.
IsParity bool
}

// ToRow converts the AxisHalf to a shwap.Row.
func (a AxisHalf) ToRow() shwap.Row {
side := shwap.Left
if a.IsParity {
Expand Down
31 changes: 31 additions & 0 deletions share/new_eds/nd.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package eds

import (
"context"
"fmt"

"github.com/celestiaorg/celestia-node/share"
"github.com/celestiaorg/celestia-node/share/shwap"
)

// NamespacedData extracts shares for a specific namespace from an EDS, considering
// each row independently. It uses root to determine which rows to extract data from,
// avoiding the need to recalculate the row roots for each row.
func NamespacedData(
ctx context.Context,
root *share.Root,
eds Accessor,
namespace share.Namespace,
) (shwap.NamespacedData, error) {
rowIdxs := share.RowsWithNamespace(root, namespace)
rows := make(shwap.NamespacedData, len(rowIdxs))
var err error
for i, idx := range rowIdxs {
rows[i], err = eds.RowNamespaceData(ctx, namespace, idx)
if err != nil {
return nil, fmt.Errorf("failed to process row %d: %w", idx, err)
}
}

return rows, nil
}
32 changes: 32 additions & 0 deletions share/new_eds/nd_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package eds

import (
"context"
"testing"
"time"

"github.com/stretchr/testify/require"

"github.com/celestiaorg/celestia-node/share/eds/edstest"
"github.com/celestiaorg/celestia-node/share/sharetest"
)

func TestNamespacedData(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
t.Cleanup(cancel)

const odsSize = 8
sharesAmount := odsSize * odsSize
namespace := sharetest.RandV0Namespace()
for amount := 1; amount < sharesAmount; amount++ {
eds, root := edstest.RandEDSWithNamespace(t, namespace, amount, odsSize)
rsmt2d := Rsmt2D{ExtendedDataSquare: eds}
nd, err := NamespacedData(ctx, root, rsmt2d, namespace)
require.NoError(t, err)
require.True(t, len(nd) > 0)
require.Len(t, nd.Flatten(), amount)

err = nd.Validate(root, namespace)
require.NoError(t, err)
}
}
116 changes: 116 additions & 0 deletions share/new_eds/rsmt2d.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package eds

import (
"context"
"fmt"

"github.com/celestiaorg/celestia-app/pkg/wrapper"
"github.com/celestiaorg/rsmt2d"

"github.com/celestiaorg/celestia-node/share"
"github.com/celestiaorg/celestia-node/share/shwap"
)

var _ Accessor = Rsmt2D{}

// Rsmt2D is a rsmt2d based in-memory implementation of Accessor.
type Rsmt2D struct {
*rsmt2d.ExtendedDataSquare
}

// Size returns the size of the Extended Data Square.
func (eds Rsmt2D) Size(context.Context) int {
return int(eds.Width())
}

// Sample returns share and corresponding proof for row and column indices.
func (eds Rsmt2D) Sample(
_ context.Context,
rowIdx, colIdx int,
) (shwap.Sample, error) {
return eds.SampleForProofAxis(rowIdx, colIdx, rsmt2d.Row)
}

// SampleForProofAxis samples a share from an Extended Data Square based on the provided
// row and column indices and proof axis. It returns a sample with the share and proof.
func (eds Rsmt2D) SampleForProofAxis(
rowIdx, colIdx int,
proofType rsmt2d.Axis,
) (shwap.Sample, error) {
axisIdx, shrIdx := relativeIndexes(rowIdx, colIdx, proofType)
shares := getAxis(eds.ExtendedDataSquare, proofType, axisIdx)

tree := wrapper.NewErasuredNamespacedMerkleTree(uint64(eds.Width()/2), uint(axisIdx))
for _, shr := range shares {
err := tree.Push(shr)
if err != nil {
return shwap.Sample{}, fmt.Errorf("while pushing shares to NMT: %w", err)
}
}

prf, err := tree.ProveRange(shrIdx, shrIdx+1)
if err != nil {
return shwap.Sample{}, fmt.Errorf("while proving range share over NMT: %w", err)
}

return shwap.Sample{
Share: shares[shrIdx],
Proof: &prf,
ProofType: proofType,
}, nil
}

// AxisHalf returns Shares for the first half of the axis of the given type and index.
func (eds Rsmt2D) AxisHalf(_ context.Context, axisType rsmt2d.Axis, axisIdx int) (AxisHalf, error) {
shares := getAxis(eds.ExtendedDataSquare, axisType, axisIdx)
halfShares := shares[:eds.Width()/2]
return AxisHalf{
Shares: halfShares,
IsParity: false,
}, nil
}

// HalfRow constructs a new shwap.Row from an Extended Data Square based on the specified index and
// side.
func (eds Rsmt2D) HalfRow(idx int, side shwap.RowSide) shwap.Row {
shares := eds.ExtendedDataSquare.Row(uint(idx))
return shwap.RowFromShares(shares, side)
}

// RowNamespaceData returns data for the given namespace and row index.
func (eds Rsmt2D) RowNamespaceData(
_ context.Context,
namespace share.Namespace,
rowIdx int,
) (shwap.RowNamespaceData, error) {
shares := eds.Row(uint(rowIdx))
return shwap.RowNamespaceDataFromShares(shares, namespace, rowIdx)
}

// Shares returns data shares extracted from the EDS. It returns new copy of the shares each
// time.
func (eds Rsmt2D) Shares(_ context.Context) ([]share.Share, error) {
return eds.ExtendedDataSquare.Flattened(), nil
}

func getAxis(eds *rsmt2d.ExtendedDataSquare, axisType rsmt2d.Axis, axisIdx int) []share.Share {
switch axisType {
case rsmt2d.Row:
return eds.Row(uint(axisIdx))
case rsmt2d.Col:
return eds.Col(uint(axisIdx))
default:
panic("unknown axis")
}
}

func relativeIndexes(rowIdx, colIdx int, axisType rsmt2d.Axis) (axisIdx, shrIdx int) {
switch axisType {
case rsmt2d.Row:
return rowIdx, colIdx
case rsmt2d.Col:
return colIdx, rowIdx
default:
panic(fmt.Sprintf("invalid proof type: %d", axisType))
}
}
74 changes: 74 additions & 0 deletions share/new_eds/rsmt2d_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package eds

import (
"context"
"testing"

"github.com/stretchr/testify/require"

"github.com/celestiaorg/rsmt2d"

"github.com/celestiaorg/celestia-node/share"
"github.com/celestiaorg/celestia-node/share/eds/edstest"
"github.com/celestiaorg/celestia-node/share/shwap"
)

func TestRsmt2dSample(t *testing.T) {
eds, root := randRsmt2dAccsessor(t, 8)

width := int(eds.Width())
for rowIdx := 0; rowIdx < width; rowIdx++ {
for colIdx := 0; colIdx < width; colIdx++ {
shr, err := eds.Sample(context.TODO(), rowIdx, colIdx)
require.NoError(t, err)

err = shr.Validate(root, rowIdx, colIdx)
require.NoError(t, err)
}
}
}

func TestRsmt2dHalfRowFrom(t *testing.T) {
const odsSize = 8
eds, _ := randRsmt2dAccsessor(t, odsSize)

for rowIdx := 0; rowIdx < odsSize*2; rowIdx++ {
for _, side := range []shwap.RowSide{shwap.Left, shwap.Right} {
row := eds.HalfRow(rowIdx, side)

want := eds.Row(uint(rowIdx))
shares, err := row.Shares()
require.NoError(t, err)
require.Equal(t, want, shares)
}
}
}

func TestRsmt2dSampleForProofAxis(t *testing.T) {
const odsSize = 8
eds := edstest.RandEDS(t, odsSize)
accessor := Rsmt2D{ExtendedDataSquare: eds}

for _, proofType := range []rsmt2d.Axis{rsmt2d.Row, rsmt2d.Col} {
for rowIdx := 0; rowIdx < odsSize*2; rowIdx++ {
for colIdx := 0; colIdx < odsSize*2; colIdx++ {
sample, err := accessor.SampleForProofAxis(rowIdx, colIdx, proofType)
require.NoError(t, err)

want := eds.GetCell(uint(rowIdx), uint(colIdx))
require.Equal(t, want, sample.Share)
require.Equal(t, proofType, sample.ProofType)
require.NotNil(t, sample.Proof)
require.Equal(t, sample.Proof.End()-sample.Proof.Start(), 1)
require.Len(t, sample.Proof.Nodes(), 4)
}
}
}
}

func randRsmt2dAccsessor(t *testing.T, size int) (Rsmt2D, *share.Root) {
eds := edstest.RandEDS(t, size)
root, err := share.NewRoot(eds)
require.NoError(t, err)
return Rsmt2D{ExtendedDataSquare: eds}, root
}
26 changes: 0 additions & 26 deletions share/shwap/namespace_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,39 +3,13 @@ package shwap
import (
"fmt"

"github.com/celestiaorg/rsmt2d"

"github.com/celestiaorg/celestia-node/share"
)

// NamespacedData stores collections of RowNamespaceData, each representing shares and their proofs
// within a namespace.
type NamespacedData []RowNamespaceData

// NamespacedDataFromEDS extracts shares for a specific namespace from an EDS, considering
// each row independently.
func NamespacedDataFromEDS(
square *rsmt2d.ExtendedDataSquare,
namespace share.Namespace,
) (NamespacedData, error) {
root, err := share.NewRoot(square)
if err != nil {
return nil, fmt.Errorf("error computing root: %w", err)
}

rowIdxs := share.RowsWithNamespace(root, namespace)
rows := make(NamespacedData, len(rowIdxs))
for i, idx := range rowIdxs {
shares := square.Row(uint(idx))
rows[i], err = RowNamespaceDataFromShares(shares, namespace, idx)
if err != nil {
return nil, fmt.Errorf("failed to process row %d: %w", idx, err)
}
}

return rows, nil
}

// Flatten combines all shares from all rows within the namespace into a single slice.
func (ns NamespacedData) Flatten() []share.Share {
var shares []share.Share
Expand Down
14 changes: 7 additions & 7 deletions share/shwap/row.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"fmt"

"github.com/celestiaorg/celestia-app/pkg/wrapper"
"github.com/celestiaorg/rsmt2d"

"github.com/celestiaorg/celestia-node/share"
"github.com/celestiaorg/celestia-node/share/shwap/pb"
Expand Down Expand Up @@ -35,14 +34,12 @@ func NewRow(halfShares []share.Share, side RowSide) Row {

// RowFromEDS constructs a new Row from an Extended Data Square based on the specified index and
// side.
func RowFromEDS(square *rsmt2d.ExtendedDataSquare, idx int, side RowSide) Row {
sqrLn := int(square.Width())
shares := square.Row(uint(idx))
func RowFromShares(shares []share.Share, side RowSide) Row {
var halfShares []share.Share
if side == Right {
halfShares = shares[sqrLn/2:] // Take the right half of the shares.
halfShares = shares[len(shares)/2:] // Take the right half of the shares.
} else {
halfShares = shares[:sqrLn/2] // Take the left half of the shares.
halfShares = shares[:len(shares)/2] // Take the left half of the shares.
}

return NewRow(halfShares, side)
Expand Down Expand Up @@ -95,7 +92,10 @@ func (r Row) Validate(dah *share.Root, idx int) error {
return fmt.Errorf("invalid RowSide: %d", r.side)
}

return r.verifyInclusion(dah, idx)
if err := r.verifyInclusion(dah, idx); err != nil {
return fmt.Errorf("%w: %w", ErrorFailedVerification, err)
}
return nil
}

// verifyInclusion verifies the integrity of the row's shares against the provided root hash for the
Expand Down
Loading
Loading