Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(shwap): Extract eds interface #3452

Merged
merged 11 commits into from
Jun 5, 2024
7 changes: 5 additions & 2 deletions share/store/file/axis_half.go → share/new_eds/axis_half.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
package file
package eds

import (
"github.com/celestiaorg/celestia-node/share"
"github.com/celestiaorg/celestia-node/share/shwap"
)

// AxisHalf represents a half of a row or column in a shwap.
type AxisHalf struct {
Shares []share.Share
Shares []share.Share
// IsParity indicates whether the half is parity or data.
IsParity bool
}

// ToRow converts the AxisHalf to a shwap.Row.
func (a AxisHalf) ToRow() shwap.Row {
side := shwap.Left
if a.IsParity {
Expand Down
29 changes: 29 additions & 0 deletions share/new_eds/eds.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package eds

import (
"context"
"io"

"github.com/celestiaorg/rsmt2d"

"github.com/celestiaorg/celestia-node/share"
"github.com/celestiaorg/celestia-node/share/shwap"
)

// EDS is an interface for accessing extended data square data.
type EDS interface {
io.Closer
// Size returns square size of the file.
Size() int
// Sample returns share and corresponding proof for row and column indices. Implementation can choose
// which axis to use for proof. Chosen axis for proof should be indicated in the returned Sample.
Sample(ctx context.Context, rowIdx, colIdx int) (shwap.Sample, error)
// AxisHalf returns half of shares axis of the given type and index. Side is determined by implementation
// and is not guaranteed to be the same for all implementations. Implementations should indicate the side
// in the returned AxisHalf.
AxisHalf(ctx context.Context, axisType rsmt2d.Axis, axisIdx int) (AxisHalf, error)
// RowData returns data for the given namespace and row index.
RowData(ctx context.Context, namespace share.Namespace, rowIdx int) (shwap.RowNamespaceData, error)
// EDS returns extended data square stored in the file.
EDS(ctx context.Context) (*rsmt2d.ExtendedDataSquare, error)
}
133 changes: 133 additions & 0 deletions share/new_eds/in_mem.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package eds

import (
"context"
"fmt"

"github.com/celestiaorg/celestia-app/pkg/wrapper"
"github.com/celestiaorg/rsmt2d"

"github.com/celestiaorg/celestia-node/share"
"github.com/celestiaorg/celestia-node/share/shwap"
)

var _ EDS = InMem{}

// InMem is an in-memory implementation of EDS.
type InMem struct {
*rsmt2d.ExtendedDataSquare
}

// Close does nothing.
func (eds InMem) Close() error {
return nil
}

// Size returns the size of the Extended Data Square.
func (eds InMem) Size() int {
return int(eds.Width())
}

// Sample returns share and corresponding proof for row and column indices.
func (eds InMem) Sample(
_ context.Context,
rowIdx, colIdx int,
) (shwap.Sample, error) {
return eds.SampleForProofAxis(rowIdx, colIdx, rsmt2d.Row)
}

// SampleForProofAxis samples a share from an Extended Data Square based on the provided
// row and column indices and proof axis. It returns a sample with the share and proof.
func (eds InMem) SampleForProofAxis(
rowIdx, colIdx int,
proofType rsmt2d.Axis,
) (shwap.Sample, error) {
var axisIdx, shrIdx int
switch proofType {
case rsmt2d.Row:
axisIdx, shrIdx = rowIdx, colIdx
case rsmt2d.Col:
axisIdx, shrIdx = colIdx, rowIdx
default:
return shwap.Sample{}, fmt.Errorf("invalid proof type: %d", proofType)
}
shrs := getAxis(eds.ExtendedDataSquare, proofType, axisIdx)

tree := wrapper.NewErasuredNamespacedMerkleTree(uint64(eds.Width()/2), uint(axisIdx))
for _, shr := range shrs {
err := tree.Push(shr)
if err != nil {
return shwap.Sample{}, fmt.Errorf("while pushing shares to NMT: %w", err)
}
}

prf, err := tree.ProveRange(shrIdx, shrIdx+1)
if err != nil {
return shwap.Sample{}, fmt.Errorf("while proving range share over NMT: %w", err)
}

return shwap.Sample{
Share: shrs[shrIdx],
Proof: &prf,
ProofType: proofType,
}, nil
}

// AxisHalf returns Shares for the first half of the axis of the given type and index.
func (eds InMem) AxisHalf(_ context.Context, axisType rsmt2d.Axis, axisIdx int) (AxisHalf, error) {
return AxisHalf{
Shares: getAxis(eds.ExtendedDataSquare, axisType, axisIdx)[:eds.Size()/2],
IsParity: false,
}, nil
}

// HalfRow constructs a new shwap.Row from an Extended Data Square based on the specified index and
// side.
func (eds InMem) HalfRow(idx int, side shwap.RowSide) shwap.Row {
shares := eds.ExtendedDataSquare.Row(uint(idx))
return shwap.RowFromShares(shares, side)
}

// RowData returns data for the given namespace and row index.
func (eds InMem) RowData(_ context.Context, namespace share.Namespace, rowIdx int) (shwap.RowNamespaceData, error) {
shares := eds.Row(uint(rowIdx))
return shwap.RowNamespaceDataFromShares(shares, namespace, rowIdx)
}

// NamespacedDataFromEDS extracts shares for a specific namespace from an EDS, considering
// each row independently.
func (eds InMem) NamespacedData(
namespace share.Namespace,
) (shwap.NamespacedData, error) {
root, err := share.NewRoot(eds.ExtendedDataSquare)
if err != nil {
return nil, fmt.Errorf("error computing root: %w", err)
}

rowIdxs := share.RowsWithNamespace(root, namespace)
rows := make(shwap.NamespacedData, len(rowIdxs))
for i, idx := range rowIdxs {
rows[i], err = eds.RowData(context.TODO(), namespace, idx)
if err != nil {
return nil, fmt.Errorf("failed to process row %d: %w", idx, err)
}
}

return rows, nil
}

// EDS returns extended data square stored in the file.
func (eds InMem) EDS(_ context.Context) (*rsmt2d.ExtendedDataSquare, error) {
return eds.ExtendedDataSquare, nil
}

func getAxis(eds *rsmt2d.ExtendedDataSquare, axisType rsmt2d.Axis, axisIdx int) []share.Share {
switch axisType {
case rsmt2d.Row:
return eds.Row(uint(axisIdx))
case rsmt2d.Col:
return eds.Col(uint(axisIdx))
default:
panic("unknown axis")
}
}
93 changes: 93 additions & 0 deletions share/new_eds/in_mem_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package eds

import (
"context"
"testing"

"github.com/stretchr/testify/require"

"github.com/celestiaorg/rsmt2d"

"github.com/celestiaorg/celestia-node/share"
"github.com/celestiaorg/celestia-node/share/eds/edstest"
"github.com/celestiaorg/celestia-node/share/sharetest"
"github.com/celestiaorg/celestia-node/share/shwap"
)

func TestMemFileSample(t *testing.T) {
eds, root := randInMemEDS(t, 8)

width := int(eds.Width())
for rowIdx := 0; rowIdx < width; rowIdx++ {
for colIdx := 0; colIdx < width; colIdx++ {
shr, err := eds.Sample(context.TODO(), rowIdx, colIdx)
require.NoError(t, err)

err = shr.Validate(root, rowIdx, colIdx)
require.NoError(t, err)
}
}
}

func TestHalfRowFromInMem(t *testing.T) {
const odsSize = 8
eds, _ := randInMemEDS(t, odsSize)

for rowIdx := 0; rowIdx < odsSize*2; rowIdx++ {
for _, side := range []shwap.RowSide{shwap.Left, shwap.Right} {
row := eds.HalfRow(rowIdx, side)

want := eds.Row(uint(rowIdx))
shares, err := row.Shares()
require.NoError(t, err)
require.Equal(t, want, shares)
}
}
}

func TestInMemNamespacedData(t *testing.T) {
const odsSize = 8

sharesAmount := odsSize * odsSize
namespace := sharetest.RandV0Namespace()
for amount := 1; amount < sharesAmount; amount++ {
eds, root := edstest.RandEDSWithNamespace(t, namespace, amount, odsSize)
inMem := InMem{ExtendedDataSquare: eds}
nd, err := inMem.NamespacedData(namespace)
require.NoError(t, err)
require.True(t, len(nd) > 0)
require.Len(t, nd.Flatten(), amount)

err = nd.Validate(root, namespace)
require.NoError(t, err)
}
}

func TestInMemSampleForProofAxis(t *testing.T) {
const odsSize = 8
eds := edstest.RandEDS(t, odsSize)
inMem := InMem{ExtendedDataSquare: eds}

for _, proofType := range []rsmt2d.Axis{rsmt2d.Row, rsmt2d.Col} {
for rowIdx := 0; rowIdx < odsSize*2; rowIdx++ {
for colIdx := 0; colIdx < odsSize*2; colIdx++ {
sample, err := inMem.SampleForProofAxis(rowIdx, colIdx, proofType)
require.NoError(t, err)

want := eds.GetCell(uint(rowIdx), uint(colIdx))
require.Equal(t, want, sample.Share)
require.Equal(t, proofType, sample.ProofType)
require.NotNil(t, sample.Proof)
require.Equal(t, sample.Proof.End()-sample.Proof.Start(), 1)
require.Len(t, sample.Proof.Nodes(), 4)
}
}
}
}

func randInMemEDS(t *testing.T, size int) (InMem, *share.Root) {
eds := edstest.RandEDS(t, size)
root, err := share.NewRoot(eds)
require.NoError(t, err)
return InMem{ExtendedDataSquare: eds}, root
}
26 changes: 0 additions & 26 deletions share/shwap/namespace_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,39 +3,13 @@ package shwap
import (
"fmt"

"github.com/celestiaorg/rsmt2d"

"github.com/celestiaorg/celestia-node/share"
)

// NamespacedData stores collections of RowNamespaceData, each representing shares and their proofs
// within a namespace.
type NamespacedData []RowNamespaceData

// NamespacedDataFromEDS extracts shares for a specific namespace from an EDS, considering
// each row independently.
func NamespacedDataFromEDS(
square *rsmt2d.ExtendedDataSquare,
namespace share.Namespace,
) (NamespacedData, error) {
root, err := share.NewRoot(square)
if err != nil {
return nil, fmt.Errorf("error computing root: %w", err)
}

rowIdxs := share.RowsWithNamespace(root, namespace)
rows := make(NamespacedData, len(rowIdxs))
for i, idx := range rowIdxs {
shares := square.Row(uint(idx))
rows[i], err = RowNamespaceDataFromShares(shares, namespace, idx)
if err != nil {
return nil, fmt.Errorf("failed to process row %d: %w", idx, err)
}
}

return rows, nil
}

// Flatten combines all shares from all rows within the namespace into a single slice.
func (ns NamespacedData) Flatten() []share.Share {
var shares []share.Share
Expand Down
9 changes: 3 additions & 6 deletions share/shwap/row.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"fmt"

"github.com/celestiaorg/celestia-app/pkg/wrapper"
"github.com/celestiaorg/rsmt2d"

"github.com/celestiaorg/celestia-node/share"
"github.com/celestiaorg/celestia-node/share/shwap/pb"
Expand Down Expand Up @@ -35,14 +34,12 @@ func NewRow(halfShares []share.Share, side RowSide) Row {

// RowFromEDS constructs a new Row from an Extended Data Square based on the specified index and
// side.
func RowFromEDS(square *rsmt2d.ExtendedDataSquare, idx int, side RowSide) Row {
sqrLn := int(square.Width())
shares := square.Row(uint(idx))
func RowFromShares(shares []share.Share, side RowSide) Row {
var halfShares []share.Share
if side == Right {
halfShares = shares[sqrLn/2:] // Take the right half of the shares.
halfShares = shares[len(shares)/2:] // Take the right half of the shares.
} else {
halfShares = shares[:sqrLn/2] // Take the left half of the shares.
halfShares = shares[:len(shares)/2] // Take the left half of the shares.
}

return NewRow(halfShares, side)
Expand Down
17 changes: 0 additions & 17 deletions share/shwap/row_namespace_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"github.com/celestiaorg/celestia-app/pkg/wrapper"
"github.com/celestiaorg/nmt"
nmt_pb "github.com/celestiaorg/nmt/pb"
"github.com/celestiaorg/rsmt2d"

"github.com/celestiaorg/celestia-node/share"
"github.com/celestiaorg/celestia-node/share/shwap/pb"
Expand All @@ -18,22 +17,6 @@ type RowNamespaceData struct {
Proof *nmt.Proof `json:"proof"` // Proof of the shares' inclusion in the namespace.
}

// RowNamespaceDataFromEDS extracts and constructs a RowNamespaceData from the row of given EDS
// identified by the index and the namespace.
func RowNamespaceDataFromEDS(
eds *rsmt2d.ExtendedDataSquare,
namespace share.Namespace,
rowIdx int,
) (RowNamespaceData, error) {
shares := eds.Row(uint(rowIdx))
rowData, err := RowNamespaceDataFromShares(shares, namespace, rowIdx)
if err != nil {
return RowNamespaceData{}, err
}

return rowData, nil
}

// RowNamespaceDataFromShares extracts and constructs a RowNamespaceData from shares within the
// specified namespace.
func RowNamespaceDataFromShares(
Expand Down
Loading
Loading