diff --git a/Makefile b/Makefile index 000d01fa02..31197dc824 100644 --- a/Makefile +++ b/Makefile @@ -161,13 +161,14 @@ PB_CORE=$(shell go list -f {{.Dir}} -m github.com/tendermint/tendermint) PB_GOGO=$(shell go list -f {{.Dir}} -m github.com/gogo/protobuf) PB_CELESTIA_APP=$(shell go list -f {{.Dir}} -m github.com/celestiaorg/celestia-app) PB_NMT=$(shell go list -f {{.Dir}} -m github.com/celestiaorg/nmt) +PB_NODE=$(shell pwd) ## pb-gen: Generate protobuf code for all /pb/*.proto files in the project. pb-gen: @echo '--> Generating protobuf' @for dir in $(PB_PKGS); \ do for file in `find $$dir -type f -name "*.proto"`; \ - do protoc -I=. -I=${PB_CORE}/proto/ -I=${PB_GOGO} -I=${PB_CELESTIA_APP}/proto -I=${PB_NMT} --gogofaster_out=paths=source_relative:. $$file; \ + do protoc -I=. -I=${PB_CORE}/proto/ -I=${PB_NODE} -I=${PB_GOGO} -I=${PB_CELESTIA_APP}/proto -I=${PB_NMT} --gogofaster_out=paths=source_relative:. $$file; \ echo '-->' $$file; \ done; \ done; diff --git a/go.mod b/go.mod index d852a48c61..a3b2525e82 100644 --- a/go.mod +++ b/go.mod @@ -41,6 +41,7 @@ require ( github.com/ipfs/go-ipld-format v0.6.0 github.com/ipfs/go-log/v2 v2.5.1 github.com/ipld/go-car v0.6.2 + github.com/klauspost/reedsolomon v1.12.1 github.com/libp2p/go-libp2p v0.33.2 github.com/libp2p/go-libp2p-kad-dht v0.25.2 github.com/libp2p/go-libp2p-pubsub v0.10.1 @@ -228,7 +229,6 @@ require ( github.com/josharian/intern v1.0.0 // indirect github.com/klauspost/compress v1.17.6 // indirect github.com/klauspost/cpuid/v2 v2.2.7 // indirect - github.com/klauspost/reedsolomon v1.12.1 // indirect github.com/koron/go-ssdp v0.0.4 // indirect github.com/lib/pq v1.10.7 // indirect github.com/libp2p/go-buffer-pool v0.1.0 // indirect diff --git a/share/availability.go b/share/availability.go index f3511da450..3373a62276 100644 --- a/share/availability.go +++ b/share/availability.go @@ -4,29 +4,12 @@ import ( "context" "errors" - "github.com/celestiaorg/celestia-app/pkg/da" - "github.com/celestiaorg/rsmt2d" - "github.com/celestiaorg/celestia-node/header" ) // ErrNotAvailable is returned whenever DA sampling fails. var ErrNotAvailable = errors.New("share: data not available") -// Root represents root commitment to multiple Shares. -// In practice, it is a commitment to all the Data in a square. -type Root = da.DataAvailabilityHeader - -// NewRoot generates Root(DataAvailabilityHeader) using the -// provided extended data square. -func NewRoot(eds *rsmt2d.ExtendedDataSquare) (*Root, error) { - dah, err := da.NewDataAvailabilityHeader(eds) - if err != nil { - return nil, err - } - return &dah, nil -} - // Availability defines interface for validation of Shares' availability. // //go:generate mockgen -destination=availability/mocks/availability.go -package=mocks . Availability diff --git a/share/eds/byzantine/share_proof.go b/share/eds/byzantine/share_proof.go index dbc687e54b..d064656830 100644 --- a/share/eds/byzantine/share_proof.go +++ b/share/eds/byzantine/share_proof.go @@ -72,8 +72,8 @@ func (s *ShareWithProof) ShareWithProofToProto() *pb.Share { } } -// GetShareWithProof attempts to get a share with proof for the given share. It first tries to get a row proof -// and if that fails or proof is invalid, it tries to get a column proof. +// GetShareWithProof attempts to get a share with proof for the given share. It first tries to get +// a row proof and if that fails or proof is invalid, it tries to get a column proof. func GetShareWithProof( ctx context.Context, bGetter blockservice.BlockGetter, diff --git a/share/eds/edstest/testing.go b/share/eds/edstest/testing.go index bf5e664f90..450f706a39 100644 --- a/share/eds/edstest/testing.go +++ b/share/eds/edstest/testing.go @@ -34,12 +34,14 @@ func RandEDS(t require.TestingT, size int) *rsmt2d.ExtendedDataSquare { return eds } +// RandEDSWithNamespace generates EDS with given square size. Returned EDS will have +// namespacedAmount of shares with the given namespace. func RandEDSWithNamespace( t require.TestingT, namespace share.Namespace, - size int, + namespacedAmount, size int, ) (*rsmt2d.ExtendedDataSquare, *share.Root) { - shares := sharetest.RandSharesWithNamespace(t, namespace, size*size) + shares := sharetest.RandSharesWithNamespace(t, namespace, namespacedAmount, size*size) eds, err := rsmt2d.ComputeExtendedDataSquare(shares, share.DefaultRSMT2DCodec(), wrapper.NewConstructor(uint64(size))) require.NoError(t, err, "failure to recompute the extended data square") dah, err := share.NewRoot(eds) diff --git a/share/eds/utils.go b/share/eds/utils.go index b897dd14b5..c38b74e349 100644 --- a/share/eds/utils.go +++ b/share/eds/utils.go @@ -121,24 +121,24 @@ func CollectSharesByNamespace( utils.SetStatusAndEnd(span, err) }() - rootCIDs := ipld.FilterRootByNamespace(root, namespace) - if len(rootCIDs) == 0 { + rowIdxs := share.RowsWithNamespace(root, namespace) + if len(rowIdxs) == 0 { return []share.NamespacedRow{}, nil } errGroup, ctx := errgroup.WithContext(ctx) - shares = make([]share.NamespacedRow, len(rootCIDs)) - for i, rootCID := range rootCIDs { + shares = make([]share.NamespacedRow, len(rowIdxs)) + for i, rowIdx := range rowIdxs { // shadow loop variables, to ensure correct values are captured - i, rootCID := i, rootCID + rowIdx, rowRoot := rowIdx, root.RowRoots[rowIdx] errGroup.Go(func() error { - row, proof, err := ipld.GetSharesByNamespace(ctx, bg, rootCID, namespace, len(root.RowRoots)) + row, proof, err := ipld.GetSharesByNamespace(ctx, bg, rowRoot, namespace, len(root.RowRoots)) shares[i] = share.NamespacedRow{ Shares: row, Proof: proof, } if err != nil { - return fmt.Errorf("retrieving shares by namespace %s for row %x: %w", namespace.String(), rootCID, err) + return fmt.Errorf("retrieving shares by namespace %s for row %d: %w", namespace.String(), rowIdx, err) } return nil }) diff --git a/share/getter.go b/share/getter.go index 363225f971..0ad6f3f9f4 100644 --- a/share/getter.go +++ b/share/getter.go @@ -71,16 +71,19 @@ func (ns NamespacedShares) Verify(root *Root, namespace Namespace) error { } for i, row := range ns { + if row.Proof == nil && row.Shares == nil { + return fmt.Errorf("row verification failed: no proofs and shares") + } // verify row data against row hash from original root - if !row.verify(originalRoots[i], namespace) { + if !row.Verify(originalRoots[i], namespace) { return fmt.Errorf("row verification failed: row %d doesn't match original root: %s", i, root.String()) } } return nil } -// verify validates the row using nmt inclusion proof. -func (row *NamespacedRow) verify(rowRoot []byte, namespace Namespace) bool { +// Verify validates the row using nmt inclusion proof. +func (row *NamespacedRow) Verify(rowRoot []byte, namespace Namespace) bool { // construct nmt leaves from shares by prepending namespace leaves := make([][]byte, 0, len(row.Shares)) for _, shr := range row.Shares { diff --git a/share/getters/shrex_test.go b/share/getters/shrex_test.go index 74896e6c15..10a805b13b 100644 --- a/share/getters/shrex_test.go +++ b/share/getters/shrex_test.go @@ -61,8 +61,9 @@ func TestShrexGetter(t *testing.T) { t.Cleanup(cancel) // generate test data + size := 64 namespace := sharetest.RandV0Namespace() - randEDS, dah := edstest.RandEDSWithNamespace(t, namespace, 64) + randEDS, dah := edstest.RandEDSWithNamespace(t, namespace, size*size, size) eh := headertest.RandExtendedHeaderWithRoot(t, dah) require.NoError(t, edsStore.Put(ctx, dah.Hash(), randEDS)) peerManager.Validate(ctx, srvHost.ID(), shrexsub.Notification{ diff --git a/share/ipld/get_shares.go b/share/ipld/get_shares.go index 98db7012b5..2883e62761 100644 --- a/share/ipld/get_shares.go +++ b/share/ipld/get_shares.go @@ -44,12 +44,13 @@ func GetShares(ctx context.Context, bg blockservice.BlockGetter, root cid.Cid, s func GetSharesByNamespace( ctx context.Context, bGetter blockservice.BlockGetter, - root cid.Cid, + root []byte, namespace share.Namespace, maxShares int, ) ([]share.Share, *nmt.Proof, error) { + rootCid := MustCidFromNamespacedSha256(root) data := NewNamespaceData(maxShares, namespace, WithLeaves(), WithProofs()) - err := data.CollectLeavesByNamespace(ctx, bGetter, root) + err := data.CollectLeavesByNamespace(ctx, bGetter, rootCid) if err != nil { return nil, nil, err } diff --git a/share/ipld/get_shares_test.go b/share/ipld/get_shares_test.go index 2f5d630473..c5b606acee 100644 --- a/share/ipld/get_shares_test.go +++ b/share/ipld/get_shares_test.go @@ -174,8 +174,7 @@ func TestGetSharesByNamespace(t *testing.T) { rowRoots, err := eds.RowRoots() require.NoError(t, err) for _, row := range rowRoots { - rcid := MustCidFromNamespacedSha256(row) - rowShares, _, err := GetSharesByNamespace(ctx, bServ, rcid, namespace, len(rowRoots)) + rowShares, _, err := GetSharesByNamespace(ctx, bServ, row, namespace, len(rowRoots)) if errors.Is(err, ErrNamespaceOutsideRange) { continue } @@ -363,8 +362,7 @@ func TestGetSharesWithProofsByNamespace(t *testing.T) { rowRoots, err := eds.RowRoots() require.NoError(t, err) for _, row := range rowRoots { - rcid := MustCidFromNamespacedSha256(row) - rowShares, proof, err := GetSharesByNamespace(ctx, bServ, rcid, namespace, len(rowRoots)) + rowShares, proof, err := GetSharesByNamespace(ctx, bServ, row, namespace, len(rowRoots)) if namespace.IsOutsideRange(row, row) { require.ErrorIs(t, err, ErrNamespaceOutsideRange) continue @@ -386,7 +384,7 @@ func TestGetSharesWithProofsByNamespace(t *testing.T) { share.NewSHA256Hasher(), namespace.ToNMT(), leaves, - NamespacedSha256FromCID(rcid)) + row) require.True(t, verified) // verify inclusion @@ -394,7 +392,7 @@ func TestGetSharesWithProofsByNamespace(t *testing.T) { share.NewSHA256Hasher(), namespace.ToNMT(), rowShares, - NamespacedSha256FromCID(rcid)) + row) require.True(t, verified) } } diff --git a/share/namespace.go b/share/namespace.go index df4ad74058..2cea574bbc 100644 --- a/share/namespace.go +++ b/share/namespace.go @@ -2,7 +2,9 @@ package share import ( "bytes" + "encoding/binary" "encoding/hex" + "errors" "fmt" appns "github.com/celestiaorg/celestia-app/pkg/namespace" @@ -182,3 +184,49 @@ func (n Namespace) IsGreater(target Namespace) bool { func (n Namespace) IsGreaterOrEqualThan(target Namespace) bool { return bytes.Compare(n, target) > -1 } + +// AddInt adds arbitrary int value to namespace, treating namespace as big-endian +// implementation of int +func (n Namespace) AddInt(val int) (Namespace, error) { + if val == 0 { + return n, nil + } + // Convert the input integer to a byte slice and add it to result slice + result := make([]byte, len(n)) + if val > 0 { + binary.BigEndian.PutUint64(result[len(n)-8:], uint64(val)) + } else { + binary.BigEndian.PutUint64(result[len(n)-8:], uint64(-val)) + } + + // Perform addition byte by byte + var carry int + for i := len(n) - 1; i >= 0; i-- { + var sum int + if val > 0 { + sum = int(n[i]) + int(result[i]) + carry + } else { + sum = int(n[i]) - int(result[i]) + carry + } + + switch { + case sum > 255: + carry = 1 + sum -= 256 + case sum < 0: + carry = -1 + sum += 256 + default: + carry = 0 + } + + result[i] = uint8(sum) + } + + // Handle any remaining carry + if carry != 0 { + return nil, errors.New("namespace overflow") + } + + return result, nil +} diff --git a/share/new_eds/accessor.go b/share/new_eds/accessor.go new file mode 100644 index 0000000000..09920161f6 --- /dev/null +++ b/share/new_eds/accessor.go @@ -0,0 +1,34 @@ +package eds + +import ( + "context" + "io" + + "github.com/celestiaorg/rsmt2d" + + "github.com/celestiaorg/celestia-node/share" + "github.com/celestiaorg/celestia-node/share/shwap" +) + +// Accessor is an interface for accessing extended data square data. +type Accessor interface { + // Size returns square size of the Accessor. + Size(ctx context.Context) int + // Sample returns share and corresponding proof for row and column indices. Implementation can + // choose which axis to use for proof. Chosen axis for proof should be indicated in the returned + // Sample. + Sample(ctx context.Context, rowIdx, colIdx int) (shwap.Sample, error) + // AxisHalf returns half of shares axis of the given type and index. Side is determined by + // implementation. Implementations should indicate the side in the returned AxisHalf. + AxisHalf(ctx context.Context, axisType rsmt2d.Axis, axisIdx int) (AxisHalf, error) + // RowNamespaceData returns data for the given namespace and row index. + RowNamespaceData(ctx context.Context, namespace share.Namespace, rowIdx int) (shwap.RowNamespaceData, error) + // Shares returns data (ODS) shares extracted from the Accessor. + Shares(ctx context.Context) ([]share.Share, error) +} + +// AccessorCloser is an interface that groups Accessor and io.Closer interfaces. +type AccessorCloser interface { + Accessor + io.Closer +} diff --git a/share/new_eds/axis_half.go b/share/new_eds/axis_half.go new file mode 100644 index 0000000000..dede70ebbc --- /dev/null +++ b/share/new_eds/axis_half.go @@ -0,0 +1,69 @@ +package eds + +import ( + "fmt" + + "github.com/celestiaorg/celestia-node/share" + "github.com/celestiaorg/celestia-node/share/shwap" +) + +// AxisHalf represents a half of data for a row or column in the EDS. +type AxisHalf struct { + Shares []share.Share + // IsParity indicates whether the half is parity or data. + IsParity bool +} + +// ToRow converts the AxisHalf to a shwap.Row. +func (a AxisHalf) ToRow() shwap.Row { + side := shwap.Left + if a.IsParity { + side = shwap.Right + } + return shwap.NewRow(a.Shares, side) +} + +// Extended returns full axis shares from half axis shares. +func (a AxisHalf) Extended() ([]share.Share, error) { + if a.IsParity { + return reconstructShares(a.Shares) + } + return extendShares(a.Shares) +} + +// extendShares constructs full axis shares from original half axis shares. +func extendShares(original []share.Share) ([]share.Share, error) { + if len(original) == 0 { + return nil, fmt.Errorf("original shares are empty") + } + + codec := share.DefaultRSMT2DCodec() + parity, err := codec.Encode(original) + if err != nil { + return nil, fmt.Errorf("encoding: %w", err) + } + shares := make([]share.Share, len(original)*2) + copy(shares, original) + copy(shares[len(original):], parity) + return shares, nil +} + +// reconstructShares constructs full axis shares from parity half axis shares. +func reconstructShares(parity []share.Share) ([]share.Share, error) { + if len(parity) == 0 { + return nil, fmt.Errorf("parity shares are empty") + } + + sqLen := len(parity) * 2 + shares := make([]share.Share, sqLen) + for i := sqLen / 2; i < sqLen; i++ { + shares[i] = parity[i-sqLen/2] + } + + codec := share.DefaultRSMT2DCodec() + shares, err := codec.Decode(shares) + if err != nil { + return nil, fmt.Errorf("reconstructing: %w", err) + } + return shares, nil +} diff --git a/share/new_eds/axis_half_test.go b/share/new_eds/axis_half_test.go new file mode 100644 index 0000000000..752add5acd --- /dev/null +++ b/share/new_eds/axis_half_test.go @@ -0,0 +1,32 @@ +package eds + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/celestiaorg/celestia-node/share/sharetest" +) + +func TestExtendAxisHalf(t *testing.T) { + shares := sharetest.RandShares(t, 16) + + original := AxisHalf{ + Shares: shares, + IsParity: false, + } + + extended, err := original.Extended() + require.NoError(t, err) + require.Len(t, extended, len(shares)*2) + + parity := AxisHalf{ + Shares: extended[len(shares):], + IsParity: true, + } + + parityExtended, err := parity.Extended() + require.NoError(t, err) + + require.Equal(t, extended, parityExtended) +} diff --git a/share/new_eds/nd.go b/share/new_eds/nd.go new file mode 100644 index 0000000000..6031acc41e --- /dev/null +++ b/share/new_eds/nd.go @@ -0,0 +1,31 @@ +package eds + +import ( + "context" + "fmt" + + "github.com/celestiaorg/celestia-node/share" + "github.com/celestiaorg/celestia-node/share/shwap" +) + +// NamespacedData extracts shares for a specific namespace from an EDS, considering +// each row independently. It uses root to determine which rows to extract data from, +// avoiding the need to recalculate the row roots for each row. +func NamespacedData( + ctx context.Context, + root *share.Root, + eds Accessor, + namespace share.Namespace, +) (shwap.NamespacedData, error) { + rowIdxs := share.RowsWithNamespace(root, namespace) + rows := make(shwap.NamespacedData, len(rowIdxs)) + var err error + for i, idx := range rowIdxs { + rows[i], err = eds.RowNamespaceData(ctx, namespace, idx) + if err != nil { + return nil, fmt.Errorf("failed to process row %d: %w", idx, err) + } + } + + return rows, nil +} diff --git a/share/new_eds/nd_test.go b/share/new_eds/nd_test.go new file mode 100644 index 0000000000..a0780292ef --- /dev/null +++ b/share/new_eds/nd_test.go @@ -0,0 +1,32 @@ +package eds + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/celestiaorg/celestia-node/share/eds/edstest" + "github.com/celestiaorg/celestia-node/share/sharetest" +) + +func TestNamespacedData(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + t.Cleanup(cancel) + + const odsSize = 8 + sharesAmount := odsSize * odsSize + namespace := sharetest.RandV0Namespace() + for amount := 1; amount < sharesAmount; amount++ { + eds, root := edstest.RandEDSWithNamespace(t, namespace, amount, odsSize) + rsmt2d := Rsmt2D{ExtendedDataSquare: eds} + nd, err := NamespacedData(ctx, root, rsmt2d, namespace) + require.NoError(t, err) + require.True(t, len(nd) > 0) + require.Len(t, nd.Flatten(), amount) + + err = nd.Validate(root, namespace) + require.NoError(t, err) + } +} diff --git a/share/new_eds/rsmt2d.go b/share/new_eds/rsmt2d.go new file mode 100644 index 0000000000..e5ffe6704b --- /dev/null +++ b/share/new_eds/rsmt2d.go @@ -0,0 +1,116 @@ +package eds + +import ( + "context" + "fmt" + + "github.com/celestiaorg/celestia-app/pkg/wrapper" + "github.com/celestiaorg/rsmt2d" + + "github.com/celestiaorg/celestia-node/share" + "github.com/celestiaorg/celestia-node/share/shwap" +) + +var _ Accessor = Rsmt2D{} + +// Rsmt2D is a rsmt2d based in-memory implementation of Accessor. +type Rsmt2D struct { + *rsmt2d.ExtendedDataSquare +} + +// Size returns the size of the Extended Data Square. +func (eds Rsmt2D) Size(context.Context) int { + return int(eds.Width()) +} + +// Sample returns share and corresponding proof for row and column indices. +func (eds Rsmt2D) Sample( + _ context.Context, + rowIdx, colIdx int, +) (shwap.Sample, error) { + return eds.SampleForProofAxis(rowIdx, colIdx, rsmt2d.Row) +} + +// SampleForProofAxis samples a share from an Extended Data Square based on the provided +// row and column indices and proof axis. It returns a sample with the share and proof. +func (eds Rsmt2D) SampleForProofAxis( + rowIdx, colIdx int, + proofType rsmt2d.Axis, +) (shwap.Sample, error) { + axisIdx, shrIdx := relativeIndexes(rowIdx, colIdx, proofType) + shares := getAxis(eds.ExtendedDataSquare, proofType, axisIdx) + + tree := wrapper.NewErasuredNamespacedMerkleTree(uint64(eds.Width()/2), uint(axisIdx)) + for _, shr := range shares { + err := tree.Push(shr) + if err != nil { + return shwap.Sample{}, fmt.Errorf("while pushing shares to NMT: %w", err) + } + } + + prf, err := tree.ProveRange(shrIdx, shrIdx+1) + if err != nil { + return shwap.Sample{}, fmt.Errorf("while proving range share over NMT: %w", err) + } + + return shwap.Sample{ + Share: shares[shrIdx], + Proof: &prf, + ProofType: proofType, + }, nil +} + +// AxisHalf returns Shares for the first half of the axis of the given type and index. +func (eds Rsmt2D) AxisHalf(_ context.Context, axisType rsmt2d.Axis, axisIdx int) (AxisHalf, error) { + shares := getAxis(eds.ExtendedDataSquare, axisType, axisIdx) + halfShares := shares[:eds.Width()/2] + return AxisHalf{ + Shares: halfShares, + IsParity: false, + }, nil +} + +// HalfRow constructs a new shwap.Row from an Extended Data Square based on the specified index and +// side. +func (eds Rsmt2D) HalfRow(idx int, side shwap.RowSide) shwap.Row { + shares := eds.ExtendedDataSquare.Row(uint(idx)) + return shwap.RowFromShares(shares, side) +} + +// RowNamespaceData returns data for the given namespace and row index. +func (eds Rsmt2D) RowNamespaceData( + _ context.Context, + namespace share.Namespace, + rowIdx int, +) (shwap.RowNamespaceData, error) { + shares := eds.Row(uint(rowIdx)) + return shwap.RowNamespaceDataFromShares(shares, namespace, rowIdx) +} + +// Shares returns data (ODS) shares extracted from the EDS. It returns new copy of the shares each +// time. +func (eds Rsmt2D) Shares(_ context.Context) ([]share.Share, error) { + return eds.ExtendedDataSquare.FlattenedODS(), nil +} + +func getAxis(eds *rsmt2d.ExtendedDataSquare, axisType rsmt2d.Axis, axisIdx int) []share.Share { + switch axisType { + case rsmt2d.Row: + return eds.Row(uint(axisIdx)) + case rsmt2d.Col: + return eds.Col(uint(axisIdx)) + default: + panic("unknown axis") + } +} + +func relativeIndexes(rowIdx, colIdx int, axisType rsmt2d.Axis) (axisIdx, shrIdx int) { + switch axisType { + case rsmt2d.Row: + return rowIdx, colIdx + case rsmt2d.Col: + return colIdx, rowIdx + default: + panic(fmt.Sprintf("invalid proof type: %d", axisType)) + } +} diff --git a/share/new_eds/rsmt2d_test.go b/share/new_eds/rsmt2d_test.go new file mode 100644 index 0000000000..eafcb607ae --- /dev/null +++ b/share/new_eds/rsmt2d_test.go @@ -0,0 +1,71 @@ +package eds + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/celestiaorg/rsmt2d" + + "github.com/celestiaorg/celestia-node/share" + "github.com/celestiaorg/celestia-node/share/eds/edstest" + "github.com/celestiaorg/celestia-node/share/shwap" +) + +func TestMemFile(t *testing.T) { + odsSize := 8 + newAccessor := func(tb testing.TB, eds *rsmt2d.ExtendedDataSquare) Accessor { + return &Rsmt2D{ExtendedDataSquare: eds} + } + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + t.Cleanup(cancel) + + TestSuiteAccessor(ctx, t, newAccessor, odsSize) +} + +func TestRsmt2dHalfRow(t *testing.T) { + const odsSize = 8 + eds, _ := randRsmt2dAccsessor(t, odsSize) + + for rowIdx := 0; rowIdx < odsSize*2; rowIdx++ { + for _, side := range []shwap.RowSide{shwap.Left, shwap.Right} { + row := eds.HalfRow(rowIdx, side) + + want := eds.Row(uint(rowIdx)) + shares, err := row.Shares() + require.NoError(t, err) + require.Equal(t, want, shares) + } + } +} + +func TestRsmt2dSampleForProofAxis(t *testing.T) { + const odsSize = 8 + eds := edstest.RandEDS(t, odsSize) + accessor := Rsmt2D{ExtendedDataSquare: eds} + + for _, proofType := range []rsmt2d.Axis{rsmt2d.Row, rsmt2d.Col} { + for rowIdx := 0; rowIdx < odsSize*2; rowIdx++ { + for colIdx := 0; colIdx < odsSize*2; colIdx++ { + sample, err := accessor.SampleForProofAxis(rowIdx, colIdx, proofType) + require.NoError(t, err) + + want := eds.GetCell(uint(rowIdx), uint(colIdx)) + require.Equal(t, want, sample.Share) + require.Equal(t, proofType, sample.ProofType) + require.NotNil(t, sample.Proof) + require.Equal(t, sample.Proof.End()-sample.Proof.Start(), 1) + require.Len(t, sample.Proof.Nodes(), 4) + } + } + } +} + +func randRsmt2dAccsessor(t *testing.T, size int) (Rsmt2D, *share.Root) { + eds := edstest.RandEDS(t, size) + root, err := share.NewRoot(eds) + require.NoError(t, err) + return Rsmt2D{ExtendedDataSquare: eds}, root +} diff --git a/share/new_eds/testing.go b/share/new_eds/testing.go new file mode 100644 index 0000000000..fc9564dc7d --- /dev/null +++ b/share/new_eds/testing.go @@ -0,0 +1,300 @@ +package eds + +import ( + "context" + "fmt" + "strconv" + "sync" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/celestiaorg/nmt" + "github.com/celestiaorg/rsmt2d" + + "github.com/celestiaorg/celestia-node/share" + "github.com/celestiaorg/celestia-node/share/eds/edstest" + "github.com/celestiaorg/celestia-node/share/sharetest" + "github.com/celestiaorg/celestia-node/share/shwap" +) + +type createAccessor func(testing.TB, *rsmt2d.ExtendedDataSquare) Accessor + +// TestSuiteAccessor runs a suite of tests for the given Accessor implementation. +func TestSuiteAccessor( + ctx context.Context, + t *testing.T, + createAccessor createAccessor, + odsSize int, +) { + t.Run("Sample", func(t *testing.T) { + testAccessorSample(ctx, t, createAccessor, odsSize) + }) + + t.Run("AxisHalf", func(t *testing.T) { + testAccessorAxisHalf(ctx, t, createAccessor, odsSize) + }) + + t.Run("RowNamespaceData", func(t *testing.T) { + testAccessorRowNamespaceData(ctx, t, createAccessor, odsSize) + }) + + t.Run("Shares", func(t *testing.T) { + testAccessorShares(ctx, t, createAccessor, odsSize) + }) +} + +func testAccessorSample( + ctx context.Context, + t *testing.T, + createAccessor createAccessor, + odsSize int, +) { + eds := edstest.RandEDS(t, odsSize) + fl := createAccessor(t, eds) + + dah, err := share.NewRoot(eds) + require.NoError(t, err) + + width := int(eds.Width()) + t.Run("single thread", func(t *testing.T) { + for rowIdx := 0; rowIdx < width; rowIdx++ { + for colIdx := 0; colIdx < width; colIdx++ { + testSample(ctx, t, fl, dah, colIdx, rowIdx) + } + } + }) + + t.Run("parallel", func(t *testing.T) { + wg := sync.WaitGroup{} + for rowIdx := 0; rowIdx < width; rowIdx++ { + for colIdx := 0; colIdx < width; colIdx++ { + wg.Add(1) + go func(rowIdx, colIdx int) { + defer wg.Done() + testSample(ctx, t, fl, dah, rowIdx, colIdx) + }(rowIdx, colIdx) + } + } + wg.Wait() + }) +} + +func testSample( + ctx context.Context, + t *testing.T, + fl Accessor, + dah *share.Root, + rowIdx, colIdx int, +) { + shr, err := fl.Sample(ctx, rowIdx, colIdx) + require.NoError(t, err) + + err = shr.Validate(dah, rowIdx, colIdx) + require.NoError(t, err) +} + +func testAccessorRowNamespaceData( + ctx context.Context, + t *testing.T, + createAccessor createAccessor, + odsSize int, +) { + t.Run("included", func(t *testing.T) { + // generate EDS with random data and some Shares with the same namespace + sharesAmount := odsSize * odsSize + namespace := sharetest.RandV0Namespace() + // test with different amount of shares + for amount := 1; amount < sharesAmount; amount++ { + // select random amount of shares, but not less than 1 + eds, dah := edstest.RandEDSWithNamespace(t, namespace, amount, odsSize) + f := createAccessor(t, eds) + + var actualSharesAmount int + // loop over all rows and check that the amount of shares in the namespace is equal to the expected + // amount + for i, root := range dah.RowRoots { + rowData, err := f.RowNamespaceData(ctx, namespace, i) + + // namespace is not included in the row, so there should be no shares + if namespace.IsOutsideRange(root, root) { + require.ErrorIs(t, err, shwap.ErrNamespaceOutsideRange) + require.Len(t, rowData.Shares, 0) + continue + } + + actualSharesAmount += len(rowData.Shares) + require.NoError(t, err) + require.True(t, len(rowData.Shares) > 0) + err = rowData.Validate(dah, namespace, i) + require.NoError(t, err) + } + + // check that the amount of shares in the namespace is equal to the expected amount + require.Equal(t, amount, actualSharesAmount) + } + }) + + t.Run("not included", func(t *testing.T) { + // generate EDS with random data and some Shares with the same namespace + eds := edstest.RandEDS(t, odsSize) + dah, err := share.NewRoot(eds) + require.NoError(t, err) + + // loop over first half of the rows, because the second half is parity and does not contain + // namespaced shares + for i, root := range dah.RowRoots[:odsSize] { + // select namespace that within the range of root namespaces, but is not included + maxNs := nmt.MaxNamespace(root, share.NamespaceSize) + absentNs, err := share.Namespace(maxNs).AddInt(-1) + require.NoError(t, err) + + f := createAccessor(t, eds) + rowData, err := f.RowNamespaceData(ctx, absentNs, i) + require.NoError(t, err) + + // namespace is not included in the row, so there should be no shares + require.Len(t, rowData.Shares, 0) + require.True(t, rowData.Proof.IsOfAbsence()) + + err = rowData.Validate(dah, absentNs, i) + require.NoError(t, err) + } + }) +} + +func testAccessorAxisHalf( + ctx context.Context, + t *testing.T, + createAccessor createAccessor, + odsSize int, +) { + eds := edstest.RandEDS(t, odsSize) + fl := createAccessor(t, eds) + + t.Run("single thread", func(t *testing.T) { + for _, axisType := range []rsmt2d.Axis{rsmt2d.Col, rsmt2d.Row} { + for axisIdx := 0; axisIdx < int(eds.Width()); axisIdx++ { + half, err := fl.AxisHalf(ctx, axisType, axisIdx) + require.NoError(t, err) + require.Len(t, half.Shares, odsSize) + + var expected []share.Share + if half.IsParity { + expected = getAxis(eds, axisType, axisIdx)[odsSize:] + } else { + expected = getAxis(eds, axisType, axisIdx)[:odsSize] + } + + require.Equal(t, expected, half.Shares) + } + } + }) + + t.Run("parallel", func(t *testing.T) { + wg := sync.WaitGroup{} + for _, axisType := range []rsmt2d.Axis{rsmt2d.Col, rsmt2d.Row} { + for i := 0; i < int(eds.Width()); i++ { + wg.Add(1) + go func(axisType rsmt2d.Axis, idx int) { + defer wg.Done() + half, err := fl.AxisHalf(ctx, axisType, idx) + require.NoError(t, err) + require.Len(t, half.Shares, odsSize) + + var expected []share.Share + if half.IsParity { + expected = getAxis(eds, axisType, idx)[odsSize:] + } else { + expected = getAxis(eds, axisType, idx)[:odsSize] + } + + require.Equal(t, expected, half.Shares) + }(axisType, i) + } + } + wg.Wait() + }) +} + +func testAccessorShares( + ctx context.Context, + t *testing.T, + createAccessor createAccessor, + odsSize int, +) { + eds := edstest.RandEDS(t, odsSize) + fl := createAccessor(t, eds) + + shares, err := fl.Shares(ctx) + require.NoError(t, err) + expected := eds.FlattenedODS() + require.Equal(t, expected, shares) +} + +func BenchGetHalfAxisFromAccessor( + ctx context.Context, + b *testing.B, + newAccessor func(size int) Accessor, + minOdsSize, maxOdsSize int, +) { + for size := minOdsSize; size <= maxOdsSize; size *= 2 { + f := newAccessor(size) + + // loop over all possible axis types and quadrants + for _, squareHalf := range []int{0, 1} { + for _, axisType := range []rsmt2d.Axis{rsmt2d.Row, rsmt2d.Col} { + name := fmt.Sprintf("Size:%v/ProofType:%s/squareHalf:%s", size, axisType, strconv.Itoa(squareHalf)) + b.Run(name, func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := f.AxisHalf(ctx, axisType, f.Size(ctx)/2*(squareHalf)) + require.NoError(b, err) + } + }) + } + } + } +} + +func BenchGetSampleFromAccessor( + ctx context.Context, + b *testing.B, + newAccessor func(size int) Accessor, + minOdsSize, maxOdsSize int, +) { + for size := minOdsSize; size <= maxOdsSize; size *= 2 { + f := newAccessor(size) + + // loop over all possible axis types and quadrants + for _, q := range quadrants { + name := fmt.Sprintf("Size:%v/quadrant:%s", size, q) + b.Run(name, func(b *testing.B) { + rowIdx, colIdx := q.coordinates(f.Size(ctx)) + // warm up cache + _, err := f.Sample(ctx, rowIdx, colIdx) + require.NoError(b, err, q.String()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := f.Sample(ctx, rowIdx, colIdx) + require.NoError(b, err) + } + }) + } + } +} + +type quadrant int + +var quadrants = []quadrant{1, 2, 3, 4} + +func (q quadrant) String() string { + return strconv.Itoa(int(q)) +} + +func (q quadrant) coordinates(edsSize int) (rowIdx, colIdx int) { + colIdx = edsSize/2*(int(q-1)%2) + 1 + rowIdx = edsSize/2*(int(q-1)/2) + 1 + return rowIdx, colIdx +} diff --git a/share/root.go b/share/root.go new file mode 100644 index 0000000000..8b14b9979e --- /dev/null +++ b/share/root.go @@ -0,0 +1,82 @@ +package share + +import ( + "bytes" + "crypto/sha256" + "encoding/hex" + "fmt" + "hash" + + "github.com/celestiaorg/celestia-app/pkg/da" + "github.com/celestiaorg/rsmt2d" +) + +// Root represents root commitment to multiple Shares. +// In practice, it is a commitment to all the Data in a square. +type Root = da.DataAvailabilityHeader + +// DataHash is a representation of the Root hash. +type DataHash []byte + +func (dh DataHash) Validate() error { + if len(dh) != 32 { + return fmt.Errorf("invalid hash size, expected 32, got %d", len(dh)) + } + return nil +} + +func (dh DataHash) String() string { + return fmt.Sprintf("%X", []byte(dh)) +} + +// IsEmptyRoot check whether DataHash corresponds to the root of an empty block EDS. +func (dh DataHash) IsEmptyRoot() bool { + return bytes.Equal(EmptyRoot().Hash(), dh) +} + +// NewSHA256Hasher returns a new instance of a SHA-256 hasher. +func NewSHA256Hasher() hash.Hash { + return sha256.New() +} + +// NewRoot generates Root(DataAvailabilityHeader) using the +// provided extended data square. +func NewRoot(eds *rsmt2d.ExtendedDataSquare) (*Root, error) { + dah, err := da.NewDataAvailabilityHeader(eds) + if err != nil { + return nil, err + } + return &dah, nil +} + +// RowsWithNamespace inspects the Root for the Namespace and provides +// a slices of Row indexes containing the namespace. +func RowsWithNamespace(root *Root, namespace Namespace) (idxs []int) { + for i, row := range root.RowRoots { + if !namespace.IsOutsideRange(row, row) { + idxs = append(idxs, i) + } + } + return +} + +// RootHashForCoordinates returns the root hash for the given coordinates. +func RootHashForCoordinates(r *Root, axisType rsmt2d.Axis, rowIdx, colIdx uint) []byte { + if axisType == rsmt2d.Row { + return r.RowRoots[rowIdx] + } + return r.ColumnRoots[colIdx] +} + +// MustDataHashFromString converts a hex string to a valid datahash. +func MustDataHashFromString(datahash string) DataHash { + dh, err := hex.DecodeString(datahash) + if err != nil { + panic(fmt.Sprintf("datahash conversion: passed string was not valid hex: %s", datahash)) + } + err = DataHash(dh).Validate() + if err != nil { + panic(fmt.Sprintf("datahash validation: passed hex string failed: %s", err)) + } + return dh +} diff --git a/share/share.go b/share/share.go index 298c3d46d4..b97bab3589 100644 --- a/share/share.go +++ b/share/share.go @@ -1,13 +1,12 @@ package share import ( - "bytes" "crypto/sha256" - "encoding/hex" "fmt" - "hash" "github.com/celestiaorg/celestia-app/pkg/appconsts" + "github.com/celestiaorg/nmt" + "github.com/celestiaorg/rsmt2d" ) // DefaultRSMT2DCodec sets the default rsmt2d.Codec for shares. @@ -38,39 +37,35 @@ func GetData(s Share) []byte { return s[NamespaceSize:] } -// DataHash is a representation of the Root hash. -type DataHash []byte - -func (dh DataHash) Validate() error { - if len(dh) != 32 { - return fmt.Errorf("invalid hash size, expected 32, got %d", len(dh)) +// ValidateShare checks the size of a given share. +func ValidateShare(s Share) error { + if len(s) != Size { + return fmt.Errorf("invalid share size: %d", len(s)) } return nil } -func (dh DataHash) String() string { - return fmt.Sprintf("%X", []byte(dh)) -} - -// IsEmptyRoot check whether DataHash corresponds to the root of an empty block EDS. -func (dh DataHash) IsEmptyRoot() bool { - return bytes.Equal(EmptyRoot().Hash(), dh) +// ShareWithProof contains data with corresponding Merkle Proof +type ShareWithProof struct { //nolint: revive + // Share is a full data including namespace + Share + // Proof is a Merkle Proof of current share + Proof *nmt.Proof + // Axis is a type of axis against which the share proof is computed + Axis rsmt2d.Axis } -// MustDataHashFromString converts a hex string to a valid datahash. -func MustDataHashFromString(datahash string) DataHash { - dh, err := hex.DecodeString(datahash) - if err != nil { - panic(fmt.Sprintf("datahash conversion: passed string was not valid hex: %s", datahash)) +// Validate validates inclusion of the share under the given root CID. +func (s *ShareWithProof) Validate(rootHash []byte, x, y, edsSize int) bool { + isParity := x >= edsSize/2 || y >= edsSize/2 + namespace := ParitySharesNamespace + if !isParity { + namespace = GetNamespace(s.Share) } - err = DataHash(dh).Validate() - if err != nil { - panic(fmt.Sprintf("datahash validation: passed hex string failed: %s", err)) - } - return dh -} - -// NewSHA256Hasher returns a new instance of a SHA-256 hasher. -func NewSHA256Hasher() hash.Hash { - return sha256.New() + return s.Proof.VerifyInclusion( + sha256.New(), // TODO(@Wondertan): This should be defined somewhere globally + namespace.ToNMT(), + [][]byte{s.Share}, + rootHash, + ) } diff --git a/share/sharetest/testing.go b/share/sharetest/testing.go index 3889260393..6564af9b06 100644 --- a/share/sharetest/testing.go +++ b/share/sharetest/testing.go @@ -38,17 +38,26 @@ func RandShares(t require.TestingT, total int) []share.Share { } // RandSharesWithNamespace is same the as RandShares, but sets same namespace for all shares. -func RandSharesWithNamespace(t require.TestingT, namespace share.Namespace, total int) []share.Share { +func RandSharesWithNamespace(t require.TestingT, namespace share.Namespace, namespacedAmount, total int) []share.Share { if total&(total-1) != 0 { t.Errorf("total must be power of 2: %d", total) t.FailNow() } + if namespacedAmount > total { + t.Errorf("withNamespace must be less than total: %d", total) + t.FailNow() + } + shares := make([]share.Share, total) rnd := rand.New(rand.NewSource(time.Now().Unix())) //nolint:gosec for i := range shares { shr := make([]byte, share.Size) - copy(share.GetNamespace(shr), namespace) + if i < namespacedAmount { + copy(share.GetNamespace(shr), namespace) + } else { + copy(share.GetNamespace(shr), RandV0Namespace()) + } _, err := rnd.Read(share.GetData(shr)) require.NoError(t, err) shares[i] = shr diff --git a/share/shwap/eds_id.go b/share/shwap/eds_id.go new file mode 100644 index 0000000000..fe0734ad75 --- /dev/null +++ b/share/shwap/eds_id.go @@ -0,0 +1,63 @@ +package shwap + +import ( + "encoding/binary" + "fmt" + + "github.com/celestiaorg/celestia-node/share" +) + +// EdsIDSize defines the byte size of the EdsID. +const EdsIDSize = 8 + +// EdsID represents a unique identifier for a row, using the height of the block +// to identify the data square in the chain. +type EdsID struct { + Height uint64 // Height specifies the block height. +} + +// NewEdsID creates a new EdsID using the given height and verifies it against the provided Root. +// It returns an error if the verification fails. +func NewEdsID(height uint64, root *share.Root) (EdsID, error) { + eid := EdsID{ + Height: height, + } + return eid, eid.Validate(root) +} + +// EdsIDFromBinary decodes a byte slice into an EdsID, validating the length of the data. +// It returns an error if the data slice does not match the expected size of an EdsID. +func EdsIDFromBinary(data []byte) (EdsID, error) { + if len(data) != EdsIDSize { + return EdsID{}, fmt.Errorf("invalid EdsID data length: %d != %d", len(data), EdsIDSize) + } + rid := EdsID{ + Height: binary.BigEndian.Uint64(data), + } + return rid, nil +} + +// MarshalBinary encodes an EdsID into its binary form, primarily for storage or network +// transmission. +func (eid EdsID) MarshalBinary() ([]byte, error) { + data := make([]byte, 0, EdsIDSize) + return eid.appendTo(data), nil +} + +// Validate checks the integrity of an EdsID's fields against the provided Root. +// It ensures that the EdsID is not constructed with a zero Height and that the root is not nil. +func (eid EdsID) Validate(root *share.Root) error { + if root == nil { + return fmt.Errorf("provided Root is nil") + } + if eid.Height == 0 { + return fmt.Errorf("height cannot be zero") + } + return nil +} + +// appendTo helps in the binary encoding of EdsID by appending the binary form of Height to the +// given byte slice. +func (eid EdsID) appendTo(data []byte) []byte { + return binary.BigEndian.AppendUint64(data, eid.Height) +} diff --git a/share/shwap/eds_id_test.go b/share/shwap/eds_id_test.go new file mode 100644 index 0000000000..58b3cf0916 --- /dev/null +++ b/share/shwap/eds_id_test.go @@ -0,0 +1,30 @@ +package shwap + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/celestiaorg/celestia-node/share" + "github.com/celestiaorg/celestia-node/share/eds/edstest" +) + +func TestEdsID(t *testing.T) { + square := edstest.RandEDS(t, 2) + root, err := share.NewRoot(square) + require.NoError(t, err) + + id, err := NewEdsID(2, root) + require.NoError(t, err) + + data, err := id.MarshalBinary() + require.NoError(t, err) + + idOut, err := EdsIDFromBinary(data) + require.NoError(t, err) + assert.EqualValues(t, id, idOut) + + err = idOut.Validate(root) + require.NoError(t, err) +} diff --git a/share/shwap/namespace_data.go b/share/shwap/namespace_data.go new file mode 100644 index 0000000000..687958be15 --- /dev/null +++ b/share/shwap/namespace_data.go @@ -0,0 +1,35 @@ +package shwap + +import ( + "fmt" + + "github.com/celestiaorg/celestia-node/share" +) + +// NamespacedData stores collections of RowNamespaceData, each representing shares and their proofs +// within a namespace. +type NamespacedData []RowNamespaceData + +// Flatten combines all shares from all rows within the namespace into a single slice. +func (ns NamespacedData) Flatten() []share.Share { + var shares []share.Share + for _, row := range ns { + shares = append(shares, row.Shares...) + } + return shares +} + +// Validate checks the integrity of the NamespacedData against a provided root and namespace. +func (ns NamespacedData) Validate(root *share.Root, namespace share.Namespace) error { + rowIdxs := share.RowsWithNamespace(root, namespace) + if len(rowIdxs) != len(ns) { + return fmt.Errorf("expected %d rows, found %d rows", len(rowIdxs), len(ns)) + } + + for i, row := range ns { + if err := row.Validate(root, namespace, rowIdxs[i]); err != nil { + return fmt.Errorf("validating row: %w", err) + } + } + return nil +} diff --git a/share/shwap/pb/shwap.pb.go b/share/shwap/pb/shwap.pb.go new file mode 100644 index 0000000000..000bf78ca7 --- /dev/null +++ b/share/shwap/pb/shwap.pb.go @@ -0,0 +1,1114 @@ +// Code generated by protoc-gen-gogo. DO NOT EDIT. +// source: share/shwap/pb/shwap.proto + +package pb + +import ( + fmt "fmt" + pb "github.com/celestiaorg/nmt/pb" + proto "github.com/gogo/protobuf/proto" + io "io" + math "math" + math_bits "math/bits" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package + +type AxisType int32 + +const ( + AxisType_ROW AxisType = 0 + AxisType_COL AxisType = 1 +) + +var AxisType_name = map[int32]string{ + 0: "ROW", + 1: "COL", +} + +var AxisType_value = map[string]int32{ + "ROW": 0, + "COL": 1, +} + +func (x AxisType) String() string { + return proto.EnumName(AxisType_name, int32(x)) +} + +func (AxisType) EnumDescriptor() ([]byte, []int) { + return fileDescriptor_9431653f3c9f0bcb, []int{0} +} + +type Row_HalfSide int32 + +const ( + Row_LEFT Row_HalfSide = 0 + Row_RIGHT Row_HalfSide = 1 +) + +var Row_HalfSide_name = map[int32]string{ + 0: "LEFT", + 1: "RIGHT", +} + +var Row_HalfSide_value = map[string]int32{ + "LEFT": 0, + "RIGHT": 1, +} + +func (x Row_HalfSide) String() string { + return proto.EnumName(Row_HalfSide_name, int32(x)) +} + +func (Row_HalfSide) EnumDescriptor() ([]byte, []int) { + return fileDescriptor_9431653f3c9f0bcb, []int{0, 0} +} + +type Row struct { + SharesHalf []*Share `protobuf:"bytes,1,rep,name=shares_half,json=sharesHalf,proto3" json:"shares_half,omitempty"` + HalfSide Row_HalfSide `protobuf:"varint,2,opt,name=half_side,json=halfSide,proto3,enum=shwap.Row_HalfSide" json:"half_side,omitempty"` +} + +func (m *Row) Reset() { *m = Row{} } +func (m *Row) String() string { return proto.CompactTextString(m) } +func (*Row) ProtoMessage() {} +func (*Row) Descriptor() ([]byte, []int) { + return fileDescriptor_9431653f3c9f0bcb, []int{0} +} +func (m *Row) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *Row) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_Row.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *Row) XXX_Merge(src proto.Message) { + xxx_messageInfo_Row.Merge(m, src) +} +func (m *Row) XXX_Size() int { + return m.Size() +} +func (m *Row) XXX_DiscardUnknown() { + xxx_messageInfo_Row.DiscardUnknown(m) +} + +var xxx_messageInfo_Row proto.InternalMessageInfo + +func (m *Row) GetSharesHalf() []*Share { + if m != nil { + return m.SharesHalf + } + return nil +} + +func (m *Row) GetHalfSide() Row_HalfSide { + if m != nil { + return m.HalfSide + } + return Row_LEFT +} + +type Sample struct { + Share *Share `protobuf:"bytes,1,opt,name=share,proto3" json:"share,omitempty"` + Proof *pb.Proof `protobuf:"bytes,2,opt,name=proof,proto3" json:"proof,omitempty"` + ProofType AxisType `protobuf:"varint,3,opt,name=proof_type,json=proofType,proto3,enum=shwap.AxisType" json:"proof_type,omitempty"` +} + +func (m *Sample) Reset() { *m = Sample{} } +func (m *Sample) String() string { return proto.CompactTextString(m) } +func (*Sample) ProtoMessage() {} +func (*Sample) Descriptor() ([]byte, []int) { + return fileDescriptor_9431653f3c9f0bcb, []int{1} +} +func (m *Sample) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *Sample) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_Sample.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *Sample) XXX_Merge(src proto.Message) { + xxx_messageInfo_Sample.Merge(m, src) +} +func (m *Sample) XXX_Size() int { + return m.Size() +} +func (m *Sample) XXX_DiscardUnknown() { + xxx_messageInfo_Sample.DiscardUnknown(m) +} + +var xxx_messageInfo_Sample proto.InternalMessageInfo + +func (m *Sample) GetShare() *Share { + if m != nil { + return m.Share + } + return nil +} + +func (m *Sample) GetProof() *pb.Proof { + if m != nil { + return m.Proof + } + return nil +} + +func (m *Sample) GetProofType() AxisType { + if m != nil { + return m.ProofType + } + return AxisType_ROW +} + +type RowNamespaceData struct { + Shares []*Share `protobuf:"bytes,1,rep,name=shares,proto3" json:"shares,omitempty"` + Proof *pb.Proof `protobuf:"bytes,2,opt,name=proof,proto3" json:"proof,omitempty"` +} + +func (m *RowNamespaceData) Reset() { *m = RowNamespaceData{} } +func (m *RowNamespaceData) String() string { return proto.CompactTextString(m) } +func (*RowNamespaceData) ProtoMessage() {} +func (*RowNamespaceData) Descriptor() ([]byte, []int) { + return fileDescriptor_9431653f3c9f0bcb, []int{2} +} +func (m *RowNamespaceData) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *RowNamespaceData) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_RowNamespaceData.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *RowNamespaceData) XXX_Merge(src proto.Message) { + xxx_messageInfo_RowNamespaceData.Merge(m, src) +} +func (m *RowNamespaceData) XXX_Size() int { + return m.Size() +} +func (m *RowNamespaceData) XXX_DiscardUnknown() { + xxx_messageInfo_RowNamespaceData.DiscardUnknown(m) +} + +var xxx_messageInfo_RowNamespaceData proto.InternalMessageInfo + +func (m *RowNamespaceData) GetShares() []*Share { + if m != nil { + return m.Shares + } + return nil +} + +func (m *RowNamespaceData) GetProof() *pb.Proof { + if m != nil { + return m.Proof + } + return nil +} + +type Share struct { + Data []byte `protobuf:"bytes,1,opt,name=data,proto3" json:"data,omitempty"` +} + +func (m *Share) Reset() { *m = Share{} } +func (m *Share) String() string { return proto.CompactTextString(m) } +func (*Share) ProtoMessage() {} +func (*Share) Descriptor() ([]byte, []int) { + return fileDescriptor_9431653f3c9f0bcb, []int{3} +} +func (m *Share) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *Share) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_Share.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *Share) XXX_Merge(src proto.Message) { + xxx_messageInfo_Share.Merge(m, src) +} +func (m *Share) XXX_Size() int { + return m.Size() +} +func (m *Share) XXX_DiscardUnknown() { + xxx_messageInfo_Share.DiscardUnknown(m) +} + +var xxx_messageInfo_Share proto.InternalMessageInfo + +func (m *Share) GetData() []byte { + if m != nil { + return m.Data + } + return nil +} + +func init() { + proto.RegisterEnum("shwap.AxisType", AxisType_name, AxisType_value) + proto.RegisterEnum("shwap.Row_HalfSide", Row_HalfSide_name, Row_HalfSide_value) + proto.RegisterType((*Row)(nil), "shwap.Row") + proto.RegisterType((*Sample)(nil), "shwap.Sample") + proto.RegisterType((*RowNamespaceData)(nil), "shwap.RowNamespaceData") + proto.RegisterType((*Share)(nil), "shwap.Share") +} + +func init() { proto.RegisterFile("share/shwap/pb/shwap.proto", fileDescriptor_9431653f3c9f0bcb) } + +var fileDescriptor_9431653f3c9f0bcb = []byte{ + // 381 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x8c, 0x92, 0x4f, 0x6b, 0xe2, 0x40, + 0x18, 0xc6, 0x33, 0x1b, 0xe3, 0xc6, 0x57, 0xd1, 0x30, 0x7b, 0x09, 0xee, 0x92, 0x95, 0xb0, 0x0b, + 0xb2, 0x60, 0xb2, 0xe8, 0x27, 0xd8, 0xbf, 0xb5, 0x60, 0x6b, 0x19, 0x85, 0x42, 0x2f, 0x61, 0x62, + 0x46, 0x13, 0x88, 0x9d, 0x21, 0x49, 0x49, 0x3d, 0xf7, 0xd0, 0x6b, 0x3f, 0x56, 0x8f, 0x1e, 0x7b, + 0x2c, 0xfa, 0x45, 0x4a, 0x26, 0xb1, 0x14, 0xda, 0x43, 0x6f, 0xbf, 0xcc, 0xf3, 0xcc, 0xbc, 0xcf, + 0x13, 0x5e, 0xe8, 0xa6, 0x21, 0x4d, 0x98, 0x9b, 0x86, 0x39, 0x15, 0xae, 0xf0, 0x4b, 0x70, 0x44, + 0xc2, 0x33, 0x8e, 0x35, 0xf9, 0xd1, 0x6d, 0x0b, 0xdf, 0x15, 0x09, 0xe7, 0xcb, 0xf2, 0xd8, 0xbe, + 0x45, 0xa0, 0x12, 0x9e, 0xe3, 0x01, 0x34, 0xe5, 0xe5, 0xd4, 0x0b, 0x69, 0xbc, 0x34, 0x51, 0x4f, + 0xed, 0x37, 0x87, 0x2d, 0xa7, 0x7c, 0x61, 0x56, 0x28, 0x04, 0x4a, 0xc3, 0x98, 0xc6, 0x4b, 0xfc, + 0x13, 0x1a, 0x85, 0xcf, 0x4b, 0xa3, 0x80, 0x99, 0x1f, 0x7a, 0xa8, 0xdf, 0x1e, 0x7e, 0xaa, 0xcc, + 0x84, 0xe7, 0x4e, 0xe1, 0x99, 0x45, 0x01, 0x23, 0x7a, 0x58, 0x91, 0xfd, 0x15, 0xf4, 0xc3, 0x29, + 0xd6, 0xa1, 0x36, 0xf9, 0xf7, 0x7f, 0x6e, 0x28, 0xb8, 0x01, 0x1a, 0x39, 0x3e, 0x1a, 0xcf, 0x0d, + 0x64, 0xdf, 0x20, 0xa8, 0xcf, 0xe8, 0x5a, 0xc4, 0x0c, 0xdb, 0xa0, 0xc9, 0x59, 0x26, 0xea, 0xa1, + 0x57, 0x31, 0x4a, 0x09, 0x7f, 0x07, 0x4d, 0xf6, 0x90, 0xd3, 0x9b, 0xc3, 0x8e, 0x53, 0xb5, 0xf2, + 0x9d, 0xb3, 0x02, 0x48, 0xa9, 0x62, 0x07, 0x40, 0x82, 0x97, 0x6d, 0x04, 0x33, 0x55, 0x99, 0xb4, + 0x53, 0xbd, 0xf7, 0xeb, 0x3a, 0x4a, 0xe7, 0x1b, 0xc1, 0x48, 0x43, 0x5a, 0x0a, 0xb4, 0x3d, 0x30, + 0x08, 0xcf, 0x4f, 0xe9, 0x9a, 0xa5, 0x82, 0x2e, 0xd8, 0x5f, 0x9a, 0x51, 0xfc, 0x0d, 0xea, 0x65, + 0xf5, 0x37, 0x7f, 0x4b, 0xa5, 0xbd, 0x33, 0x90, 0xfd, 0x19, 0x34, 0x79, 0x0f, 0x63, 0xa8, 0x05, + 0x34, 0xa3, 0xb2, 0x63, 0x8b, 0x48, 0xfe, 0xf1, 0x05, 0xf4, 0x43, 0x28, 0xfc, 0x11, 0x54, 0x32, + 0x3d, 0x37, 0x94, 0x02, 0xfe, 0x4c, 0x27, 0x06, 0xfa, 0x7d, 0x72, 0xbf, 0xb3, 0xd0, 0x76, 0x67, + 0xa1, 0xc7, 0x9d, 0x85, 0xee, 0xf6, 0x96, 0xb2, 0xdd, 0x5b, 0xca, 0xc3, 0xde, 0x52, 0x2e, 0x46, + 0xab, 0x28, 0x0b, 0xaf, 0x7c, 0x67, 0xc1, 0xd7, 0xee, 0x82, 0xc5, 0x2c, 0xcd, 0x22, 0xca, 0x93, + 0xd5, 0x33, 0x0f, 0x2e, 0x79, 0x50, 0xec, 0xc5, 0xcb, 0xed, 0xf0, 0xeb, 0x72, 0x03, 0x46, 0x4f, + 0x01, 0x00, 0x00, 0xff, 0xff, 0x67, 0xb6, 0xc0, 0x8b, 0x36, 0x02, 0x00, 0x00, +} + +func (m *Row) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *Row) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *Row) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if m.HalfSide != 0 { + i = encodeVarintShwap(dAtA, i, uint64(m.HalfSide)) + i-- + dAtA[i] = 0x10 + } + if len(m.SharesHalf) > 0 { + for iNdEx := len(m.SharesHalf) - 1; iNdEx >= 0; iNdEx-- { + { + size, err := m.SharesHalf[iNdEx].MarshalToSizedBuffer(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarintShwap(dAtA, i, uint64(size)) + } + i-- + dAtA[i] = 0xa + } + } + return len(dAtA) - i, nil +} + +func (m *Sample) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *Sample) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *Sample) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if m.ProofType != 0 { + i = encodeVarintShwap(dAtA, i, uint64(m.ProofType)) + i-- + dAtA[i] = 0x18 + } + if m.Proof != nil { + { + size, err := m.Proof.MarshalToSizedBuffer(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarintShwap(dAtA, i, uint64(size)) + } + i-- + dAtA[i] = 0x12 + } + if m.Share != nil { + { + size, err := m.Share.MarshalToSizedBuffer(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarintShwap(dAtA, i, uint64(size)) + } + i-- + dAtA[i] = 0xa + } + return len(dAtA) - i, nil +} + +func (m *RowNamespaceData) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *RowNamespaceData) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *RowNamespaceData) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if m.Proof != nil { + { + size, err := m.Proof.MarshalToSizedBuffer(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarintShwap(dAtA, i, uint64(size)) + } + i-- + dAtA[i] = 0x12 + } + if len(m.Shares) > 0 { + for iNdEx := len(m.Shares) - 1; iNdEx >= 0; iNdEx-- { + { + size, err := m.Shares[iNdEx].MarshalToSizedBuffer(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarintShwap(dAtA, i, uint64(size)) + } + i-- + dAtA[i] = 0xa + } + } + return len(dAtA) - i, nil +} + +func (m *Share) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *Share) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *Share) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if len(m.Data) > 0 { + i -= len(m.Data) + copy(dAtA[i:], m.Data) + i = encodeVarintShwap(dAtA, i, uint64(len(m.Data))) + i-- + dAtA[i] = 0xa + } + return len(dAtA) - i, nil +} + +func encodeVarintShwap(dAtA []byte, offset int, v uint64) int { + offset -= sovShwap(v) + base := offset + for v >= 1<<7 { + dAtA[offset] = uint8(v&0x7f | 0x80) + v >>= 7 + offset++ + } + dAtA[offset] = uint8(v) + return base +} +func (m *Row) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if len(m.SharesHalf) > 0 { + for _, e := range m.SharesHalf { + l = e.Size() + n += 1 + l + sovShwap(uint64(l)) + } + } + if m.HalfSide != 0 { + n += 1 + sovShwap(uint64(m.HalfSide)) + } + return n +} + +func (m *Sample) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if m.Share != nil { + l = m.Share.Size() + n += 1 + l + sovShwap(uint64(l)) + } + if m.Proof != nil { + l = m.Proof.Size() + n += 1 + l + sovShwap(uint64(l)) + } + if m.ProofType != 0 { + n += 1 + sovShwap(uint64(m.ProofType)) + } + return n +} + +func (m *RowNamespaceData) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if len(m.Shares) > 0 { + for _, e := range m.Shares { + l = e.Size() + n += 1 + l + sovShwap(uint64(l)) + } + } + if m.Proof != nil { + l = m.Proof.Size() + n += 1 + l + sovShwap(uint64(l)) + } + return n +} + +func (m *Share) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + l = len(m.Data) + if l > 0 { + n += 1 + l + sovShwap(uint64(l)) + } + return n +} + +func sovShwap(x uint64) (n int) { + return (math_bits.Len64(x|1) + 6) / 7 +} +func sozShwap(x uint64) (n int) { + return sovShwap(uint64((x << 1) ^ uint64((int64(x) >> 63)))) +} +func (m *Row) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowShwap + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: Row: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: Row: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field SharesHalf", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowShwap + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthShwap + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthShwap + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.SharesHalf = append(m.SharesHalf, &Share{}) + if err := m.SharesHalf[len(m.SharesHalf)-1].Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + case 2: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field HalfSide", wireType) + } + m.HalfSide = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowShwap + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.HalfSide |= Row_HalfSide(b&0x7F) << shift + if b < 0x80 { + break + } + } + default: + iNdEx = preIndex + skippy, err := skipShwap(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLengthShwap + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *Sample) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowShwap + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: Sample: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: Sample: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Share", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowShwap + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthShwap + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthShwap + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.Share == nil { + m.Share = &Share{} + } + if err := m.Share.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Proof", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowShwap + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthShwap + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthShwap + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.Proof == nil { + m.Proof = &pb.Proof{} + } + if err := m.Proof.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + case 3: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field ProofType", wireType) + } + m.ProofType = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowShwap + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.ProofType |= AxisType(b&0x7F) << shift + if b < 0x80 { + break + } + } + default: + iNdEx = preIndex + skippy, err := skipShwap(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLengthShwap + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *RowNamespaceData) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowShwap + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: RowNamespaceData: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: RowNamespaceData: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Shares", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowShwap + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthShwap + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthShwap + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Shares = append(m.Shares, &Share{}) + if err := m.Shares[len(m.Shares)-1].Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Proof", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowShwap + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthShwap + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthShwap + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.Proof == nil { + m.Proof = &pb.Proof{} + } + if err := m.Proof.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipShwap(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLengthShwap + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *Share) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowShwap + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: Share: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: Share: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Data", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowShwap + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthShwap + } + postIndex := iNdEx + byteLen + if postIndex < 0 { + return ErrInvalidLengthShwap + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Data = append(m.Data[:0], dAtA[iNdEx:postIndex]...) + if m.Data == nil { + m.Data = []byte{} + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipShwap(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLengthShwap + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func skipShwap(dAtA []byte) (n int, err error) { + l := len(dAtA) + iNdEx := 0 + depth := 0 + for iNdEx < l { + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowShwap + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + wireType := int(wire & 0x7) + switch wireType { + case 0: + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowShwap + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + iNdEx++ + if dAtA[iNdEx-1] < 0x80 { + break + } + } + case 1: + iNdEx += 8 + case 2: + var length int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowShwap + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + length |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if length < 0 { + return 0, ErrInvalidLengthShwap + } + iNdEx += length + case 3: + depth++ + case 4: + if depth == 0 { + return 0, ErrUnexpectedEndOfGroupShwap + } + depth-- + case 5: + iNdEx += 4 + default: + return 0, fmt.Errorf("proto: illegal wireType %d", wireType) + } + if iNdEx < 0 { + return 0, ErrInvalidLengthShwap + } + if depth == 0 { + return iNdEx, nil + } + } + return 0, io.ErrUnexpectedEOF +} + +var ( + ErrInvalidLengthShwap = fmt.Errorf("proto: negative length found during unmarshaling") + ErrIntOverflowShwap = fmt.Errorf("proto: integer overflow") + ErrUnexpectedEndOfGroupShwap = fmt.Errorf("proto: unexpected end of group") +) diff --git a/share/shwap/pb/shwap.proto b/share/shwap/pb/shwap.proto new file mode 100644 index 0000000000..d7daea568a --- /dev/null +++ b/share/shwap/pb/shwap.proto @@ -0,0 +1,36 @@ +// Defined in CIP-19 https://github.com/celestiaorg/CIPs/blob/82aeb7dfc472105a11babffd548c730c899a3d24/cips/cip-19.md +syntax = "proto3"; +package shwap; +option go_package = "github.com/celestiaorg/celestia-node/share/shwap/pb"; + +import "pb/proof.proto"; // celestiaorg/nmt/pb/proof.proto + +message Row { + repeated Share shares_half = 1; + HalfSide half_side= 2; + + enum HalfSide { + LEFT = 0; + RIGHT = 1; + } +} + +message Sample { + Share share = 1; + proof.pb.Proof proof = 2; + AxisType proof_type = 3; +} + +message RowNamespaceData { + repeated Share shares = 1; + proof.pb.Proof proof = 2; +} + +message Share { + bytes data = 1; +} + +enum AxisType { + ROW = 0; + COL = 1; +} diff --git a/share/shwap/row.go b/share/shwap/row.go new file mode 100644 index 0000000000..f03c7366f5 --- /dev/null +++ b/share/shwap/row.go @@ -0,0 +1,142 @@ +package shwap + +import ( + "bytes" + "fmt" + + "github.com/celestiaorg/celestia-app/pkg/wrapper" + + "github.com/celestiaorg/celestia-node/share" + "github.com/celestiaorg/celestia-node/share/shwap/pb" +) + +// RowSide enumerates the possible sides of a row within an Extended Data Square (EDS). +type RowSide int + +const ( + Left RowSide = iota // Left side of the row. + Right // Right side of the row. +) + +// Row represents a portion of a row in an EDS, either left or right half. +type Row struct { + halfShares []share.Share // halfShares holds the shares of either the left or right half of a row. + side RowSide // side indicates whether the row half is left or right. +} + +// NewRow creates a new Row with the specified shares and side. +func NewRow(halfShares []share.Share, side RowSide) Row { + return Row{ + halfShares: halfShares, + side: side, + } +} + +// RowFromEDS constructs a new Row from an Extended Data Square based on the specified index and +// side. +func RowFromShares(shares []share.Share, side RowSide) Row { + var halfShares []share.Share + if side == Right { + halfShares = shares[len(shares)/2:] // Take the right half of the shares. + } else { + halfShares = shares[:len(shares)/2] // Take the left half of the shares. + } + + return NewRow(halfShares, side) +} + +// RowFromProto converts a protobuf Row to a Row structure. +func RowFromProto(r *pb.Row) Row { + return Row{ + halfShares: SharesFromProto(r.SharesHalf), + side: sideFromProto(r.GetHalfSide()), + } +} + +// Shares reconstructs the complete row shares from the half provided, using RSMT2D for data +// recovery if needed. +func (r Row) Shares() ([]share.Share, error) { + shares := make([]share.Share, len(r.halfShares)*2) + offset := 0 + if r.side == Right { + offset = len(r.halfShares) // Position the halfShares in the second half if it's the right side. + } + for i, share := range r.halfShares { + shares[i+offset] = share + } + return share.DefaultRSMT2DCodec().Decode(shares) +} + +// ToProto converts the Row to its protobuf representation. +func (r Row) ToProto() *pb.Row { + return &pb.Row{ + SharesHalf: SharesToProto(r.halfShares), + HalfSide: r.side.ToProto(), + } +} + +// Validate checks if the row's shares match the expected number from the root data and validates +// the side of the row. +func (r Row) Validate(dah *share.Root, idx int) error { + if len(r.halfShares) == 0 { + return fmt.Errorf("empty half row") + } + expectedShares := len(dah.RowRoots) / 2 + if len(r.halfShares) != expectedShares { + return fmt.Errorf("shares size doesn't match root size: %d != %d", len(r.halfShares), expectedShares) + } + if err := ValidateShares(r.halfShares); err != nil { + return fmt.Errorf("invalid shares: %w", err) + } + if r.side != Left && r.side != Right { + return fmt.Errorf("invalid RowSide: %d", r.side) + } + + if err := r.verifyInclusion(dah, idx); err != nil { + return fmt.Errorf("%w: %w", ErrFailedVerification, err) + } + return nil +} + +// verifyInclusion verifies the integrity of the row's shares against the provided root hash for the +// given row index. +func (r Row) verifyInclusion(dah *share.Root, idx int) error { + shrs, err := r.Shares() + if err != nil { + return fmt.Errorf("while extending shares: %w", err) + } + + sqrLn := uint64(len(shrs) / 2) + tree := wrapper.NewErasuredNamespacedMerkleTree(sqrLn, uint(idx)) + for _, s := range shrs { + if err := tree.Push(s); err != nil { + return fmt.Errorf("while pushing shares to NMT: %w", err) + } + } + + root, err := tree.Root() + if err != nil { + return fmt.Errorf("while computing NMT root: %w", err) + } + + if !bytes.Equal(dah.RowRoots[idx], root) { + return fmt.Errorf("invalid root hash: %X != %X", root, dah.RowRoots[idx]) + } + return nil +} + +// ToProto converts a RowSide to its protobuf representation. +func (s RowSide) ToProto() pb.Row_HalfSide { + if s == Left { + return pb.Row_LEFT + } + return pb.Row_RIGHT +} + +// sideFromProto converts a protobuf Row_HalfSide back to a RowSide. +func sideFromProto(side pb.Row_HalfSide) RowSide { + if side == pb.Row_LEFT { + return Left + } + return Right +} diff --git a/share/shwap/row_id.go b/share/shwap/row_id.go new file mode 100644 index 0000000000..cfdde4943f --- /dev/null +++ b/share/shwap/row_id.go @@ -0,0 +1,79 @@ +package shwap + +import ( + "encoding/binary" + "fmt" + + "github.com/celestiaorg/celestia-node/share" +) + +// RowIDSize defines the size in bytes of RowID, consisting of the size of EdsID and 2 bytes for +// RowIndex. +const RowIDSize = EdsIDSize + 2 + +// RowID uniquely identifies a row in the data square of a blockchain block, combining block height +// with the row's index. +type RowID struct { + EdsID // Embedding EdsID to include the block height in RowID. + RowIndex int // RowIndex specifies the position of the row within the data square. +} + +// NewRowID creates a new RowID with the specified block height, row index, and validates it +// against the provided Root. It returns an error if the validation fails, ensuring the RowID +// conforms to expected constraints. +func NewRowID(height uint64, rowIdx int, root *share.Root) (RowID, error) { + rid := RowID{ + EdsID: EdsID{ + Height: height, + }, + RowIndex: rowIdx, + } + return rid, rid.Validate(root) +} + +// RowIDFromBinary decodes a RowID from its binary representation. +// It returns an error if the input data does not conform to the expected size or content format. +func RowIDFromBinary(data []byte) (RowID, error) { + if len(data) != RowIDSize { + return RowID{}, fmt.Errorf("invalid RowID data length: expected %d, got %d", RowIDSize, len(data)) + } + eid, err := EdsIDFromBinary(data[:EdsIDSize]) + if err != nil { + return RowID{}, fmt.Errorf("error decoding EdsID: %w", err) + } + return RowID{ + EdsID: eid, + RowIndex: int(binary.BigEndian.Uint16(data[EdsIDSize:])), + }, nil +} + +// MarshalBinary encodes the RowID into a binary form for storage or network transmission. +func (rid RowID) MarshalBinary() ([]byte, error) { + data := make([]byte, 0, RowIDSize) + return rid.appendTo(data), nil +} + +// Validate ensures the RowID's fields are valid given the specified root structure, particularly +// that the row index is within bounds. +func (rid RowID) Validate(root *share.Root) error { + if err := rid.EdsID.Validate(root); err != nil { + return err + } + + if root == nil || len(root.RowRoots) == 0 { + return fmt.Errorf("provided root is nil or empty") + } + + if rid.RowIndex >= len(root.RowRoots) { + return fmt.Errorf("RowIndex out of bounds: %d >= %d", rid.RowIndex, len(root.RowRoots)) + } + + return nil +} + +// appendTo assists in binary encoding of RowID by appending the encoded fields to the given byte +// slice. +func (rid RowID) appendTo(data []byte) []byte { + data = rid.EdsID.appendTo(data) + return binary.BigEndian.AppendUint16(data, uint16(rid.RowIndex)) +} diff --git a/share/shwap/row_id_test.go b/share/shwap/row_id_test.go new file mode 100644 index 0000000000..410fe2782d --- /dev/null +++ b/share/shwap/row_id_test.go @@ -0,0 +1,30 @@ +package shwap + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/celestiaorg/celestia-node/share" + "github.com/celestiaorg/celestia-node/share/eds/edstest" +) + +func TestRowID(t *testing.T) { + square := edstest.RandEDS(t, 4) + root, err := share.NewRoot(square) + require.NoError(t, err) + + id, err := NewRowID(2, 1, root) + require.NoError(t, err) + + data, err := id.MarshalBinary() + require.NoError(t, err) + + idOut, err := RowIDFromBinary(data) + require.NoError(t, err) + assert.EqualValues(t, id, idOut) + + err = idOut.Validate(root) + require.NoError(t, err) +} diff --git a/share/shwap/row_namespace_data.go b/share/shwap/row_namespace_data.go new file mode 100644 index 0000000000..5d424ee0f3 --- /dev/null +++ b/share/shwap/row_namespace_data.go @@ -0,0 +1,181 @@ +package shwap + +import ( + "errors" + "fmt" + + "github.com/celestiaorg/celestia-app/pkg/appconsts" + "github.com/celestiaorg/celestia-app/pkg/wrapper" + "github.com/celestiaorg/nmt" + nmt_pb "github.com/celestiaorg/nmt/pb" + + "github.com/celestiaorg/celestia-node/share" + "github.com/celestiaorg/celestia-node/share/shwap/pb" +) + +// ErrNamespaceOutsideRange is returned by RowNamespaceDataFromShares when the target namespace is +// outside of the namespace range for the given row. In this case, the implementation cannot return +// the non-inclusion proof and will return ErrNamespaceOutsideRange. +var ErrNamespaceOutsideRange = errors.New("target namespace is outside of namespace range for the given root") + +// RowNamespaceData holds shares and their corresponding proof for a single row within a namespace. +type RowNamespaceData struct { + Shares []share.Share `json:"shares"` // Shares within the namespace. + Proof *nmt.Proof `json:"proof"` // Proof of the shares' inclusion in the namespace. +} + +// RowNamespaceDataFromShares extracts and constructs a RowNamespaceData from shares within the +// specified namespace. +func RowNamespaceDataFromShares( + shares []share.Share, + namespace share.Namespace, + rowIndex int, +) (RowNamespaceData, error) { + tree := wrapper.NewErasuredNamespacedMerkleTree(uint64(len(shares)/2), uint(rowIndex)) + nmtTree := nmt.New( + appconsts.NewBaseHashFunc(), + nmt.NamespaceIDSize(appconsts.NamespaceSize), + nmt.IgnoreMaxNamespace(true), + ) + tree.SetTree(nmtTree) + + for _, shr := range shares { + if err := tree.Push(shr); err != nil { + return RowNamespaceData{}, fmt.Errorf("failed to build tree for row %d: %w", rowIndex, err) + } + } + + root, err := tree.Root() + if err != nil { + return RowNamespaceData{}, fmt.Errorf("failed to get root for row %d: %w", rowIndex, err) + } + if namespace.IsOutsideRange(root, root) { + return RowNamespaceData{}, ErrNamespaceOutsideRange + } + + var from, count int + for i := range len(shares) / 2 { + if namespace.Equals(share.GetNamespace(shares[i])) { + if count == 0 { + from = i + } + count++ + continue + } + if count > 0 { + break + } + } + + // if count is 0, then the namespace is not present in the shares. Return non-inclusion proof. + if count == 0 { + proof, err := nmtTree.ProveNamespace(namespace.ToNMT()) + if err != nil { + return RowNamespaceData{}, fmt.Errorf("failed to generate non-inclusion proof for row %d: %w", rowIndex, err) + } + + return RowNamespaceData{ + Proof: &proof, + }, nil + } + + namespacedShares := make([]share.Share, count) + copy(namespacedShares, shares[from:from+count]) + + proof, err := tree.ProveRange(from, from+count) + if err != nil { + return RowNamespaceData{}, fmt.Errorf("failed to generate proof for row %d: %w", rowIndex, err) + } + + return RowNamespaceData{ + Shares: namespacedShares, + Proof: &proof, + }, nil +} + +// RowNamespaceDataFromProto constructs RowNamespaceData out of its protobuf representation. +func RowNamespaceDataFromProto(row *pb.RowNamespaceData) RowNamespaceData { + var proof nmt.Proof + if row.GetProof().GetLeafHash() != nil { + proof = nmt.NewAbsenceProof( + int(row.GetProof().GetStart()), + int(row.GetProof().GetEnd()), + row.GetProof().GetNodes(), + row.GetProof().GetLeafHash(), + row.GetProof().GetIsMaxNamespaceIgnored(), + ) + } else { + proof = nmt.NewInclusionProof( + int(row.GetProof().GetStart()), + int(row.GetProof().GetEnd()), + row.GetProof().GetNodes(), + row.GetProof().GetIsMaxNamespaceIgnored(), + ) + } + + return RowNamespaceData{ + Shares: SharesFromProto(row.GetShares()), + Proof: &proof, + } +} + +// ToProto converts RowNamespaceData to its protobuf representation for serialization. +func (rnd RowNamespaceData) ToProto() *pb.RowNamespaceData { + return &pb.RowNamespaceData{ + Shares: SharesToProto(rnd.Shares), + Proof: &nmt_pb.Proof{ + Start: int64(rnd.Proof.Start()), + End: int64(rnd.Proof.End()), + Nodes: rnd.Proof.Nodes(), + LeafHash: rnd.Proof.LeafHash(), + IsMaxNamespaceIgnored: rnd.Proof.IsMaxNamespaceIDIgnored(), + }, + } +} + +// Validate checks validity of the RowNamespaceData against the Root, Namespace and Row index. +func (rnd RowNamespaceData) Validate(dah *share.Root, namespace share.Namespace, rowIdx int) error { + if rnd.Proof == nil || rnd.Proof.IsEmptyProof() { + return fmt.Errorf("nil proof") + } + if len(rnd.Shares) == 0 && !rnd.Proof.IsOfAbsence() { + return fmt.Errorf("empty shares with non-absence proof for row %d", rowIdx) + } + + if len(rnd.Shares) > 0 && rnd.Proof.IsOfAbsence() { + return fmt.Errorf("non-empty shares with absence proof for row %d", rowIdx) + } + + if err := ValidateShares(rnd.Shares); err != nil { + return fmt.Errorf("invalid shares: %w", err) + } + + rowRoot := dah.RowRoots[rowIdx] + if namespace.IsOutsideRange(rowRoot, rowRoot) { + return fmt.Errorf("namespace out of range for row %d", rowIdx) + } + + if !rnd.verifyInclusion(rowRoot, namespace) { + return fmt.Errorf("%w for row: %d", ErrFailedVerification, rowIdx) + } + return nil +} + +// verifyInclusion checks the inclusion of the row's shares in the provided root using NMT. +func (rnd RowNamespaceData) verifyInclusion(rowRoot []byte, namespace share.Namespace) bool { + leaves := make([][]byte, 0, len(rnd.Shares)) + for _, sh := range rnd.Shares { + namespaceBytes := share.GetNamespace(sh) + leave := make([]byte, len(sh)+len(namespaceBytes)) + copy(leave, namespaceBytes) + copy(leave[len(namespaceBytes):], sh) + leaves = append(leaves, leave) + } + + return rnd.Proof.VerifyNamespace( + share.NewSHA256Hasher(), + namespace.ToNMT(), + leaves, + rowRoot, + ) +} diff --git a/share/shwap/row_namespace_data_id.go b/share/shwap/row_namespace_data_id.go new file mode 100644 index 0000000000..e31b39636e --- /dev/null +++ b/share/shwap/row_namespace_data_id.go @@ -0,0 +1,95 @@ +package shwap + +import ( + "fmt" + + "github.com/celestiaorg/celestia-node/share" +) + +// RowNamespaceDataIDSize defines the total size of a RowNamespaceDataID in bytes, combining the +// size of a RowID and the size of a Namespace. +const RowNamespaceDataIDSize = RowIDSize + share.NamespaceSize + +// RowNamespaceDataID uniquely identifies a piece of namespaced data within a row of an Extended +// Data Square (EDS). +type RowNamespaceDataID struct { + RowID // Embedded RowID representing the specific row in the EDS. + DataNamespace share.Namespace // DataNamespace is a string representation of the namespace to facilitate comparisons. +} + +// NewRowNamespaceDataID creates a new RowNamespaceDataID with the specified parameters. It +// validates the RowNamespaceDataID against the provided Root before returning. +func NewRowNamespaceDataID( + height uint64, + rowIdx int, + namespace share.Namespace, + root *share.Root, +) (RowNamespaceDataID, error) { + did := RowNamespaceDataID{ + RowID: RowID{ + EdsID: EdsID{ + Height: height, + }, + RowIndex: rowIdx, + }, + DataNamespace: namespace, + } + + if err := did.Validate(root); err != nil { + return RowNamespaceDataID{}, err + } + return did, nil +} + +// RowNamespaceDataIDFromBinary deserializes a RowNamespaceDataID from its binary form. It returns +// an error if the binary data's length does not match the expected size. +func RowNamespaceDataIDFromBinary(data []byte) (RowNamespaceDataID, error) { + if len(data) != RowNamespaceDataIDSize { + return RowNamespaceDataID{}, + fmt.Errorf("invalid RowNamespaceDataID length: expected %d, got %d", RowNamespaceDataIDSize, len(data)) + } + + rid, err := RowIDFromBinary(data[:RowIDSize]) + if err != nil { + return RowNamespaceDataID{}, fmt.Errorf("error unmarshaling RowID: %w", err) + } + + nsData := data[RowIDSize:] + ns := share.Namespace(nsData) + if err := ns.ValidateForData(); err != nil { + return RowNamespaceDataID{}, fmt.Errorf("error validating DataNamespace: %w", err) + } + + return RowNamespaceDataID{ + RowID: rid, + DataNamespace: ns, + }, nil +} + +// MarshalBinary encodes RowNamespaceDataID into binary form. +// NOTE: Proto is avoided because +// * Its size is not deterministic which is required for IPLD. +// * No support for uint16 +func (s RowNamespaceDataID) MarshalBinary() ([]byte, error) { + data := make([]byte, 0, RowNamespaceDataIDSize) + return s.appendTo(data), nil +} + +// Validate checks the validity of RowNamespaceDataID's fields, including the RowID and the +// namespace. +func (s RowNamespaceDataID) Validate(root *share.Root) error { + if err := s.RowID.Validate(root); err != nil { + return fmt.Errorf("error validating RowID: %w", err) + } + if err := s.DataNamespace.ValidateForData(); err != nil { + return fmt.Errorf("error validating DataNamespace: %w", err) + } + + return nil +} + +// appendTo helps in appending the binary form of DataNamespace to the serialized RowID data. +func (s RowNamespaceDataID) appendTo(data []byte) []byte { + data = s.RowID.appendTo(data) + return append(data, s.DataNamespace...) +} diff --git a/share/shwap/row_namespace_data_id_test.go b/share/shwap/row_namespace_data_id_test.go new file mode 100644 index 0000000000..bc965d4bc7 --- /dev/null +++ b/share/shwap/row_namespace_data_id_test.go @@ -0,0 +1,29 @@ +package shwap + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/celestiaorg/celestia-node/share/eds/edstest" + "github.com/celestiaorg/celestia-node/share/sharetest" +) + +func TestDataID(t *testing.T) { + ns := sharetest.RandV0Namespace() + _, root := edstest.RandEDSWithNamespace(t, ns, 8, 4) + + id, err := NewRowNamespaceDataID(1, 1, ns, root) + require.NoError(t, err) + + data, err := id.MarshalBinary() + require.NoError(t, err) + + sidOut, err := RowNamespaceDataIDFromBinary(data) + require.NoError(t, err) + assert.EqualValues(t, id, sidOut) + + err = sidOut.Validate(root) + require.NoError(t, err) +} diff --git a/share/shwap/row_namespace_data_test.go b/share/shwap/row_namespace_data_test.go new file mode 100644 index 0000000000..07f434b87e --- /dev/null +++ b/share/shwap/row_namespace_data_test.go @@ -0,0 +1,99 @@ +package shwap_test + +import ( + "bytes" + "context" + "slices" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/celestiaorg/celestia-node/share" + "github.com/celestiaorg/celestia-node/share/eds/edstest" + eds "github.com/celestiaorg/celestia-node/share/new_eds" + "github.com/celestiaorg/celestia-node/share/sharetest" + "github.com/celestiaorg/celestia-node/share/shwap" +) + +func TestNamespacedRowFromShares(t *testing.T) { + const odsSize = 8 + + minNamespace, err := share.NewBlobNamespaceV0(slices.Concat(bytes.Repeat([]byte{0}, 8), []byte{1, 0})) + require.NoError(t, err) + err = minNamespace.ValidateForData() + require.NoError(t, err) + + for namespacedAmount := 1; namespacedAmount < odsSize; namespacedAmount++ { + shares := sharetest.RandSharesWithNamespace(t, minNamespace, namespacedAmount, odsSize) + parity, err := share.DefaultRSMT2DCodec().Encode(shares) + require.NoError(t, err) + extended := slices.Concat(shares, parity) + + nr, err := shwap.RowNamespaceDataFromShares(extended, minNamespace, 0) + require.NoError(t, err) + require.Equal(t, namespacedAmount, len(nr.Shares)) + } +} + +func TestNamespacedRowFromSharesNonIncluded(t *testing.T) { + // TODO: this will fail until absence proof support is added + t.Skip() + + const odsSize = 8 + // Test absent namespace + shares := sharetest.RandShares(t, odsSize) + absentNs, err := share.GetNamespace(shares[0]).AddInt(1) + require.NoError(t, err) + + parity, err := share.DefaultRSMT2DCodec().Encode(shares) + require.NoError(t, err) + extended := slices.Concat(shares, parity) + + nr, err := shwap.RowNamespaceDataFromShares(extended, absentNs, 0) + require.NoError(t, err) + require.Len(t, nr.Shares, 0) + require.True(t, nr.Proof.IsOfAbsence()) +} + +func TestValidateNamespacedRow(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + t.Cleanup(cancel) + + const odsSize = 8 + sharesAmount := odsSize * odsSize + namespace := sharetest.RandV0Namespace() + for amount := 1; amount < sharesAmount; amount++ { + randEDS, root := edstest.RandEDSWithNamespace(t, namespace, amount, odsSize) + rsmt2d := eds.Rsmt2D{ExtendedDataSquare: randEDS} + nd, err := eds.NamespacedData(ctx, root, rsmt2d, namespace) + require.NoError(t, err) + require.True(t, len(nd) > 0) + + rowIdxs := share.RowsWithNamespace(root, namespace) + require.Len(t, nd, len(rowIdxs)) + + for i, rowIdx := range rowIdxs { + err = nd[i].Validate(root, namespace, rowIdx) + require.NoError(t, err) + } + } +} + +func TestNamespacedRowProtoEncoding(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + t.Cleanup(cancel) + + const odsSize = 8 + namespace := sharetest.RandV0Namespace() + randEDS, root := edstest.RandEDSWithNamespace(t, namespace, odsSize, odsSize) + rsmt2d := eds.Rsmt2D{ExtendedDataSquare: randEDS} + nd, err := eds.NamespacedData(ctx, root, rsmt2d, namespace) + require.NoError(t, err) + require.True(t, len(nd) > 0) + + expected := nd[0] + pb := expected.ToProto() + ndOut := shwap.RowNamespaceDataFromProto(pb) + require.Equal(t, expected, ndOut) +} diff --git a/share/shwap/row_test.go b/share/shwap/row_test.go new file mode 100644 index 0000000000..5ea4c76b61 --- /dev/null +++ b/share/shwap/row_test.go @@ -0,0 +1,117 @@ +package shwap + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/celestiaorg/celestia-node/share" + "github.com/celestiaorg/celestia-node/share/eds/edstest" +) + +func TestRowFromShares(t *testing.T) { + const odsSize = 8 + eds := edstest.RandEDS(t, odsSize) + + for rowIdx := 0; rowIdx < odsSize*2; rowIdx++ { + for _, side := range []RowSide{Left, Right} { + shares := eds.Row(uint(rowIdx)) + row := RowFromShares(shares, side) + extended, err := row.Shares() + require.NoError(t, err) + require.Equal(t, shares, extended) + + var half []share.Share + if side == Right { + half = shares[odsSize:] + } else { + half = shares[:odsSize] + } + require.Equal(t, half, row.halfShares) + require.Equal(t, side, row.side) + } + } +} + +func TestRowValidate(t *testing.T) { + const odsSize = 8 + eds := edstest.RandEDS(t, odsSize) + root, err := share.NewRoot(eds) + require.NoError(t, err) + + for rowIdx := 0; rowIdx < odsSize*2; rowIdx++ { + for _, side := range []RowSide{Left, Right} { + shares := eds.Row(uint(rowIdx)) + row := RowFromShares(shares, side) + + err := row.Validate(root, rowIdx) + require.NoError(t, err) + err = row.Validate(root, rowIdx) + require.NoError(t, err) + } + } +} + +func TestRowValidateNegativeCases(t *testing.T) { + eds := edstest.RandEDS(t, 8) // Generate a random Extended Data Square of size 8 + root, err := share.NewRoot(eds) + require.NoError(t, err) + shares := eds.Row(0) + row := RowFromShares(shares, Left) + + // Test with incorrect side specification + invalidSideRow := Row{halfShares: row.halfShares, side: RowSide(999)} + err = invalidSideRow.Validate(root, 0) + require.Error(t, err, "should error on invalid row side") + + // Test with invalid shares (more shares than expected) + incorrectShares := make([]share.Share, (eds.Width()/2)+1) // Adding an extra share + for i := range incorrectShares { + incorrectShares[i] = eds.GetCell(uint(i), 0) + } + invalidRow := Row{halfShares: incorrectShares, side: Left} + err = invalidRow.Validate(root, 0) + require.Error(t, err, "should error on incorrect number of shares") + + // Test with empty shares + emptyRow := Row{halfShares: []share.Share{}, side: Left} + err = emptyRow.Validate(root, 0) + require.Error(t, err, "should error on empty halfShares") + + // Doesn't match root. Corrupt root hash + root.RowRoots[0][len(root.RowRoots[0])-1] ^= 0xFF + err = row.Validate(root, 0) + require.Error(t, err, "should error on invalid root hash") +} + +func TestRowProtoEncoding(t *testing.T) { + const odsSize = 8 + eds := edstest.RandEDS(t, odsSize) + + for rowIdx := 0; rowIdx < odsSize*2; rowIdx++ { + for _, side := range []RowSide{Left, Right} { + shares := eds.Row(uint(rowIdx)) + row := RowFromShares(shares, side) + + pb := row.ToProto() + rowOut := RowFromProto(pb) + require.Equal(t, row, rowOut) + } + } +} + +// BenchmarkRowValidate benchmarks the performance of row validation. +// BenchmarkRowValidate-10 9591 121802 ns/op +func BenchmarkRowValidate(b *testing.B) { + const odsSize = 32 + eds := edstest.RandEDS(b, odsSize) + root, err := share.NewRoot(eds) + require.NoError(b, err) + shares := eds.Row(0) + row := RowFromShares(shares, Left) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = row.Validate(root, 0) + } +} diff --git a/share/shwap/sample.go b/share/shwap/sample.go new file mode 100644 index 0000000000..cb263415ad --- /dev/null +++ b/share/shwap/sample.go @@ -0,0 +1,122 @@ +package shwap + +import ( + "errors" + "fmt" + + "github.com/celestiaorg/celestia-app/pkg/wrapper" + "github.com/celestiaorg/nmt" + nmt_pb "github.com/celestiaorg/nmt/pb" + "github.com/celestiaorg/rsmt2d" + + "github.com/celestiaorg/celestia-node/share" + "github.com/celestiaorg/celestia-node/share/shwap/pb" +) + +// ErrFailedVerification is returned when inclusion proof verification fails. It is returned +// when the data and the proof do not match trusted data root. +var ErrFailedVerification = errors.New("failed to verify inclusion") + +// Sample represents a data share along with its Merkle proof, used to validate the share's +// inclusion in a data square. +type Sample struct { + share.Share // Embeds the Share which includes the data with namespace. + Proof *nmt.Proof // Proof is the Merkle Proof validating the share's inclusion. + ProofType rsmt2d.Axis // ProofType indicates whether the proof is against a row or a column. +} + +// SampleFromShares creates a Sample from a list of shares, using the specified proof type and +// the share index to be included in the sample. +func SampleFromShares(shares []share.Share, proofType rsmt2d.Axis, axisIdx, shrIdx int) (Sample, error) { + tree := wrapper.NewErasuredNamespacedMerkleTree(uint64(len(shares)/2), uint(axisIdx)) + for _, shr := range shares { + err := tree.Push(shr) + if err != nil { + return Sample{}, err + } + } + + proof, err := tree.ProveRange(shrIdx, shrIdx+1) + if err != nil { + return Sample{}, err + } + + return Sample{ + Share: shares[shrIdx], + Proof: &proof, + ProofType: proofType, + }, nil +} + +// SampleFromProto converts a protobuf Sample back into its domain model equivalent. +func SampleFromProto(s *pb.Sample) Sample { + proof := nmt.NewInclusionProof( + int(s.GetProof().GetStart()), + int(s.GetProof().GetEnd()), + s.GetProof().GetNodes(), + s.GetProof().GetIsMaxNamespaceIgnored(), + ) + return Sample{ + Share: ShareFromProto(s.GetShare()), + Proof: &proof, + ProofType: rsmt2d.Axis(s.GetProofType()), + } +} + +// ToProto converts a Sample into its protobuf representation for serialization purposes. +func (s Sample) ToProto() *pb.Sample { + return &pb.Sample{ + Share: &pb.Share{Data: s.Share}, + Proof: &nmt_pb.Proof{ + Start: int64(s.Proof.Start()), + End: int64(s.Proof.End()), + Nodes: s.Proof.Nodes(), + LeafHash: s.Proof.LeafHash(), + IsMaxNamespaceIgnored: s.Proof.IsMaxNamespaceIDIgnored(), + }, + ProofType: pb.AxisType(s.ProofType), + } +} + +// Validate checks the inclusion of the share using its Merkle proof under the specified root. +// Returns an error if the proof is invalid or does not correspond to the indicated proof type. +func (s Sample) Validate(dah *share.Root, rowIdx, colIdx int) error { + if s.Proof == nil || s.Proof.IsEmptyProof() { + return errors.New("nil proof") + } + if err := share.ValidateShare(s.Share); err != nil { + return err + } + if s.ProofType != rsmt2d.Row && s.ProofType != rsmt2d.Col { + return fmt.Errorf("invalid SampleProofType: %d", s.ProofType) + } + if !s.verifyInclusion(dah, rowIdx, colIdx) { + return ErrFailedVerification + } + return nil +} + +// verifyInclusion checks if the share is included in the given root hash at the specified indices. +func (s Sample) verifyInclusion(dah *share.Root, rowIdx, colIdx int) bool { + size := len(dah.RowRoots) + namespace := inclusionNamespace(s.Share, rowIdx, colIdx, size) + rootHash := share.RootHashForCoordinates(dah, s.ProofType, uint(rowIdx), uint(colIdx)) + return s.Proof.VerifyInclusion( + share.NewSHA256Hasher(), + namespace.ToNMT(), + [][]byte{s.Share}, + rootHash, + ) +} + +// inclusionNamespace returns the namespace for the share based on its position in the square. +// Shares from extended part of the square are considered parity shares. It means that +// parity shares are located outside of first quadrant of the square. According to the nmt +// specification, the parity shares are prefixed with the namespace of the parity shares. +func inclusionNamespace(sh share.Share, rowIdx, colIdx, squareSize int) share.Namespace { + isParity := colIdx >= squareSize/2 || rowIdx >= squareSize/2 + if isParity { + return share.ParitySharesNamespace + } + return share.GetNamespace(sh) +} diff --git a/share/shwap/sample_id.go b/share/shwap/sample_id.go new file mode 100644 index 0000000000..33e83fe12d --- /dev/null +++ b/share/shwap/sample_id.go @@ -0,0 +1,90 @@ +package shwap + +import ( + "encoding/binary" + "fmt" + + "github.com/celestiaorg/celestia-node/share" +) + +// SampleIDSize defines the size of the SampleID in bytes, combining RowID size and 2 additional +// bytes for the ShareIndex. +const SampleIDSize = RowIDSize + 2 + +// SampleID uniquely identifies a specific sample within a row of an Extended Data Square (EDS). +type SampleID struct { + RowID // Embeds RowID to incorporate block height and row index. + ShareIndex int // ShareIndex specifies the index of the sample within the row. +} + +// NewSampleID constructs a new SampleID using the provided block height, sample index, and a root +// structure for validation. It calculates the row and share index based on the sample index and +// the length of the row roots. +func NewSampleID(height uint64, rowIdx, colIdx int, root *share.Root) (SampleID, error) { + if root == nil || len(root.RowRoots) == 0 { + return SampleID{}, fmt.Errorf("invalid root: root is nil or empty") + } + sid := SampleID{ + RowID: RowID{ + EdsID: EdsID{ + Height: height, + }, + RowIndex: rowIdx, + }, + ShareIndex: colIdx, + } + + if err := sid.Validate(root); err != nil { + return SampleID{}, err + } + return sid, nil +} + +// SampleIDFromBinary deserializes a SampleID from binary data, ensuring the data length matches +// the expected size. +func SampleIDFromBinary(data []byte) (SampleID, error) { + if len(data) != SampleIDSize { + return SampleID{}, fmt.Errorf("invalid SampleID data length: expected %d, got %d", SampleIDSize, len(data)) + } + + rid, err := RowIDFromBinary(data[:RowIDSize]) + if err != nil { + return SampleID{}, fmt.Errorf("error decoding RowID: %w", err) + } + + return SampleID{ + RowID: rid, + ShareIndex: int(binary.BigEndian.Uint16(data[RowIDSize:])), + }, nil +} + +// MarshalBinary encodes SampleID into binary form. +// NOTE: Proto is avoided because +// * Its size is not deterministic which is required for IPLD. +// * No support for uint16 +func (sid SampleID) MarshalBinary() ([]byte, error) { + data := make([]byte, 0, SampleIDSize) + return sid.appendTo(data), nil +} + +// Validate checks the validity of the SampleID by ensuring the ShareIndex is within the bounds of +// the square size. +func (sid SampleID) Validate(root *share.Root) error { + if err := sid.RowID.Validate(root); err != nil { + return err + } + + sqrLn := len(root.ColumnRoots) // Assumes ColumnRoots is valid and populated. + if sid.ShareIndex >= sqrLn { + return fmt.Errorf("ShareIndex exceeds square size: %d >= %d", sid.ShareIndex, sqrLn) + } + + return nil +} + +// appendTo helps in constructing the binary representation by appending the encoded ShareIndex to +// the serialized RowID. +func (sid SampleID) appendTo(data []byte) []byte { + data = sid.RowID.appendTo(data) + return binary.BigEndian.AppendUint16(data, uint16(sid.ShareIndex)) +} diff --git a/share/shwap/sample_id_test.go b/share/shwap/sample_id_test.go new file mode 100644 index 0000000000..eafbc76550 --- /dev/null +++ b/share/shwap/sample_id_test.go @@ -0,0 +1,30 @@ +package shwap + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/celestiaorg/celestia-node/share" + "github.com/celestiaorg/celestia-node/share/eds/edstest" +) + +func TestSampleID(t *testing.T) { + square := edstest.RandEDS(t, 4) + root, err := share.NewRoot(square) + require.NoError(t, err) + + id, err := NewSampleID(1, 1, 1, root) + require.NoError(t, err) + + data, err := id.MarshalBinary() + require.NoError(t, err) + + idOut, err := SampleIDFromBinary(data) + require.NoError(t, err) + assert.EqualValues(t, id, idOut) + + err = idOut.Validate(root) + require.NoError(t, err) +} diff --git a/share/shwap/sample_test.go b/share/shwap/sample_test.go new file mode 100644 index 0000000000..34fb4bfa9b --- /dev/null +++ b/share/shwap/sample_test.go @@ -0,0 +1,108 @@ +package shwap_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/celestiaorg/rsmt2d" + + "github.com/celestiaorg/celestia-node/share" + "github.com/celestiaorg/celestia-node/share/eds/edstest" + eds "github.com/celestiaorg/celestia-node/share/new_eds" + "github.com/celestiaorg/celestia-node/share/shwap" +) + +func TestSampleValidate(t *testing.T) { + const odsSize = 8 + randEDS := edstest.RandEDS(t, odsSize) + root, err := share.NewRoot(randEDS) + require.NoError(t, err) + inMem := eds.Rsmt2D{ExtendedDataSquare: randEDS} + + for _, proofType := range []rsmt2d.Axis{rsmt2d.Row, rsmt2d.Col} { + for rowIdx := 0; rowIdx < odsSize*2; rowIdx++ { + for colIdx := 0; colIdx < odsSize*2; colIdx++ { + sample, err := inMem.SampleForProofAxis(rowIdx, colIdx, proofType) + require.NoError(t, err) + + require.NoError(t, sample.Validate(root, rowIdx, colIdx)) + } + } + } +} + +// TestSampleNegativeVerifyInclusion checks +func TestSampleNegativeVerifyInclusion(t *testing.T) { + const odsSize = 8 + randEDS := edstest.RandEDS(t, odsSize) + root, err := share.NewRoot(randEDS) + require.NoError(t, err) + inMem := eds.Rsmt2D{ExtendedDataSquare: randEDS} + + sample, err := inMem.Sample(context.Background(), 0, 0) + require.NoError(t, err) + err = sample.Validate(root, 0, 0) + require.NoError(t, err) + + // incorrect row index + err = sample.Validate(root, 1, 0) + require.ErrorIs(t, err, shwap.ErrFailedVerification) + + // Corrupt the share + sample.Share[0] ^= 0xFF + err = sample.Validate(root, 0, 0) + require.ErrorIs(t, err, shwap.ErrFailedVerification) + + // incorrect proofType + sample, err = inMem.Sample(context.Background(), 0, 0) + require.NoError(t, err) + sample.ProofType = rsmt2d.Col + err = sample.Validate(root, 0, 0) + require.ErrorIs(t, err, shwap.ErrFailedVerification) + + // Corrupt the last root hash byte + sample, err = inMem.Sample(context.Background(), 0, 0) + require.NoError(t, err) + root.RowRoots[0][len(root.RowRoots[0])-1] ^= 0xFF + err = sample.Validate(root, 0, 0) + require.ErrorIs(t, err, shwap.ErrFailedVerification) +} + +func TestSampleProtoEncoding(t *testing.T) { + const odsSize = 8 + randEDS := edstest.RandEDS(t, odsSize) + inMem := eds.Rsmt2D{ExtendedDataSquare: randEDS} + + for _, proofType := range []rsmt2d.Axis{rsmt2d.Row, rsmt2d.Col} { + for rowIdx := 0; rowIdx < odsSize*2; rowIdx++ { + for colIdx := 0; colIdx < odsSize*2; colIdx++ { + sample, err := inMem.SampleForProofAxis(rowIdx, colIdx, proofType) + require.NoError(t, err) + + pb := sample.ToProto() + sampleOut := shwap.SampleFromProto(pb) + require.NoError(t, err) + require.Equal(t, sample, sampleOut) + } + } + } +} + +// BenchmarkSampleValidate benchmarks the performance of sample validation. +// BenchmarkSampleValidate-10 284829 3935 ns/op +func BenchmarkSampleValidate(b *testing.B) { + const odsSize = 32 + randEDS := edstest.RandEDS(b, odsSize) + root, err := share.NewRoot(randEDS) + require.NoError(b, err) + inMem := eds.Rsmt2D{ExtendedDataSquare: randEDS} + sample, err := inMem.SampleForProofAxis(0, 0, rsmt2d.Row) + require.NoError(b, err) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = sample.Validate(root, 0, 0) + } +} diff --git a/share/shwap/share.go b/share/shwap/share.go new file mode 100644 index 0000000000..a7f7ef67b7 --- /dev/null +++ b/share/shwap/share.go @@ -0,0 +1,49 @@ +package shwap + +import ( + "fmt" + + "github.com/celestiaorg/celestia-node/share" + "github.com/celestiaorg/celestia-node/share/shwap/pb" +) + +// ShareFromProto converts a protobuf Share object to the application's internal share +// representation. It returns nil if the input protobuf Share is nil, ensuring safe handling of nil +// values. +func ShareFromProto(s *pb.Share) share.Share { + if s == nil { + return nil + } + return s.Data +} + +// SharesToProto converts a slice of Shares from the application's internal representation to a +// slice of protobuf Share objects. This function allocates memory for the protobuf objects and +// copies data from the input slice. +func SharesToProto(shrs []share.Share) []*pb.Share { + protoShares := make([]*pb.Share, len(shrs)) + for i, shr := range shrs { + protoShares[i] = &pb.Share{Data: shr} + } + return protoShares +} + +// SharesFromProto converts a slice of protobuf Share objects to the application's internal slice +// of Shares. It ensures that each Share is correctly transformed using the ShareFromProto function. +func SharesFromProto(shrs []*pb.Share) []share.Share { + shares := make([]share.Share, len(shrs)) + for i, shr := range shrs { + shares[i] = ShareFromProto(shr) + } + return shares +} + +// ValidateShares takes the slice of shares and checks their conformance to share format. +func ValidateShares(shares []share.Share) error { + for i, shr := range shares { + if err := share.ValidateShare(shr); err != nil { + return fmt.Errorf("while validating share at index %d: %w", i, err) + } + } + return nil +} diff --git a/store/file/codec.go b/store/file/codec.go new file mode 100644 index 0000000000..a27280be11 --- /dev/null +++ b/store/file/codec.go @@ -0,0 +1,38 @@ +package file + +import ( + "sync" + + "github.com/klauspost/reedsolomon" +) + +var codec Codec + +func init() { + codec = NewCodec() +} + +type Codec interface { + Encoder(len int) (reedsolomon.Encoder, error) +} + +type codecCache struct { + cache sync.Map +} + +func NewCodec() Codec { + return &codecCache{} +} + +func (l *codecCache) Encoder(len int) (reedsolomon.Encoder, error) { + enc, ok := l.cache.Load(len) + if !ok { + var err error + enc, err = reedsolomon.New(len/2, len/2, reedsolomon.WithLeopardGF(true)) + if err != nil { + return nil, err + } + l.cache.Store(len, enc) + } + return enc.(reedsolomon.Encoder), nil +} diff --git a/store/file/codec_test.go b/store/file/codec_test.go new file mode 100644 index 0000000000..d6fdbb3045 --- /dev/null +++ b/store/file/codec_test.go @@ -0,0 +1,83 @@ +package file + +import ( + "fmt" + "testing" + + "github.com/klauspost/reedsolomon" + "github.com/stretchr/testify/require" + + "github.com/celestiaorg/celestia-node/share/sharetest" +) + +func BenchmarkCodec(b *testing.B) { + minSize, maxSize := 32, 128 + + for size := minSize; size <= maxSize; size *= 2 { + // BenchmarkCodec/Leopard/size:32-10 409194 2793 ns/op + // BenchmarkCodec/Leopard/size:64-10 190969 6170 ns/op + // BenchmarkCodec/Leopard/size:128-10 82821 14287 ns/op + b.Run(fmt.Sprintf("Leopard/size:%v", size), func(b *testing.B) { + enc, err := reedsolomon.New(size/2, size/2, reedsolomon.WithLeopardGF(true)) + require.NoError(b, err) + + shards := newShards(b, size, true) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + err = enc.Encode(shards) + require.NoError(b, err) + } + }) + + // BenchmarkCodec/default/size:32-10 222153 5364 ns/op + // BenchmarkCodec/default/size:64-10 58831 20349 ns/op + // BenchmarkCodec/default/size:128-10 14940 80471 ns/op + b.Run(fmt.Sprintf("default/size:%v", size), func(b *testing.B) { + enc, err := reedsolomon.New(size/2, size/2, reedsolomon.WithLeopardGF(false)) + require.NoError(b, err) + + shards := newShards(b, size, true) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + err = enc.Encode(shards) + require.NoError(b, err) + } + }) + + // BenchmarkCodec/default-reconstructSome/size:32-10 1263585 954.4 ns/op + // BenchmarkCodec/default-reconstructSome/size:64-10 762273 1554 ns/op + // BenchmarkCodec/default-reconstructSome/size:128-10 429268 2974 ns/op + b.Run(fmt.Sprintf("default-reconstructSome/size:%v", size), func(b *testing.B) { + enc, err := reedsolomon.New(size/2, size/2, reedsolomon.WithLeopardGF(false)) + require.NoError(b, err) + + shards := newShards(b, size, false) + targets := make([]bool, size) + target := size - 2 + targets[target] = true + + b.ResetTimer() + for i := 0; i < b.N; i++ { + err = enc.ReconstructSome(shards, targets) + require.NoError(b, err) + shards[target] = nil + } + }) + } +} + +func newShards(b require.TestingT, size int, fillParity bool) [][]byte { + shards := make([][]byte, size) + original := sharetest.RandShares(b, size/2) + copy(shards, original) + + if fillParity { + // fill with parity empty Shares + for j := len(original); j < len(shards); j++ { + shards[j] = make([]byte, len(original[0])) + } + } + return shards +} diff --git a/store/file/header.go b/store/file/header.go new file mode 100644 index 0000000000..846aa17bac --- /dev/null +++ b/store/file/header.go @@ -0,0 +1,100 @@ +package file + +import ( + "encoding/binary" + "fmt" + "io" + + "github.com/celestiaorg/celestia-node/share" +) + +type headerVersion uint8 + +const ( + headerVersionV0 headerVersion = 1 + headerVOSize int = 40 +) + +type headerV0 struct { + fileVersion fileVersion + fileType fileType + + // Taken directly from EDS + shareSize uint16 + squareSize uint16 + + datahash share.DataHash +} + +type fileVersion uint8 + +const ( + fileV0 fileVersion = iota + 1 +) + +type fileType uint8 + +const ( + ODS fileType = iota + q1q4 +) + +func readHeader(r io.Reader) (*headerV0, error) { + // read first byte to determine the fileVersion + var version headerVersion + if err := binary.Read(r, binary.LittleEndian, &version); err != nil { + return nil, fmt.Errorf("readHeader: %w", err) + } + + switch version { + case headerVersionV0: + h := &headerV0{} + _, err := h.ReadFrom(r) + return h, err + default: + return nil, fmt.Errorf("unsupported header fileVersion: %d", version) + } +} + +func writeHeader(w io.Writer, h *headerV0) error { + n, err := w.Write([]byte{byte(headerVersionV0)}) + if err != nil { + return fmt.Errorf("writeHeader: %w", err) + } + if n != 1 { + return fmt.Errorf("writeHeader: wrote %d bytes, expected 1", n) + } + _, err = h.WriteTo(w) + return err +} + +func (h *headerV0) Size() int { + // header size + 1 byte for header fileVersion + return headerVOSize + 1 +} + +func (h *headerV0) WriteTo(w io.Writer) (int64, error) { + buf := make([]byte, headerVOSize) + buf[0] = byte(h.fileVersion) + buf[1] = byte(h.fileType) + binary.LittleEndian.PutUint16(buf[4:6], h.shareSize) + binary.LittleEndian.PutUint16(buf[6:8], h.squareSize) + copy(buf[8:40], h.datahash) + n, err := w.Write(buf) + return int64(n), err +} + +func (h *headerV0) ReadFrom(r io.Reader) (int64, error) { + bytesHeader := make([]byte, headerVOSize) + n, err := io.ReadFull(r, bytesHeader) + if n != headerVOSize { + return 0, fmt.Errorf("headerV0 ReadFrom: read %d bytes, expected %d", len(bytesHeader), headerVOSize) + } + + h.fileVersion = fileVersion(bytesHeader[0]) + h.fileType = fileType(bytesHeader[1]) + h.shareSize = binary.LittleEndian.Uint16(bytesHeader[4:6]) + h.squareSize = binary.LittleEndian.Uint16(bytesHeader[6:8]) + h.datahash = bytesHeader[8:40] + return int64(headerVOSize), err +} diff --git a/store/file/mempool.go b/store/file/mempool.go new file mode 100644 index 0000000000..f2eab0249f --- /dev/null +++ b/store/file/mempool.go @@ -0,0 +1,85 @@ +package file + +import ( + "runtime" + "sync" + + "github.com/celestiaorg/celestia-node/share" +) + +// TODO: need better name +var memPools poolsMap + +func init() { + memPools = make(map[int]*memPool) +} + +type poolsMap map[int]*memPool + +type memPool struct { + ods *sync.Pool + halfAxis *sync.Pool +} + +func (m poolsMap) get(size int) *memPool { + pool, ok := m[size] + if !ok { + pool = &memPool{ + ods: newOdsPool(size), + halfAxis: newHalfAxisPool(size), + } + m[size] = pool + } + return pool +} + +func (m *memPool) putSquare(s *[][]share.Share) { + m.ods.Put(s) +} + +func (m *memPool) square() [][]share.Share { + square := m.ods.Get().(*[][]share.Share) + runtime.SetFinalizer(square, m.putSquare) + return *square +} + +func (m *memPool) putHalfAxis(buf *[]share.Share) { + m.halfAxis.Put(buf) +} + +func (m *memPool) getHalfAxis() []share.Share { + half := m.halfAxis.Get().(*[]share.Share) + runtime.SetFinalizer(half, m.putHalfAxis) + return *half +} + +func newOdsPool(size int) *sync.Pool { + return &sync.Pool{ + New: func() interface{} { + rows := make([][]share.Share, size) + for i := range rows { + if rows[i] == nil { + rows[i] = newHalfAxis(size) + } + } + return &rows + }, + } +} + +func newHalfAxisPool(size int) *sync.Pool { + return &sync.Pool{ + New: func() interface{} { + half := newHalfAxis(size) + return &half + }, + } +} + +func newHalfAxis(size int) []share.Share { + shares := make([]share.Share, size) + for i := range shares { + shares[i] = make([]byte, share.Size) + } + return shares +} diff --git a/store/file/ods.go b/store/file/ods.go new file mode 100644 index 0000000000..2d695b5839 --- /dev/null +++ b/store/file/ods.go @@ -0,0 +1,280 @@ +package file + +import ( + "context" + "fmt" + "io" + "os" + "sync" + + "github.com/celestiaorg/rsmt2d" + + "github.com/celestiaorg/celestia-node/share" + eds "github.com/celestiaorg/celestia-node/share/new_eds" + "github.com/celestiaorg/celestia-node/share/shwap" +) + +var _ eds.AccessorCloser = (*ODSFile)(nil) + +type ODSFile struct { + path string + hdr *headerV0 + fl *os.File + + lock sync.RWMutex + // ods stores an in-memory cache of the original data square to enhance read performance. This + // cache is particularly beneficial for operations that require reading the entire square, such as: + // - Serving samples from the fourth quadrant of the square, which necessitates reconstructing data + // from all rows. - Streaming the entire ODS by Reader(), ensuring efficient data delivery without + // repeated file reads. - Serving full ODS data by Shares(). + // Storing the square in memory allows for efficient single-read operations, avoiding the need for + // piecemeal reads by rows or columns, and facilitates quick access to data for these operations. + ods square +} + +// OpenODSFile opens an existing file. File has to be closed after usage. +func OpenODSFile(path string) (*ODSFile, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + + h, err := readHeader(f) + if err != nil { + return nil, err + } + + return &ODSFile{ + path: path, + hdr: h, + fl: f, + }, nil +} + +// CreateODSFile creates a new file. File has to be closed after usage. +func CreateODSFile( + path string, + datahash share.DataHash, + eds *rsmt2d.ExtendedDataSquare, +) (*ODSFile, error) { + f, err := os.Create(path) + if err != nil { + return nil, fmt.Errorf("file create: %w", err) + } + + h := &headerV0{ + fileVersion: fileV0, + fileType: ODS, + shareSize: share.Size, + squareSize: uint16(eds.Width()), + datahash: datahash, + } + + err = writeODSFile(f, h, eds) + if err != nil { + return nil, fmt.Errorf("writing ODS file: %w", err) + } + + err = f.Sync() + if err != nil { + return nil, fmt.Errorf("syncing file: %w", err) + } + + return &ODSFile{ + path: path, + fl: f, + hdr: h, + }, nil +} + +func writeODSFile(w io.Writer, h *headerV0, eds *rsmt2d.ExtendedDataSquare) error { + err := writeHeader(w, h) + if err != nil { + return err + } + + for _, shr := range eds.FlattenedODS() { + if _, err := w.Write(shr); err != nil { + return err + } + } + return nil +} + +// Size returns square size of the Accessor. +func (f *ODSFile) Size(context.Context) int { + return f.size() +} + +func (f *ODSFile) size() int { + return int(f.hdr.squareSize) +} + +// Close closes the file. +func (f *ODSFile) Close() error { + return f.fl.Close() +} + +// Sample returns share and corresponding proof for row and column indices. Implementation can +// choose which axis to use for proof. Chosen axis for proof should be indicated in the returned +// Sample. +func (f *ODSFile) Sample(ctx context.Context, rowIdx, colIdx int) (shwap.Sample, error) { + // Sample proof axis is selected to optimize read performance. + // - For the first and second quadrants, we read the row axis because it is more efficient to read + // single row than reading full ODS to calculate single column + // - For the third quadrant, we read the column axis because it is more efficient to read single + // column than reading full ODS to calculate single row + // - For the fourth quadrant, it does not matter which axis we read because we need to read full ODS + // to calculate the sample + axisType, axisIdx, shrIdx := rsmt2d.Row, rowIdx, colIdx + if colIdx < f.size()/2 && rowIdx >= f.size()/2 { + axisType, axisIdx, shrIdx = rsmt2d.Col, colIdx, rowIdx + } + + axis, err := f.axis(ctx, axisType, axisIdx) + if err != nil { + return shwap.Sample{}, fmt.Errorf("reading axis: %w", err) + } + + return shwap.SampleFromShares(axis, axisType, axisIdx, shrIdx) +} + +// AxisHalf returns half of shares axis of the given type and index. Side is determined by +// implementation. Implementations should indicate the side in the returned AxisHalf. +func (f *ODSFile) AxisHalf(_ context.Context, axisType rsmt2d.Axis, axisIdx int) (eds.AxisHalf, error) { + // Read the axis from the file if the axis is a row and from the top half of the square, or if the + // axis is a column and from the left half of the square. + if axisIdx < f.size()/2 { + half, err := f.readAxisHalf(axisType, axisIdx) + if err != nil { + return eds.AxisHalf{}, fmt.Errorf("reading axis half: %w", err) + } + return half, nil + } + + // if axis is from the second half of the square, read full ODS and compute the axis half + err := f.readODS() + if err != nil { + return eds.AxisHalf{}, err + } + + half, err := f.ods.computeAxisHalf(axisType, axisIdx) + if err != nil { + return eds.AxisHalf{}, fmt.Errorf("computing axis half: %w", err) + } + return half, nil +} + +// RowNamespaceData returns data for the given namespace and row index. +func (f *ODSFile) RowNamespaceData( + ctx context.Context, + namespace share.Namespace, + rowIdx int, +) (shwap.RowNamespaceData, error) { + shares, err := f.axis(ctx, rsmt2d.Row, rowIdx) + if err != nil { + return shwap.RowNamespaceData{}, err + } + return shwap.RowNamespaceDataFromShares(shares, namespace, rowIdx) +} + +// Shares returns data shares extracted from the Accessor. +func (f *ODSFile) Shares(context.Context) ([]share.Share, error) { + err := f.readODS() + if err != nil { + return nil, err + } + return f.ods.shares() +} + +func (f *ODSFile) readAxisHalf(axisType rsmt2d.Axis, axisIdx int) (eds.AxisHalf, error) { + f.lock.RLock() + ODS := f.ods + f.lock.RUnlock() + if ODS != nil { + return f.ods.axisHalf(axisType, axisIdx) + } + + switch axisType { + case rsmt2d.Col: + col, err := f.readCol(axisIdx, 0) + return eds.AxisHalf{ + Shares: col, + IsParity: false, + }, err + case rsmt2d.Row: + row, err := f.readRow(axisIdx) + return eds.AxisHalf{ + Shares: row, + IsParity: false, + }, err + } + return eds.AxisHalf{}, fmt.Errorf("unknown axis") +} + +func (f *ODSFile) readODS() error { + f.lock.Lock() + defer f.lock.Unlock() + if f.ods != nil { + return nil + } + + // reset file pointer to the beginning of the file shares data + _, err := f.fl.Seek(int64(f.hdr.Size()), io.SeekStart) + if err != nil { + return fmt.Errorf("discarding header: %w", err) + } + + square, err := readSquare(f.fl, share.Size, f.size()) + if err != nil { + return fmt.Errorf("reading ODS: %w", err) + } + f.ods = square + return nil +} + +func (f *ODSFile) readRow(idx int) ([]share.Share, error) { + shrLn := int(f.hdr.shareSize) + odsLn := f.size() / 2 + + shares := make([]share.Share, odsLn) + + pos := idx * odsLn + offset := f.hdr.Size() + pos*shrLn + + axsData := make([]byte, odsLn*shrLn) + if _, err := f.fl.ReadAt(axsData, int64(offset)); err != nil { + return nil, err + } + + for i := range shares { + shares[i] = axsData[i*shrLn : (i+1)*shrLn] + } + return shares, nil +} + +func (f *ODSFile) readCol(axisIdx, quadrantIdx int) ([]share.Share, error) { + shrLn := int(f.hdr.shareSize) + odsLn := f.size() / 2 + quadrantOffset := quadrantIdx * odsLn * odsLn * shrLn + + shares := memPools.get(odsLn).getHalfAxis() + for i := range shares { + pos := axisIdx + i*odsLn + offset := f.hdr.Size() + quadrantOffset + pos*shrLn + + if _, err := f.fl.ReadAt(shares[i], int64(offset)); err != nil { + return nil, err + } + } + return shares, nil +} + +func (f *ODSFile) axis(ctx context.Context, axisType rsmt2d.Axis, axisIdx int) ([]share.Share, error) { + half, err := f.AxisHalf(ctx, axisType, axisIdx) + if err != nil { + return nil, err + } + + return half.Extended() +} diff --git a/store/file/ods_test.go b/store/file/ods_test.go new file mode 100644 index 0000000000..8f91974446 --- /dev/null +++ b/store/file/ods_test.go @@ -0,0 +1,122 @@ +package file + +import ( + "context" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/tendermint/tendermint/libs/rand" + + "github.com/celestiaorg/rsmt2d" + + "github.com/celestiaorg/celestia-node/share" + "github.com/celestiaorg/celestia-node/share/eds/edstest" + eds "github.com/celestiaorg/celestia-node/share/new_eds" +) + +func TestCreateODSFile(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + t.Cleanup(cancel) + + edsIn := edstest.RandEDS(t, 8) + datahash := share.DataHash(rand.Bytes(32)) + path := t.TempDir() + "/" + datahash.String() + f, err := CreateODSFile(path, datahash, edsIn) + require.NoError(t, err) + + shares, err := f.Shares(ctx) + require.NoError(t, err) + expected := edsIn.FlattenedODS() + require.Equal(t, expected, shares) + require.Equal(t, datahash, f.hdr.datahash) + require.NoError(t, f.Close()) + + f, err = OpenODSFile(path) + require.NoError(t, err) + shares, err = f.Shares(ctx) + require.NoError(t, err) + require.Equal(t, expected, shares) + require.Equal(t, datahash, f.hdr.datahash) + require.NoError(t, f.Close()) +} + +func TestReadODSFromFile(t *testing.T) { + eds := edstest.RandEDS(t, 8) + path := t.TempDir() + "/testfile" + f, err := CreateODSFile(path, []byte{}, eds) + require.NoError(t, err) + + err = f.readODS() + require.NoError(t, err) + for i, row := range f.ods { + original := eds.Row(uint(i))[:eds.Width()/2] + require.True(t, len(original) == len(row)) + require.Equal(t, original, row) + } +} + +func TestODSFile(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + t.Cleanup(cancel) + + ODSSize := 8 + eds.TestSuiteAccessor(ctx, t, createODSFile, ODSSize) +} + +// ReconstructSome, default codec +// BenchmarkAxisFromODSFile/Size:32/Axis:row/squareHalf:first(original)-10 455848 2588 ns/op +// BenchmarkAxisFromODSFile/Size:32/Axis:row/squareHalf:second(extended)-10 9015 203950 ns/op +// BenchmarkAxisFromODSFile/Size:32/Axis:col/squareHalf:first(original)-10 52734 21178 ns/op +// BenchmarkAxisFromODSFile/Size:32/Axis:col/squareHalf:second(extended)-10 8830 127452 ns/op +// BenchmarkAxisFromODSFile/Size:64/Axis:row/squareHalf:first(original)-10 303834 4763 ns/op +// BenchmarkAxisFromODSFile/Size:64/Axis:row/squareHalf:second(extended)-10 2940 426246 ns/op +// BenchmarkAxisFromODSFile/Size:64/Axis:col/squareHalf:first(original)-10 27758 42842 ns/op +// BenchmarkAxisFromODSFile/Size:64/Axis:col/squareHalf:second(extended)-10 3385 353868 ns/op +// BenchmarkAxisFromODSFile/Size:128/Axis:row/squareHalf:first(original)-10 172086 6455 ns/op +// BenchmarkAxisFromODSFile/Size:128/Axis:row/squareHalf:second(extended)-10 672 1550386 ns/op +// BenchmarkAxisFromODSFile/Size:128/Axis:col/squareHalf:first(original)-10 14202 84316 ns/op +// BenchmarkAxisFromODSFile/Size:128/Axis:col/squareHalf:second(extended)-10 978 1230980 ns/op +func BenchmarkAxisFromODSFile(b *testing.B) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + b.Cleanup(cancel) + + minSize, maxSize := 32, 32 + newFile := func(size int) eds.Accessor { + eds := edstest.RandEDS(b, size) + return createODSFile(b, eds) + } + eds.BenchGetHalfAxisFromAccessor(ctx, b, newFile, minSize, maxSize) +} + +// BenchmarkShareFromODSFile/Size:32/Axis:row/squareHalf:first(original)-10 10339 111328 ns/op +// BenchmarkShareFromODSFile/Size:32/Axis:row/squareHalf:second(extended)-10 3392 359180 ns/op +// BenchmarkShareFromODSFile/Size:32/Axis:col/squareHalf:first(original)-10 8925 131352 ns/op +// BenchmarkShareFromODSFile/Size:32/Axis:col/squareHalf:second(extended)-10 3447 346218 ns/op +// BenchmarkShareFromODSFile/Size:64/Axis:row/squareHalf:first(original)-10 5503 215833 ns/op +// BenchmarkShareFromODSFile/Size:64/Axis:row/squareHalf:second(extended)-10 1231 1001053 ns/op +// BenchmarkShareFromODSFile/Size:64/Axis:col/squareHalf:first(original)-10 4711 250001 ns/op +// BenchmarkShareFromODSFile/Size:64/Axis:col/squareHalf:second(extended)-10 1315 910079 ns/op +// BenchmarkShareFromODSFile/Size:128/Axis:row/squareHalf:first(original)-10 2364 435748 ns/op +// BenchmarkShareFromODSFile/Size:128/Axis:row/squareHalf:second(extended)-10 358 3330620 ns/op +// BenchmarkShareFromODSFile/Size:128/Axis:col/squareHalf:first(original)-10 2114 514642 ns/op +// BenchmarkShareFromODSFile/Size:128/Axis:col/squareHalf:second(extended)-10 373 3068104 ns/op +func BenchmarkShareFromODSFile(b *testing.B) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + b.Cleanup(cancel) + + minSize, maxSize := 32, 32 + newFile := func(size int) eds.Accessor { + eds := edstest.RandEDS(b, size) + return createODSFile(b, eds) + } + eds.BenchGetSampleFromAccessor(ctx, b, newFile, minSize, maxSize) +} + +func createODSFile(t testing.TB, eds *rsmt2d.ExtendedDataSquare) eds.Accessor { + path := t.TempDir() + "/" + strconv.Itoa(rand.Intn(1000)) + fl, err := CreateODSFile(path, []byte{}, eds) + require.NoError(t, err) + return fl +} diff --git a/store/file/square.go b/store/file/square.go new file mode 100644 index 0000000000..325abd1443 --- /dev/null +++ b/store/file/square.go @@ -0,0 +1,138 @@ +package file + +import ( + "bufio" + "fmt" + "io" + + "golang.org/x/sync/errgroup" + + "github.com/celestiaorg/rsmt2d" + + "github.com/celestiaorg/celestia-node/share" + eds "github.com/celestiaorg/celestia-node/share/new_eds" +) + +type square [][]share.Share + +// readSquare reads Shares from the reader and returns a square. It assumes that the reader is +// positioned at the beginning of the Shares. It knows the size of the Shares and the size of the +// square, so reads from reader are limited to exactly the amount of data required. +func readSquare(r io.Reader, shareSize, edsSize int) (square, error) { + odsLn := edsSize / 2 + + // get pre-allocated square and buffer from memPools + square := memPools.get(odsLn).square() + + br := bufio.NewReaderSize(r, 4096) + var total int + for i := 0; i < odsLn; i++ { + for j := 0; j < odsLn; j++ { + n, err := io.ReadFull(br, square[i][j]) + if err != nil { + return nil, fmt.Errorf("reading share: %w, bytes read: %v", err, total+n) + } + if n != shareSize { + return nil, fmt.Errorf("share size mismatch: expected %v, got %v", shareSize, n) + } + total += n + } + } + return square, nil +} + +func (s square) size() int { + return len(s) +} + +func (s square) shares() ([]share.Share, error) { + shares := make([]share.Share, 0, s.size()*s.size()) + for _, row := range s { + shares = append(shares, row...) + } + return shares, nil +} + +func (s square) axisHalf(axisType rsmt2d.Axis, axisIdx int) (eds.AxisHalf, error) { + if s == nil { + return eds.AxisHalf{}, fmt.Errorf("square is nil") + } + + if axisIdx >= s.size() { + return eds.AxisHalf{}, fmt.Errorf("index is out of square bounds") + } + + // square stores rows directly in high level slice, so we can return by accessing row by index + if axisType == rsmt2d.Row { + row := s[axisIdx] + return eds.AxisHalf{ + Shares: row, + IsParity: false, + }, nil + } + + // construct half column from row ordered square + col := make([]share.Share, s.size()) + for i := 0; i < s.size(); i++ { + col[i] = s[i][axisIdx] + } + return eds.AxisHalf{ + Shares: col, + IsParity: false, + }, nil +} + +func (s square) computeAxisHalf( + axisType rsmt2d.Axis, + axisIdx int, +) (eds.AxisHalf, error) { + shares := make([]share.Share, s.size()) + + // extend opposite half of the square while collecting Shares for the first half of required axis + g := errgroup.Group{} + opposite := oppositeAxis(axisType) + for i := 0; i < s.size(); i++ { + g.Go(func() error { + half, err := s.axisHalf(opposite, i) + if err != nil { + return err + } + + enc, err := codec.Encoder(s.size() * 2) + if err != nil { + return fmt.Errorf("getting encoder: %w", err) + } + + shards := make([][]byte, s.size()*2) + if half.IsParity { + copy(shards[s.size():], half.Shares) + } else { + copy(shards, half.Shares) + } + + target := make([]bool, s.size()*2) + target[axisIdx] = true + + err = enc.ReconstructSome(shards, target) + if err != nil { + return fmt.Errorf("reconstruct some: %w", err) + } + + shares[i] = shards[axisIdx] + return nil + }) + } + + err := g.Wait() + return eds.AxisHalf{ + Shares: shares, + IsParity: false, + }, err +} + +func oppositeAxis(axis rsmt2d.Axis) rsmt2d.Axis { + if axis == rsmt2d.Col { + return rsmt2d.Row + } + return rsmt2d.Col +}