From e804d2c3529436e73e4839ffe567d5ca4e05a3b4 Mon Sep 17 00:00:00 2001 From: Wondertan Date: Thu, 21 Sep 2023 13:06:27 +0200 Subject: [PATCH] add support for col proofs sampling --- share/eds/file.go | 58 ++++++++++++++++++++---------- share/eds/file_test.go | 72 ++++++++++++++++++++++++------------- share/ipldv2/ipldv2_test.go | 29 ++++++++------- share/ipldv2/sample.go | 17 +++++++-- 4 files changed, 117 insertions(+), 59 deletions(-) diff --git a/share/eds/file.go b/share/eds/file.go index 379f675251..64f79dfd6b 100644 --- a/share/eds/file.go +++ b/share/eds/file.go @@ -86,23 +86,42 @@ func (f *File) Header() Header { return f.hdr } -func (f *File) Axis(idx int, _ rsmt2d.Axis) ([]share.Share, error) { - // TODO: Add Col support - shrLn := int64(f.hdr.ShareSize) - sqrLn := int64(f.hdr.SquareSize) - rwwLn := shrLn * sqrLn +func (f *File) Axis(idx int, axis rsmt2d.Axis) ([]share.Share, error) { + shrLn := int(f.hdr.ShareSize) + sqrLn := int(f.hdr.SquareSize) - offset := int64(idx)*rwwLn + HeaderSize - rowdata := make([]byte, rwwLn) - if _, err := f.fl.ReadAt(rowdata, offset); err != nil { - return nil, err - } + shrs := make([]share.Share, sqrLn) + switch axis { + case rsmt2d.Col: + // [] [] [] [] + // [] [] [] [] + // [] [] [] [] + // [] [] [] [] + + for i := 0; i < sqrLn; i++ { + pos := idx+i*sqrLn + offset := pos*shrLn+HeaderSize + + shr := make(share.Share, shrLn) + if _, err := f.fl.ReadAt(shr, int64(offset)); err != nil { + return nil, err + } + shrs[i] = shr + } + case rsmt2d.Row: + pos := idx*sqrLn + offset := pos*shrLn + HeaderSize + axsData := make([]byte, sqrLn*shrLn) + if _, err := f.fl.ReadAt(axsData, int64(offset)); err != nil { + return nil, err + } - row := make([]share.Share, sqrLn) - for i := range row { - row[i] = rowdata[int64(i)*shrLn : (int64(i)+1)*shrLn] + for i := range shrs { + shrs[i] = axsData[i*shrLn : (i+1)*shrLn] + } } - return row, nil + + return shrs, nil } func (f *File) Share(idx int) (share.Share, error) { @@ -120,13 +139,17 @@ func (f *File) Share(idx int) (share.Share, error) { func (f *File) ShareWithProof(idx int, axis rsmt2d.Axis) (share.Share, nmt.Proof, error) { // TODO: Cache the axis as well as computed tree sqrLn := int(f.hdr.SquareSize) - rowIdx := idx / sqrLn - shrs, err := f.Axis(rowIdx, axis) + axsIdx, shrIdx := idx/sqrLn, idx%sqrLn + if axis == rsmt2d.Col { + axsIdx, shrIdx = shrIdx, axsIdx + } + + shrs, err := f.Axis(axsIdx, axis) if err != nil { return nil, nmt.Proof{}, err } - tree := wrapper.NewErasuredNamespacedMerkleTree(uint64(sqrLn/2), uint(rowIdx)) + tree := wrapper.NewErasuredNamespacedMerkleTree(uint64(sqrLn/2), uint(axsIdx)) for _, shr := range shrs { err = tree.Push(shr) if err != nil { @@ -134,7 +157,6 @@ func (f *File) ShareWithProof(idx int, axis rsmt2d.Axis) (share.Share, nmt.Proof } } - shrIdx := idx % sqrLn proof, err := tree.ProveRange(shrIdx, shrIdx+1) if err != nil { return nil, nmt.Proof{}, err diff --git a/share/eds/file_test.go b/share/eds/file_test.go index e9e3ea8647..9d52f1b05a 100644 --- a/share/eds/file_test.go +++ b/share/eds/file_test.go @@ -4,18 +4,20 @@ import ( "crypto/sha256" "testing" + "github.com/celestiaorg/celestia-node/share" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/celestiaorg/rsmt2d" - "github.com/celestiaorg/celestia-node/share" "github.com/celestiaorg/celestia-node/share/eds/edstest" ) func TestFile(t *testing.T) { path := t.TempDir() + "/testfile" - eds := edstest.RandEDS(t, 16) + eds := edstest.RandEDS(t, 8) + root, err := share.NewRoot(eds) + require.NoError(t, err) fl, err := CreateFile(path, eds) require.NoError(t, err) @@ -25,33 +27,40 @@ func TestFile(t *testing.T) { fl, err = OpenFile(path) require.NoError(t, err) - for i := 0; i < int(eds.Width()); i++ { - row, err := fl.Axis(i, rsmt2d.Row) - require.NoError(t, err) - assert.EqualValues(t, eds.Row(uint(i)), row) + axis := []rsmt2d.Axis{rsmt2d.Col, rsmt2d.Row} + for _, axis := range axis { + for i := 0; i < int(eds.Width()); i++ { + row, err := fl.Axis(i, axis) + require.NoError(t, err) + assert.EqualValues(t, getAxis(i, axis, eds), row) + } } width := int(eds.Width()) - for i := 0; i < width*width; i++ { - row, col := uint(i/width), uint(i%width) - shr, err := fl.Share(i) - require.NoError(t, err) - assert.EqualValues(t, eds.GetCell(row, col), shr) - - shr, proof, err := fl.ShareWithProof(i, rsmt2d.Row) - require.NoError(t, err) - assert.EqualValues(t, eds.GetCell(row, col), shr) - - roots, err := eds.RowRoots() - require.NoError(t, err) - - namespace := share.ParitySharesNamespace - if int(row) < width/2 && int(col) < width/2 { - namespace = share.GetNamespace(shr) - } + for _, axis := range axis { + for i := 0; i < width*width; i++ { + row, col := uint(i/width), uint(i%width) + shr, err := fl.Share(i) + require.NoError(t, err) + assert.EqualValues(t, eds.GetCell(row, col), shr) + + shr, proof, err := fl.ShareWithProof(i, axis) + require.NoError(t, err) + assert.EqualValues(t, eds.GetCell(row, col), shr) + + namespace := share.ParitySharesNamespace + if int(row) < width/2 && int(col) < width/2 { + namespace = share.GetNamespace(shr) + } - ok := proof.VerifyInclusion(sha256.New(), namespace.ToNMT(), [][]byte{shr}, roots[row]) - assert.True(t, ok) + dahroot := root.RowRoots[row] + if axis == rsmt2d.Col { + dahroot = root.ColumnRoots[col] + } + + ok := proof.VerifyInclusion(sha256.New(), namespace.ToNMT(), [][]byte{shr}, dahroot) + assert.True(t, ok) + } } out, err := fl.EDS() @@ -61,3 +70,16 @@ func TestFile(t *testing.T) { err = fl.Close() require.NoError(t, err) } + + +// TODO(@Wondertan): Should be a method on eds +func getAxis(idx int, axis rsmt2d.Axis, eds *rsmt2d.ExtendedDataSquare) [][]byte { + switch axis { + case rsmt2d.Row: + return eds.Row(uint(idx)) + case rsmt2d.Col: + return eds.Col(uint(idx)) + default: + panic("") + } +} diff --git a/share/ipldv2/ipldv2_test.go b/share/ipldv2/ipldv2_test.go index 08e8e6e1d8..121f3e079d 100644 --- a/share/ipldv2/ipldv2_test.go +++ b/share/ipldv2/ipldv2_test.go @@ -26,24 +26,27 @@ func TestV2Roundtrip(t *testing.T) { dn.ConnectAll() square := edstest.RandEDS(t, 16) + axis := []rsmt2d.Axis{rsmt2d.Col, rsmt2d.Row} width := int(square.Width()) - for i := 0; i < width*width; i++ { - smpl, err := NewSampleFrom(square, i, rsmt2d.Row) - require.NoError(t, err) + for _, axis := range axis { + for i := 0; i < width*width; i++ { + smpl, err := NewSampleFrom(square, i, axis) + require.NoError(t, err) - err = smpl.Validate() - require.NoError(t, err) + err = smpl.Validate() + require.NoError(t, err) - blkIn, err := smpl.IPLDBlock() - require.NoError(t, err) + blkIn, err := smpl.IPLDBlock() + require.NoError(t, err) - err = srv1.AddBlock(ctx, blkIn) - require.NoError(t, err) + err = srv1.AddBlock(ctx, blkIn) + require.NoError(t, err) - blkOut, err := srv2.GetBlock(ctx, blkIn.Cid()) - require.NoError(t, err) + blkOut, err := srv2.GetBlock(ctx, blkIn.Cid()) + require.NoError(t, err) - assert.EqualValues(t, blkIn.RawData(), blkOut.RawData()) - assert.EqualValues(t, blkIn.Cid(), blkOut.Cid()) + assert.EqualValues(t, blkIn.RawData(), blkOut.RawData()) + assert.EqualValues(t, blkIn.Cid(), blkOut.Cid()) + } } } diff --git a/share/ipldv2/sample.go b/share/ipldv2/sample.go index aecb7b8598..25242f4042 100644 --- a/share/ipldv2/sample.go +++ b/share/ipldv2/sample.go @@ -59,15 +59,26 @@ func NewSample(root *share.Root, idx int, axis rsmt2d.Axis, shr share.Share, pro // NewSampleFrom samples the EDS and constructs a new Sample. func NewSampleFrom(eds *rsmt2d.ExtendedDataSquare, idx int, axis rsmt2d.Axis) (*Sample, error) { sqrLn := int(eds.Width()) - rowIdx, shrIdx := idx/sqrLn, idx%sqrLn - shrs := eds.Row(uint(rowIdx)) + axisIdx, shrIdx := idx/sqrLn, idx%sqrLn + + // TODO(@Wondertan): Should be an rsmt2d method + var shrs [][]byte + switch axis { + case rsmt2d.Row: + shrs = eds.Row(uint(axisIdx)) + case rsmt2d.Col: + axisIdx, shrIdx = shrIdx, axisIdx + shrs = eds.Col(uint(axisIdx)) + default: + panic("invalid axis") + } root, err := share.NewRoot(eds) if err != nil { return nil, err } - tree := wrapper.NewErasuredNamespacedMerkleTree(uint64(sqrLn/2), uint(rowIdx)) + tree := wrapper.NewErasuredNamespacedMerkleTree(uint64(sqrLn/2), uint(axisIdx)) for _, shr := range shrs { err := tree.Push(shr) if err != nil {