From 3dd110030ea92daf30d79547cefb8e862e14b372 Mon Sep 17 00:00:00 2001 From: Vlad <13818348+walldiss@users.noreply.github.com> Date: Wed, 5 Jun 2024 20:55:53 +0500 Subject: [PATCH] refactor(shwap): Extract eds interface (#3452) - Reintroduce file interface as eds interface. Change aims to allow usage of EDS interface outside of storage package and to be high level interface of EDS methods. - Renames of eds interface methods to align with returned shwap types names - Share() -> Sample - Data -> Row data - Extracts NewFromEDS functions to eds file methods - moves associated tests to eds pkg **Additional refactoring:** - **Change Interface Name**: Realized that 'EDS' is a terrible name for an interface. Renamed `eds.EDS` to `eds.Accessor` to more accurately reflect its functionality rather than its internal content. - **Separate Closer**: Extracted `Closer` from `Accessor`. Now it is available in a new composite interface `AccessorCloser`. - **Rename InMem**: Renamed `InMem` to `rsmt2d` to better align with its usage. - **Decouple NamespacedData**: Separated `NamespacedData` from the `rsmt2d` implementation. It is now a standalone function. - **Update EDS Method**: Replaced the `EDS()` method with `Flattened`, similar to `rsmt2d`. Considered introducing two separate methods, `Flattened` and `FlattenedODS`, with the latter to be potentially added later. Proposed to park this suggestion in an issue for future consideration. --- share/new_eds/accessor.go | 34 +++++ share/{store/file => new_eds}/axis_half.go | 7 +- share/new_eds/nd.go | 31 +++++ share/new_eds/nd_test.go | 32 +++++ share/new_eds/rsmt2d.go | 116 ++++++++++++++++++ share/new_eds/rsmt2d_test.go | 74 +++++++++++ share/shwap/namespace_data.go | 26 ---- share/shwap/row.go | 14 +-- share/shwap/row_namespace_data.go | 19 +-- ...ata_test.go => row_namespace_data_test.go} | 44 +++---- share/shwap/row_test.go | 27 ++-- share/shwap/sample.go | 47 +------ share/shwap/sample_test.go | 83 +++++-------- share/store/file/eds_file.go | 25 ---- share/store/file/mem_file.go | 114 ----------------- share/store/file/mem_file_test.go | 51 -------- 16 files changed, 372 insertions(+), 372 deletions(-) create mode 100644 share/new_eds/accessor.go rename share/{store/file => new_eds}/axis_half.go (58%) create mode 100644 share/new_eds/nd.go create mode 100644 share/new_eds/nd_test.go create mode 100644 share/new_eds/rsmt2d.go create mode 100644 share/new_eds/rsmt2d_test.go rename share/shwap/{namespaced_data_test.go => row_namespace_data_test.go} (69%) delete mode 100644 share/store/file/eds_file.go delete mode 100644 share/store/file/mem_file.go delete mode 100644 share/store/file/mem_file_test.go diff --git a/share/new_eds/accessor.go b/share/new_eds/accessor.go new file mode 100644 index 0000000000..d31472546d --- /dev/null +++ b/share/new_eds/accessor.go @@ -0,0 +1,34 @@ +package eds + +import ( + "context" + "io" + + "github.com/celestiaorg/rsmt2d" + + "github.com/celestiaorg/celestia-node/share" + "github.com/celestiaorg/celestia-node/share/shwap" +) + +// Accessor is an interface for accessing extended data square data. +type Accessor interface { + // Size returns square size of the Accessor. + Size(ctx context.Context) int + // Sample returns share and corresponding proof for row and column indices. Implementation can + // choose which axis to use for proof. Chosen axis for proof should be indicated in the returned + // Sample. + Sample(ctx context.Context, rowIdx, colIdx int) (shwap.Sample, error) + // AxisHalf returns half of shares axis of the given type and index. Side is determined by + // implementation. Implementations should indicate the side in the returned AxisHalf. + AxisHalf(ctx context.Context, axisType rsmt2d.Axis, axisIdx int) (AxisHalf, error) + // RowNamespaceData returns data for the given namespace and row index. + RowNamespaceData(ctx context.Context, namespace share.Namespace, rowIdx int) (shwap.RowNamespaceData, error) + // Shares returns data shares extracted from the Accessor. + Shares(ctx context.Context) ([]share.Share, error) +} + +// AccessorCloser is an interface that groups Accessor and io.Closer interfaces. +type AccessorCloser interface { + Accessor + io.Closer +} diff --git a/share/store/file/axis_half.go b/share/new_eds/axis_half.go similarity index 58% rename from share/store/file/axis_half.go rename to share/new_eds/axis_half.go index 58b9306b85..17b29de591 100644 --- a/share/store/file/axis_half.go +++ b/share/new_eds/axis_half.go @@ -1,15 +1,18 @@ -package file +package eds import ( "github.com/celestiaorg/celestia-node/share" "github.com/celestiaorg/celestia-node/share/shwap" ) +// AxisHalf represents a half of data for a row or column in the EDS. type AxisHalf struct { - Shares []share.Share + Shares []share.Share + // IsParity indicates whether the half is parity or data. IsParity bool } +// ToRow converts the AxisHalf to a shwap.Row. func (a AxisHalf) ToRow() shwap.Row { side := shwap.Left if a.IsParity { diff --git a/share/new_eds/nd.go b/share/new_eds/nd.go new file mode 100644 index 0000000000..6031acc41e --- /dev/null +++ b/share/new_eds/nd.go @@ -0,0 +1,31 @@ +package eds + +import ( + "context" + "fmt" + + "github.com/celestiaorg/celestia-node/share" + "github.com/celestiaorg/celestia-node/share/shwap" +) + +// NamespacedData extracts shares for a specific namespace from an EDS, considering +// each row independently. It uses root to determine which rows to extract data from, +// avoiding the need to recalculate the row roots for each row. +func NamespacedData( + ctx context.Context, + root *share.Root, + eds Accessor, + namespace share.Namespace, +) (shwap.NamespacedData, error) { + rowIdxs := share.RowsWithNamespace(root, namespace) + rows := make(shwap.NamespacedData, len(rowIdxs)) + var err error + for i, idx := range rowIdxs { + rows[i], err = eds.RowNamespaceData(ctx, namespace, idx) + if err != nil { + return nil, fmt.Errorf("failed to process row %d: %w", idx, err) + } + } + + return rows, nil +} diff --git a/share/new_eds/nd_test.go b/share/new_eds/nd_test.go new file mode 100644 index 0000000000..a0780292ef --- /dev/null +++ b/share/new_eds/nd_test.go @@ -0,0 +1,32 @@ +package eds + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/celestiaorg/celestia-node/share/eds/edstest" + "github.com/celestiaorg/celestia-node/share/sharetest" +) + +func TestNamespacedData(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + t.Cleanup(cancel) + + const odsSize = 8 + sharesAmount := odsSize * odsSize + namespace := sharetest.RandV0Namespace() + for amount := 1; amount < sharesAmount; amount++ { + eds, root := edstest.RandEDSWithNamespace(t, namespace, amount, odsSize) + rsmt2d := Rsmt2D{ExtendedDataSquare: eds} + nd, err := NamespacedData(ctx, root, rsmt2d, namespace) + require.NoError(t, err) + require.True(t, len(nd) > 0) + require.Len(t, nd.Flatten(), amount) + + err = nd.Validate(root, namespace) + require.NoError(t, err) + } +} diff --git a/share/new_eds/rsmt2d.go b/share/new_eds/rsmt2d.go new file mode 100644 index 0000000000..e37be86834 --- /dev/null +++ b/share/new_eds/rsmt2d.go @@ -0,0 +1,116 @@ +package eds + +import ( + "context" + "fmt" + + "github.com/celestiaorg/celestia-app/pkg/wrapper" + "github.com/celestiaorg/rsmt2d" + + "github.com/celestiaorg/celestia-node/share" + "github.com/celestiaorg/celestia-node/share/shwap" +) + +var _ Accessor = Rsmt2D{} + +// Rsmt2D is a rsmt2d based in-memory implementation of Accessor. +type Rsmt2D struct { + *rsmt2d.ExtendedDataSquare +} + +// Size returns the size of the Extended Data Square. +func (eds Rsmt2D) Size(context.Context) int { + return int(eds.Width()) +} + +// Sample returns share and corresponding proof for row and column indices. +func (eds Rsmt2D) Sample( + _ context.Context, + rowIdx, colIdx int, +) (shwap.Sample, error) { + return eds.SampleForProofAxis(rowIdx, colIdx, rsmt2d.Row) +} + +// SampleForProofAxis samples a share from an Extended Data Square based on the provided +// row and column indices and proof axis. It returns a sample with the share and proof. +func (eds Rsmt2D) SampleForProofAxis( + rowIdx, colIdx int, + proofType rsmt2d.Axis, +) (shwap.Sample, error) { + axisIdx, shrIdx := relativeIndexes(rowIdx, colIdx, proofType) + shares := getAxis(eds.ExtendedDataSquare, proofType, axisIdx) + + tree := wrapper.NewErasuredNamespacedMerkleTree(uint64(eds.Width()/2), uint(axisIdx)) + for _, shr := range shares { + err := tree.Push(shr) + if err != nil { + return shwap.Sample{}, fmt.Errorf("while pushing shares to NMT: %w", err) + } + } + + prf, err := tree.ProveRange(shrIdx, shrIdx+1) + if err != nil { + return shwap.Sample{}, fmt.Errorf("while proving range share over NMT: %w", err) + } + + return shwap.Sample{ + Share: shares[shrIdx], + Proof: &prf, + ProofType: proofType, + }, nil +} + +// AxisHalf returns Shares for the first half of the axis of the given type and index. +func (eds Rsmt2D) AxisHalf(_ context.Context, axisType rsmt2d.Axis, axisIdx int) (AxisHalf, error) { + shares := getAxis(eds.ExtendedDataSquare, axisType, axisIdx) + halfShares := shares[:eds.Width()/2] + return AxisHalf{ + Shares: halfShares, + IsParity: false, + }, nil +} + +// HalfRow constructs a new shwap.Row from an Extended Data Square based on the specified index and +// side. +func (eds Rsmt2D) HalfRow(idx int, side shwap.RowSide) shwap.Row { + shares := eds.ExtendedDataSquare.Row(uint(idx)) + return shwap.RowFromShares(shares, side) +} + +// RowNamespaceData returns data for the given namespace and row index. +func (eds Rsmt2D) RowNamespaceData( + _ context.Context, + namespace share.Namespace, + rowIdx int, +) (shwap.RowNamespaceData, error) { + shares := eds.Row(uint(rowIdx)) + return shwap.RowNamespaceDataFromShares(shares, namespace, rowIdx) +} + +// Shares returns data shares extracted from the EDS. It returns new copy of the shares each +// time. +func (eds Rsmt2D) Shares(_ context.Context) ([]share.Share, error) { + return eds.ExtendedDataSquare.Flattened(), nil +} + +func getAxis(eds *rsmt2d.ExtendedDataSquare, axisType rsmt2d.Axis, axisIdx int) []share.Share { + switch axisType { + case rsmt2d.Row: + return eds.Row(uint(axisIdx)) + case rsmt2d.Col: + return eds.Col(uint(axisIdx)) + default: + panic("unknown axis") + } +} + +func relativeIndexes(rowIdx, colIdx int, axisType rsmt2d.Axis) (axisIdx, shrIdx int) { + switch axisType { + case rsmt2d.Row: + return rowIdx, colIdx + case rsmt2d.Col: + return colIdx, rowIdx + default: + panic(fmt.Sprintf("invalid proof type: %d", axisType)) + } +} diff --git a/share/new_eds/rsmt2d_test.go b/share/new_eds/rsmt2d_test.go new file mode 100644 index 0000000000..ae10d3b4be --- /dev/null +++ b/share/new_eds/rsmt2d_test.go @@ -0,0 +1,74 @@ +package eds + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/celestiaorg/rsmt2d" + + "github.com/celestiaorg/celestia-node/share" + "github.com/celestiaorg/celestia-node/share/eds/edstest" + "github.com/celestiaorg/celestia-node/share/shwap" +) + +func TestRsmt2dSample(t *testing.T) { + eds, root := randRsmt2dAccsessor(t, 8) + + width := int(eds.Width()) + for rowIdx := 0; rowIdx < width; rowIdx++ { + for colIdx := 0; colIdx < width; colIdx++ { + shr, err := eds.Sample(context.TODO(), rowIdx, colIdx) + require.NoError(t, err) + + err = shr.Validate(root, rowIdx, colIdx) + require.NoError(t, err) + } + } +} + +func TestRsmt2dHalfRowFrom(t *testing.T) { + const odsSize = 8 + eds, _ := randRsmt2dAccsessor(t, odsSize) + + for rowIdx := 0; rowIdx < odsSize*2; rowIdx++ { + for _, side := range []shwap.RowSide{shwap.Left, shwap.Right} { + row := eds.HalfRow(rowIdx, side) + + want := eds.Row(uint(rowIdx)) + shares, err := row.Shares() + require.NoError(t, err) + require.Equal(t, want, shares) + } + } +} + +func TestRsmt2dSampleForProofAxis(t *testing.T) { + const odsSize = 8 + eds := edstest.RandEDS(t, odsSize) + accessor := Rsmt2D{ExtendedDataSquare: eds} + + for _, proofType := range []rsmt2d.Axis{rsmt2d.Row, rsmt2d.Col} { + for rowIdx := 0; rowIdx < odsSize*2; rowIdx++ { + for colIdx := 0; colIdx < odsSize*2; colIdx++ { + sample, err := accessor.SampleForProofAxis(rowIdx, colIdx, proofType) + require.NoError(t, err) + + want := eds.GetCell(uint(rowIdx), uint(colIdx)) + require.Equal(t, want, sample.Share) + require.Equal(t, proofType, sample.ProofType) + require.NotNil(t, sample.Proof) + require.Equal(t, sample.Proof.End()-sample.Proof.Start(), 1) + require.Len(t, sample.Proof.Nodes(), 4) + } + } + } +} + +func randRsmt2dAccsessor(t *testing.T, size int) (Rsmt2D, *share.Root) { + eds := edstest.RandEDS(t, size) + root, err := share.NewRoot(eds) + require.NoError(t, err) + return Rsmt2D{ExtendedDataSquare: eds}, root +} diff --git a/share/shwap/namespace_data.go b/share/shwap/namespace_data.go index 36cdd00444..687958be15 100644 --- a/share/shwap/namespace_data.go +++ b/share/shwap/namespace_data.go @@ -3,8 +3,6 @@ package shwap import ( "fmt" - "github.com/celestiaorg/rsmt2d" - "github.com/celestiaorg/celestia-node/share" ) @@ -12,30 +10,6 @@ import ( // within a namespace. type NamespacedData []RowNamespaceData -// NamespacedDataFromEDS extracts shares for a specific namespace from an EDS, considering -// each row independently. -func NamespacedDataFromEDS( - square *rsmt2d.ExtendedDataSquare, - namespace share.Namespace, -) (NamespacedData, error) { - root, err := share.NewRoot(square) - if err != nil { - return nil, fmt.Errorf("error computing root: %w", err) - } - - rowIdxs := share.RowsWithNamespace(root, namespace) - rows := make(NamespacedData, len(rowIdxs)) - for i, idx := range rowIdxs { - shares := square.Row(uint(idx)) - rows[i], err = RowNamespaceDataFromShares(shares, namespace, idx) - if err != nil { - return nil, fmt.Errorf("failed to process row %d: %w", idx, err) - } - } - - return rows, nil -} - // Flatten combines all shares from all rows within the namespace into a single slice. func (ns NamespacedData) Flatten() []share.Share { var shares []share.Share diff --git a/share/shwap/row.go b/share/shwap/row.go index 046ff5eceb..f03c7366f5 100644 --- a/share/shwap/row.go +++ b/share/shwap/row.go @@ -5,7 +5,6 @@ import ( "fmt" "github.com/celestiaorg/celestia-app/pkg/wrapper" - "github.com/celestiaorg/rsmt2d" "github.com/celestiaorg/celestia-node/share" "github.com/celestiaorg/celestia-node/share/shwap/pb" @@ -35,14 +34,12 @@ func NewRow(halfShares []share.Share, side RowSide) Row { // RowFromEDS constructs a new Row from an Extended Data Square based on the specified index and // side. -func RowFromEDS(square *rsmt2d.ExtendedDataSquare, idx int, side RowSide) Row { - sqrLn := int(square.Width()) - shares := square.Row(uint(idx)) +func RowFromShares(shares []share.Share, side RowSide) Row { var halfShares []share.Share if side == Right { - halfShares = shares[sqrLn/2:] // Take the right half of the shares. + halfShares = shares[len(shares)/2:] // Take the right half of the shares. } else { - halfShares = shares[:sqrLn/2] // Take the left half of the shares. + halfShares = shares[:len(shares)/2] // Take the left half of the shares. } return NewRow(halfShares, side) @@ -95,7 +92,10 @@ func (r Row) Validate(dah *share.Root, idx int) error { return fmt.Errorf("invalid RowSide: %d", r.side) } - return r.verifyInclusion(dah, idx) + if err := r.verifyInclusion(dah, idx); err != nil { + return fmt.Errorf("%w: %w", ErrFailedVerification, err) + } + return nil } // verifyInclusion verifies the integrity of the row's shares against the provided root hash for the diff --git a/share/shwap/row_namespace_data.go b/share/shwap/row_namespace_data.go index 01af228a8d..2bbf4c06fa 100644 --- a/share/shwap/row_namespace_data.go +++ b/share/shwap/row_namespace_data.go @@ -6,7 +6,6 @@ import ( "github.com/celestiaorg/celestia-app/pkg/wrapper" "github.com/celestiaorg/nmt" nmt_pb "github.com/celestiaorg/nmt/pb" - "github.com/celestiaorg/rsmt2d" "github.com/celestiaorg/celestia-node/share" "github.com/celestiaorg/celestia-node/share/shwap/pb" @@ -18,22 +17,6 @@ type RowNamespaceData struct { Proof *nmt.Proof `json:"proof"` // Proof of the shares' inclusion in the namespace. } -// RowNamespaceDataFromEDS extracts and constructs a RowNamespaceData from the row of given EDS -// identified by the index and the namespace. -func RowNamespaceDataFromEDS( - eds *rsmt2d.ExtendedDataSquare, - namespace share.Namespace, - rowIdx int, -) (RowNamespaceData, error) { - shares := eds.Row(uint(rowIdx)) - rowData, err := RowNamespaceDataFromShares(shares, namespace, rowIdx) - if err != nil { - return RowNamespaceData{}, err - } - - return rowData, nil -} - // RowNamespaceDataFromShares extracts and constructs a RowNamespaceData from shares within the // specified namespace. func RowNamespaceDataFromShares( @@ -144,7 +127,7 @@ func (rnd RowNamespaceData) Validate(dah *share.Root, namespace share.Namespace, } if !rnd.verifyInclusion(rowRoot, namespace) { - return fmt.Errorf("inclusion proof failed for row %d", rowIdx) + return fmt.Errorf("%w for row: %d", ErrFailedVerification, rowIdx) } return nil } diff --git a/share/shwap/namespaced_data_test.go b/share/shwap/row_namespace_data_test.go similarity index 69% rename from share/shwap/namespaced_data_test.go rename to share/shwap/row_namespace_data_test.go index 54663828a4..07f434b87e 100644 --- a/share/shwap/namespaced_data_test.go +++ b/share/shwap/row_namespace_data_test.go @@ -1,15 +1,19 @@ -package shwap +package shwap_test import ( "bytes" + "context" "slices" "testing" + "time" "github.com/stretchr/testify/require" "github.com/celestiaorg/celestia-node/share" "github.com/celestiaorg/celestia-node/share/eds/edstest" + eds "github.com/celestiaorg/celestia-node/share/new_eds" "github.com/celestiaorg/celestia-node/share/sharetest" + "github.com/celestiaorg/celestia-node/share/shwap" ) func TestNamespacedRowFromShares(t *testing.T) { @@ -26,7 +30,7 @@ func TestNamespacedRowFromShares(t *testing.T) { require.NoError(t, err) extended := slices.Concat(shares, parity) - nr, err := RowNamespaceDataFromShares(extended, minNamespace, 0) + nr, err := shwap.RowNamespaceDataFromShares(extended, minNamespace, 0) require.NoError(t, err) require.Equal(t, namespacedAmount, len(nr.Shares)) } @@ -46,35 +50,23 @@ func TestNamespacedRowFromSharesNonIncluded(t *testing.T) { require.NoError(t, err) extended := slices.Concat(shares, parity) - nr, err := RowNamespaceDataFromShares(extended, absentNs, 0) + nr, err := shwap.RowNamespaceDataFromShares(extended, absentNs, 0) require.NoError(t, err) require.Len(t, nr.Shares, 0) require.True(t, nr.Proof.IsOfAbsence()) } -func TestNamespacedSharesFromEDS(t *testing.T) { - const odsSize = 8 - sharesAmount := odsSize * odsSize - namespace := sharetest.RandV0Namespace() - for amount := 1; amount < sharesAmount; amount++ { - eds, root := edstest.RandEDSWithNamespace(t, namespace, amount, odsSize) - nd, err := NamespacedDataFromEDS(eds, namespace) - require.NoError(t, err) - require.True(t, len(nd) > 0) - require.Len(t, nd.Flatten(), amount) - - err = nd.Validate(root, namespace) - require.NoError(t, err) - } -} - func TestValidateNamespacedRow(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + t.Cleanup(cancel) + const odsSize = 8 sharesAmount := odsSize * odsSize namespace := sharetest.RandV0Namespace() for amount := 1; amount < sharesAmount; amount++ { - eds, root := edstest.RandEDSWithNamespace(t, namespace, amount, odsSize) - nd, err := NamespacedDataFromEDS(eds, namespace) + randEDS, root := edstest.RandEDSWithNamespace(t, namespace, amount, odsSize) + rsmt2d := eds.Rsmt2D{ExtendedDataSquare: randEDS} + nd, err := eds.NamespacedData(ctx, root, rsmt2d, namespace) require.NoError(t, err) require.True(t, len(nd) > 0) @@ -89,15 +81,19 @@ func TestValidateNamespacedRow(t *testing.T) { } func TestNamespacedRowProtoEncoding(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + t.Cleanup(cancel) + const odsSize = 8 namespace := sharetest.RandV0Namespace() - eds, _ := edstest.RandEDSWithNamespace(t, namespace, odsSize, odsSize) - nd, err := NamespacedDataFromEDS(eds, namespace) + randEDS, root := edstest.RandEDSWithNamespace(t, namespace, odsSize, odsSize) + rsmt2d := eds.Rsmt2D{ExtendedDataSquare: randEDS} + nd, err := eds.NamespacedData(ctx, root, rsmt2d, namespace) require.NoError(t, err) require.True(t, len(nd) > 0) expected := nd[0] pb := expected.ToProto() - ndOut := RowNamespaceDataFromProto(pb) + ndOut := shwap.RowNamespaceDataFromProto(pb) require.Equal(t, expected, ndOut) } diff --git a/share/shwap/row_test.go b/share/shwap/row_test.go index afe56c914c..5ea4c76b61 100644 --- a/share/shwap/row_test.go +++ b/share/shwap/row_test.go @@ -9,24 +9,23 @@ import ( "github.com/celestiaorg/celestia-node/share/eds/edstest" ) -func TestNewRowFromEDS(t *testing.T) { +func TestRowFromShares(t *testing.T) { const odsSize = 8 eds := edstest.RandEDS(t, odsSize) for rowIdx := 0; rowIdx < odsSize*2; rowIdx++ { for _, side := range []RowSide{Left, Right} { - row := RowFromEDS(eds, rowIdx, side) - - want := eds.Row(uint(rowIdx)) - shares, err := row.Shares() + shares := eds.Row(uint(rowIdx)) + row := RowFromShares(shares, side) + extended, err := row.Shares() require.NoError(t, err) - require.Equal(t, want, shares) + require.Equal(t, shares, extended) var half []share.Share if side == Right { - half = want[odsSize:] + half = shares[odsSize:] } else { - half = want[:odsSize] + half = shares[:odsSize] } require.Equal(t, half, row.halfShares) require.Equal(t, side, row.side) @@ -42,7 +41,8 @@ func TestRowValidate(t *testing.T) { for rowIdx := 0; rowIdx < odsSize*2; rowIdx++ { for _, side := range []RowSide{Left, Right} { - row := RowFromEDS(eds, rowIdx, side) + shares := eds.Row(uint(rowIdx)) + row := RowFromShares(shares, side) err := row.Validate(root, rowIdx) require.NoError(t, err) @@ -56,7 +56,8 @@ func TestRowValidateNegativeCases(t *testing.T) { eds := edstest.RandEDS(t, 8) // Generate a random Extended Data Square of size 8 root, err := share.NewRoot(eds) require.NoError(t, err) - row := RowFromEDS(eds, 0, Left) + shares := eds.Row(0) + row := RowFromShares(shares, Left) // Test with incorrect side specification invalidSideRow := Row{halfShares: row.halfShares, side: RowSide(999)} @@ -89,7 +90,8 @@ func TestRowProtoEncoding(t *testing.T) { for rowIdx := 0; rowIdx < odsSize*2; rowIdx++ { for _, side := range []RowSide{Left, Right} { - row := RowFromEDS(eds, rowIdx, side) + shares := eds.Row(uint(rowIdx)) + row := RowFromShares(shares, side) pb := row.ToProto() rowOut := RowFromProto(pb) @@ -105,7 +107,8 @@ func BenchmarkRowValidate(b *testing.B) { eds := edstest.RandEDS(b, odsSize) root, err := share.NewRoot(eds) require.NoError(b, err) - row := RowFromEDS(eds, 0, Left) + shares := eds.Row(0) + row := RowFromShares(shares, Left) b.ResetTimer() for i := 0; i < b.N; i++ { diff --git a/share/shwap/sample.go b/share/shwap/sample.go index ec1185eb1a..c521410e35 100644 --- a/share/shwap/sample.go +++ b/share/shwap/sample.go @@ -4,7 +4,6 @@ import ( "errors" "fmt" - "github.com/celestiaorg/celestia-app/pkg/wrapper" "github.com/celestiaorg/nmt" nmt_pb "github.com/celestiaorg/nmt/pb" "github.com/celestiaorg/rsmt2d" @@ -13,6 +12,10 @@ import ( "github.com/celestiaorg/celestia-node/share/shwap/pb" ) +// ErrFailedVerification is returned when inclusion proof verification fails. It is returned +// when the data and the proof do not match trusted data root. +var ErrFailedVerification = errors.New("failed to verify inclusion") + // Sample represents a data share along with its Merkle proof, used to validate the share's // inclusion in a data square. type Sample struct { @@ -21,46 +24,6 @@ type Sample struct { ProofType rsmt2d.Axis // ProofType indicates whether the proof is against a row or a column. } -// SampleFromEDS samples a share from an Extended Data Square based on the provided index and axis. -// This function generates a Merkle tree proof for the specified share. -func SampleFromEDS( - square *rsmt2d.ExtendedDataSquare, - proofType rsmt2d.Axis, - rowIdx, colIdx int, -) (Sample, error) { - var shrs []share.Share - var axisIdx, shrIdx int - switch proofType { - case rsmt2d.Row: - axisIdx, shrIdx = rowIdx, colIdx - shrs = square.Row(uint(rowIdx)) - case rsmt2d.Col: - axisIdx, shrIdx = colIdx, rowIdx - shrs = square.Col(uint(colIdx)) - default: - return Sample{}, fmt.Errorf("invalid proof type: %d", proofType) - } - - tree := wrapper.NewErasuredNamespacedMerkleTree(uint64(square.Width()/2), uint(axisIdx)) - for _, shr := range shrs { - err := tree.Push(shr) - if err != nil { - return Sample{}, fmt.Errorf("while pushing shares to NMT: %w", err) - } - } - - prf, err := tree.ProveRange(shrIdx, shrIdx+1) - if err != nil { - return Sample{}, fmt.Errorf("while proving range share over NMT: %w", err) - } - - return Sample{ - Share: shrs[shrIdx], - Proof: &prf, - ProofType: proofType, - }, nil -} - // SampleFromProto converts a protobuf Sample back into its domain model equivalent. func SampleFromProto(s *pb.Sample) Sample { proof := nmt.NewInclusionProof( @@ -104,7 +67,7 @@ func (s Sample) Validate(dah *share.Root, rowIdx, colIdx int) error { return fmt.Errorf("invalid SampleProofType: %d", s.ProofType) } if !s.verifyInclusion(dah, rowIdx, colIdx) { - return fmt.Errorf("share proof is invalid") + return ErrFailedVerification } return nil } diff --git a/share/shwap/sample_test.go b/share/shwap/sample_test.go index 447ca80b6d..34fb4bfa9b 100644 --- a/share/shwap/sample_test.go +++ b/share/shwap/sample_test.go @@ -1,6 +1,7 @@ -package shwap +package shwap_test import ( + "context" "testing" "github.com/stretchr/testify/require" @@ -9,42 +10,23 @@ import ( "github.com/celestiaorg/celestia-node/share" "github.com/celestiaorg/celestia-node/share/eds/edstest" + eds "github.com/celestiaorg/celestia-node/share/new_eds" + "github.com/celestiaorg/celestia-node/share/shwap" ) -func TestNewSampleFromEDS(t *testing.T) { - const odsSize = 8 - eds := edstest.RandEDS(t, odsSize) - - for _, proofType := range []rsmt2d.Axis{rsmt2d.Row, rsmt2d.Col} { - for rowIdx := 0; rowIdx < odsSize*2; rowIdx++ { - for colIdx := 0; colIdx < odsSize*2; colIdx++ { - sample, err := SampleFromEDS(eds, proofType, rowIdx, colIdx) - require.NoError(t, err) - - want := eds.GetCell(uint(rowIdx), uint(colIdx)) - require.Equal(t, want, sample.Share) - require.Equal(t, proofType, sample.ProofType) - require.NotNil(t, sample.Proof) - require.Equal(t, sample.Proof.End()-sample.Proof.Start(), 1) - require.Len(t, sample.Proof.Nodes(), 4) - } - } - } -} - func TestSampleValidate(t *testing.T) { const odsSize = 8 - eds := edstest.RandEDS(t, odsSize) - root, err := share.NewRoot(eds) + randEDS := edstest.RandEDS(t, odsSize) + root, err := share.NewRoot(randEDS) require.NoError(t, err) + inMem := eds.Rsmt2D{ExtendedDataSquare: randEDS} for _, proofType := range []rsmt2d.Axis{rsmt2d.Row, rsmt2d.Col} { for rowIdx := 0; rowIdx < odsSize*2; rowIdx++ { for colIdx := 0; colIdx < odsSize*2; colIdx++ { - sample, err := SampleFromEDS(eds, proofType, rowIdx, colIdx) + sample, err := inMem.SampleForProofAxis(rowIdx, colIdx, proofType) require.NoError(t, err) - require.True(t, sample.verifyInclusion(root, rowIdx, colIdx)) require.NoError(t, sample.Validate(root, rowIdx, colIdx)) } } @@ -54,55 +36,53 @@ func TestSampleValidate(t *testing.T) { // TestSampleNegativeVerifyInclusion checks func TestSampleNegativeVerifyInclusion(t *testing.T) { const odsSize = 8 - eds := edstest.RandEDS(t, odsSize) - root, err := share.NewRoot(eds) + randEDS := edstest.RandEDS(t, odsSize) + root, err := share.NewRoot(randEDS) require.NoError(t, err) + inMem := eds.Rsmt2D{ExtendedDataSquare: randEDS} - sample, err := SampleFromEDS(eds, rsmt2d.Row, 0, 0) + sample, err := inMem.Sample(context.Background(), 0, 0) + require.NoError(t, err) + err = sample.Validate(root, 0, 0) require.NoError(t, err) - included := sample.verifyInclusion(root, 0, 0) - require.True(t, included) // incorrect row index - included = sample.verifyInclusion(root, 1, 0) - require.False(t, included) - - // incorrect col index is not used in the inclusion proof verification - included = sample.verifyInclusion(root, 0, 1) - require.True(t, included) + err = sample.Validate(root, 1, 0) + require.ErrorIs(t, err, shwap.ErrFailedVerification) // Corrupt the share sample.Share[0] ^= 0xFF - included = sample.verifyInclusion(root, 0, 0) - require.False(t, included) + err = sample.Validate(root, 0, 0) + require.ErrorIs(t, err, shwap.ErrFailedVerification) // incorrect proofType - sample, err = SampleFromEDS(eds, rsmt2d.Row, 0, 0) + sample, err = inMem.Sample(context.Background(), 0, 0) require.NoError(t, err) sample.ProofType = rsmt2d.Col - included = sample.verifyInclusion(root, 0, 0) - require.False(t, included) + err = sample.Validate(root, 0, 0) + require.ErrorIs(t, err, shwap.ErrFailedVerification) // Corrupt the last root hash byte - sample, err = SampleFromEDS(eds, rsmt2d.Row, 0, 0) + sample, err = inMem.Sample(context.Background(), 0, 0) require.NoError(t, err) root.RowRoots[0][len(root.RowRoots[0])-1] ^= 0xFF - included = sample.verifyInclusion(root, 0, 0) - require.False(t, included) + err = sample.Validate(root, 0, 0) + require.ErrorIs(t, err, shwap.ErrFailedVerification) } func TestSampleProtoEncoding(t *testing.T) { const odsSize = 8 - eds := edstest.RandEDS(t, odsSize) + randEDS := edstest.RandEDS(t, odsSize) + inMem := eds.Rsmt2D{ExtendedDataSquare: randEDS} for _, proofType := range []rsmt2d.Axis{rsmt2d.Row, rsmt2d.Col} { for rowIdx := 0; rowIdx < odsSize*2; rowIdx++ { for colIdx := 0; colIdx < odsSize*2; colIdx++ { - sample, err := SampleFromEDS(eds, proofType, rowIdx, colIdx) + sample, err := inMem.SampleForProofAxis(rowIdx, colIdx, proofType) require.NoError(t, err) pb := sample.ToProto() - sampleOut := SampleFromProto(pb) + sampleOut := shwap.SampleFromProto(pb) require.NoError(t, err) require.Equal(t, sample, sampleOut) } @@ -114,10 +94,11 @@ func TestSampleProtoEncoding(t *testing.T) { // BenchmarkSampleValidate-10 284829 3935 ns/op func BenchmarkSampleValidate(b *testing.B) { const odsSize = 32 - eds := edstest.RandEDS(b, odsSize) - root, err := share.NewRoot(eds) + randEDS := edstest.RandEDS(b, odsSize) + root, err := share.NewRoot(randEDS) require.NoError(b, err) - sample, err := SampleFromEDS(eds, rsmt2d.Row, 0, 0) + inMem := eds.Rsmt2D{ExtendedDataSquare: randEDS} + sample, err := inMem.SampleForProofAxis(0, 0, rsmt2d.Row) require.NoError(b, err) b.ResetTimer() diff --git a/share/store/file/eds_file.go b/share/store/file/eds_file.go deleted file mode 100644 index fadd1883a1..0000000000 --- a/share/store/file/eds_file.go +++ /dev/null @@ -1,25 +0,0 @@ -package file - -import ( - "context" - "io" - - "github.com/celestiaorg/rsmt2d" - - "github.com/celestiaorg/celestia-node/share" - "github.com/celestiaorg/celestia-node/share/shwap" -) - -type EdsFile interface { - io.Closer - // Size returns square size of the file. - Size() int - // Share returns share and corresponding proof for the given axis and share index in this axis. - Share(ctx context.Context, rowIdx, colIdx int) (*shwap.Sample, error) - // AxisHalf returns Shares for the first half of the axis of the given type and index. - AxisHalf(ctx context.Context, axisType rsmt2d.Axis, axisIdx int) (AxisHalf, error) - // Data returns data for the given namespace and row index. - Data(ctx context.Context, namespace share.Namespace, rowIdx int) (shwap.RowNamespaceData, error) - // EDS returns extended data square stored in the file. - EDS(ctx context.Context) (*rsmt2d.ExtendedDataSquare, error) -} diff --git a/share/store/file/mem_file.go b/share/store/file/mem_file.go deleted file mode 100644 index d3fefcb29b..0000000000 --- a/share/store/file/mem_file.go +++ /dev/null @@ -1,114 +0,0 @@ -package file - -import ( - "context" - - "github.com/celestiaorg/celestia-app/pkg/wrapper" - "github.com/celestiaorg/nmt" - "github.com/celestiaorg/rsmt2d" - - "github.com/celestiaorg/celestia-node/share" - "github.com/celestiaorg/celestia-node/share/ipld" - "github.com/celestiaorg/celestia-node/share/shwap" -) - -var _ EdsFile = (*MemFile)(nil) - -type MemFile struct { - Eds *rsmt2d.ExtendedDataSquare -} - -func (f *MemFile) Close() error { - return nil -} - -func (f *MemFile) Size() int { - return int(f.Eds.Width()) -} - -func (f *MemFile) Share( - _ context.Context, - rowIdx, colIdx int, -) (*shwap.Sample, error) { - axisType := rsmt2d.Row - axisIdx, shrIdx := rowIdx, colIdx - - shares := getAxis(f.Eds, axisType, axisIdx) - tree := wrapper.NewErasuredNamespacedMerkleTree(uint64(f.Size()/2), uint(axisIdx)) - for _, shr := range shares { - err := tree.Push(shr) - if err != nil { - return nil, err - } - } - - proof, err := tree.ProveRange(shrIdx, shrIdx+1) - if err != nil { - return nil, err - } - - return &shwap.Sample{ - Share: shares[shrIdx], - Proof: &proof, - ProofType: axisType, - }, nil -} - -func (f *MemFile) AxisHalf(_ context.Context, axisType rsmt2d.Axis, axisIdx int) (AxisHalf, error) { - return AxisHalf{ - Shares: getAxis(f.Eds, axisType, axisIdx)[:f.Size()/2], - IsParity: false, - }, nil -} - -func (f *MemFile) Data(_ context.Context, namespace share.Namespace, rowIdx int) (shwap.RowNamespaceData, error) { - shares := getAxis(f.Eds, rsmt2d.Row, rowIdx) - return ndDataFromShares(shares, namespace, rowIdx) -} - -func (f *MemFile) EDS(_ context.Context) (*rsmt2d.ExtendedDataSquare, error) { - return f.Eds, nil -} - -func getAxis(eds *rsmt2d.ExtendedDataSquare, axisType rsmt2d.Axis, axisIdx int) []share.Share { - switch axisType { - case rsmt2d.Row: - return eds.Row(uint(axisIdx)) - case rsmt2d.Col: - return eds.Col(uint(axisIdx)) - default: - panic("unknown axis") - } -} - -func ndDataFromShares(shares []share.Share, namespace share.Namespace, rowIdx int) (shwap.RowNamespaceData, error) { - bserv := ipld.NewMemBlockservice() - batchAdder := ipld.NewNmtNodeAdder(context.TODO(), bserv, ipld.MaxSizeBatchOption(len(shares))) - tree := wrapper.NewErasuredNamespacedMerkleTree(uint64(len(shares)/2), uint(rowIdx), - nmt.NodeVisitor(batchAdder.Visit)) - for _, shr := range shares { - err := tree.Push(shr) - if err != nil { - return shwap.RowNamespaceData{}, err - } - } - - root, err := tree.Root() - if err != nil { - return shwap.RowNamespaceData{}, err - } - - err = batchAdder.Commit() - if err != nil { - return shwap.RowNamespaceData{}, err - } - - row, proof, err := ipld.GetSharesByNamespace(context.TODO(), bserv, root, namespace, len(shares)) - if err != nil { - return shwap.RowNamespaceData{}, err - } - return shwap.RowNamespaceData{ - Shares: row, - Proof: proof, - }, nil -} diff --git a/share/store/file/mem_file_test.go b/share/store/file/mem_file_test.go deleted file mode 100644 index 0297dc63e6..0000000000 --- a/share/store/file/mem_file_test.go +++ /dev/null @@ -1,51 +0,0 @@ -package file - -import ( - "context" - mrand "math/rand" - "testing" - - "github.com/stretchr/testify/require" - - "github.com/celestiaorg/celestia-node/share" - "github.com/celestiaorg/celestia-node/share/eds/edstest" - "github.com/celestiaorg/celestia-node/share/sharetest" -) - -func TestMemFileShare(t *testing.T) { - eds := edstest.RandEDS(t, 32) - root, err := share.NewRoot(eds) - require.NoError(t, err) - fl := &MemFile{Eds: eds} - - width := int(eds.Width()) - for rowIdx := 0; rowIdx < width; rowIdx++ { - for colIdx := 0; colIdx < width; colIdx++ { - shr, err := fl.Share(context.TODO(), rowIdx, colIdx) - require.NoError(t, err) - - err = shr.Validate(root, rowIdx, colIdx) - require.NoError(t, err) - } - } -} - -func TestMemFileDate(t *testing.T) { - size := 32 - - // generate EDS with random data and some shares with the same namespace - namespace := sharetest.RandV0Namespace() - amount := mrand.Intn(size*size-1) + 1 - eds, dah := edstest.RandEDSWithNamespace(t, namespace, amount, size) - - file := &MemFile{Eds: eds} - - for i, root := range dah.RowRoots { - if !namespace.IsOutsideRange(root, root) { - nd, err := file.Data(context.Background(), namespace, i) - require.NoError(t, err) - err = nd.Validate(dah, namespace, i) - require.NoError(t, err) - } - } -}