Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(shwap): Add ODS file #3482

Merged
merged 13 commits into from
Jun 19, 2024
2 changes: 2 additions & 0 deletions share/shwap/sample.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ type Sample struct {
ProofType rsmt2d.Axis // ProofType indicates whether the proof is against a row or a column.
}

// SampleFromShares creates a Sample from a list of shares, using the specified proof type and
// the share index to be included in the sample.
func SampleFromShares(shares []share.Share, proofType rsmt2d.Axis, axisIdx, shrIdx int) (Sample, error) {
tree := wrapper.NewErasuredNamespacedMerkleTree(uint64(len(shares)/2), uint(axisIdx))
for _, shr := range shares {
Expand Down
38 changes: 20 additions & 18 deletions store/file/ods.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,18 +141,15 @@ func (f *ODSFile) Sample(ctx context.Context, rowIdx, colIdx int) (shwap.Sample,

// AxisHalf returns half of shares axis of the given type and index. Side is determined by
// implementation. Implementations should indicate the side in the returned AxisHalf.
func (f *ODSFile) AxisHalf(ctx context.Context, axisType rsmt2d.Axis, axisIdx int) (eds.AxisHalf, error) {
func (f *ODSFile) AxisHalf(_ context.Context, axisType rsmt2d.Axis, axisIdx int) (eds.AxisHalf, error) {
// Read the axis from the file if the axis is a row and from the top half of the square, or if the
// axis is a column and from the left half of the square.
if axisIdx < f.size()/2 {
shares, err := f.readAxisHalf(axisType, axisIdx)
half, err := f.readAxisHalf(axisType, axisIdx)
if err != nil {
return eds.AxisHalf{}, fmt.Errorf("reading axis half: %w", err)
}
return eds.AxisHalf{
Shares: shares,
IsParity: false,
}, nil
return half, nil
}

// if axis is from the second half of the square, read full ODS and compute the axis half
Expand All @@ -161,14 +158,11 @@ func (f *ODSFile) AxisHalf(ctx context.Context, axisType rsmt2d.Axis, axisIdx in
return eds.AxisHalf{}, err
}

shares, err := f.ods.computeAxisHalf(ctx, axisType, axisIdx)
half, err := f.ods.computeAxisHalf(axisType, axisIdx)
if err != nil {
return eds.AxisHalf{}, fmt.Errorf("computing axis half: %w", err)
}
return eds.AxisHalf{
Shares: shares,
IsParity: false,
}, nil
return half, nil
}

// RowNamespaceData returns data for the given namespace and row index.
Expand All @@ -193,21 +187,29 @@ func (f *ODSFile) Shares(context.Context) ([]share.Share, error) {
return f.ods.shares()
}

func (f *ODSFile) readAxisHalf(axisType rsmt2d.Axis, axisIdx int) ([]share.Share, error) {
func (f *ODSFile) readAxisHalf(axisType rsmt2d.Axis, axisIdx int) (eds.AxisHalf, error) {
f.lock.RLock()
ODS := f.ods
f.lock.RUnlock()
if ODS != nil {
return f.ods.axisHalf(context.Background(), axisType, axisIdx)
return f.ods.axisHalf(axisType, axisIdx)
}

switch axisType {
case rsmt2d.Col:
return f.readCol(axisIdx, 0)
col, err := f.readCol(axisIdx, 0)
return eds.AxisHalf{
Shares: col,
IsParity: false,
}, err
case rsmt2d.Row:
return f.readRow(axisIdx)
row, err := f.readRow(axisIdx)
return eds.AxisHalf{
Shares: row,
IsParity: false,
}, err
}
return nil, fmt.Errorf("unknown axis")
return eds.AxisHalf{}, fmt.Errorf("unknown axis")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uber optional nit: dedup

}

func (f *ODSFile) readODS() error {
Expand All @@ -233,7 +235,7 @@ func (f *ODSFile) readODS() error {

func (f *ODSFile) readRow(idx int) ([]share.Share, error) {
shrLn := int(f.hdr.shareSize)
ODSLn := int(f.size()) / 2
ODSLn := f.size() / 2

shares := make([]share.Share, ODSLn)

Expand All @@ -253,7 +255,7 @@ func (f *ODSFile) readRow(idx int) ([]share.Share, error) {

func (f *ODSFile) readCol(axisIdx, quadrantIdx int) ([]share.Share, error) {
shrLn := int(f.hdr.shareSize)
ODSLn := int(f.size()) / 2
ODSLn := f.size() / 2
quadrantOffset := quadrantIdx * ODSLn * ODSLn * shrLn

shares := make([]share.Share, ODSLn)
Expand Down
56 changes: 35 additions & 21 deletions store/file/square.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package file

import (
"bufio"
"context"
"fmt"
"io"

Expand All @@ -11,6 +10,7 @@ import (
"github.com/celestiaorg/rsmt2d"

"github.com/celestiaorg/celestia-node/share"
eds "github.com/celestiaorg/celestia-node/share/new_eds"
)

type square [][]share.Share
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this PR used only for ODS? Should we call it so, then? Or will it be used for EDS as well?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or at least the readSquare func because its hardcoded to read the ODS, but square mean both ODS and EDS. Need to clarify

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sync: call it a quadrant

Expand Down Expand Up @@ -51,50 +51,57 @@ func (s square) size() int {
return len(s)
}

func (s square) axisHalf(_ context.Context, axisType rsmt2d.Axis, axisIdx int) ([]share.Share, error) {
func (s square) shares() ([]share.Share, error) {
shares := make([]share.Share, 0, s.size()*s.size())
for _, row := range s {
shares = append(shares, row...)
}
return shares, nil
}

func (s square) axisHalf(axisType rsmt2d.Axis, axisIdx int) (eds.AxisHalf, error) {
if s == nil {
return nil, fmt.Errorf("square is nil")
return eds.AxisHalf{}, fmt.Errorf("square is nil")
}

if axisIdx >= s.size() {
return nil, fmt.Errorf("index is out of square bounds")
return eds.AxisHalf{}, fmt.Errorf("index is out of square bounds")
}

// square stores rows directly in high level slice, so we can return by accessing row by index
if axisType == rsmt2d.Row {
return s[axisIdx], nil
row := s[axisIdx]
return eds.AxisHalf{
Shares: row,
IsParity: false,
}, nil
}

// construct half column from row ordered square
col := make([]share.Share, s.size())
for i := 0; i < s.size(); i++ {
col[i] = s[i][axisIdx]
}
return col, nil
return eds.AxisHalf{
Shares: col,
IsParity: false,
}, nil
}

func (s square) shares() ([]share.Share, error) {
shares := make([]share.Share, 0, s.size()*s.size())
for _, row := range s {
shares = append(shares, row...)
}
return shares, nil
}

// TODO(@walldiss): make comments with diagram of computed axis. Add more comment on actual algo
// TODO(@walldiss): Add more comment on actual algo and support it with visual diagram of computed
// axis.
func (s square) computeAxisHalf(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this a full axis?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohh, compute the opposite half, basically parity, because the square stores the ODS. This needs to be clarified in the comment or by a better function name, I believe.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sync: comment with a picture

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and generally more comments

ctx context.Context,
axisType rsmt2d.Axis,
axisIdx int,
) ([]share.Share, error) {
) (eds.AxisHalf, error) {
shares := make([]share.Share, s.size())

// extend opposite half of the square while collecting Shares for the first half of required axis
g, ctx := errgroup.WithContext(ctx)
g := errgroup.Group{}
opposite := oppositeAxis(axisType)
for i := 0; i < s.size(); i++ {
g.Go(func() error {
original, err := s.axisHalf(ctx, opposite, i)
half, err := s.axisHalf(opposite, i)
if err != nil {
return err
}
Expand All @@ -105,7 +112,11 @@ func (s square) computeAxisHalf(
}

shards := make([][]byte, s.size()*2)
copy(shards, original)
if half.IsParity {
copy(shards[s.size():], half.Shares)
} else {
copy(shards, half.Shares)
}

target := make([]bool, s.size()*2)
target[axisIdx] = true
Expand All @@ -121,7 +132,10 @@ func (s square) computeAxisHalf(
}

err := g.Wait()
return shares, err
return eds.AxisHalf{
Shares: shares,
IsParity: false,
}, err
}

func oppositeAxis(axis rsmt2d.Axis) rsmt2d.Axis {
Expand Down
Loading