Skip to content

Commit

Permalink
add support for col proofs sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
Wondertan committed Sep 21, 2023
1 parent f1de60a commit e804d2c
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 59 deletions.
58 changes: 40 additions & 18 deletions share/eds/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,23 +86,42 @@ func (f *File) Header() Header {
return f.hdr
}

func (f *File) Axis(idx int, _ rsmt2d.Axis) ([]share.Share, error) {
// TODO: Add Col support
shrLn := int64(f.hdr.ShareSize)
sqrLn := int64(f.hdr.SquareSize)
rwwLn := shrLn * sqrLn
func (f *File) Axis(idx int, axis rsmt2d.Axis) ([]share.Share, error) {
shrLn := int(f.hdr.ShareSize)
sqrLn := int(f.hdr.SquareSize)

offset := int64(idx)*rwwLn + HeaderSize
rowdata := make([]byte, rwwLn)
if _, err := f.fl.ReadAt(rowdata, offset); err != nil {
return nil, err
}
shrs := make([]share.Share, sqrLn)
switch axis {
case rsmt2d.Col:
// [] [] [] []
// [] [] [] []
// [] [] [] []
// [] [] [] []

for i := 0; i < sqrLn; i++ {
pos := idx+i*sqrLn
offset := pos*shrLn+HeaderSize

shr := make(share.Share, shrLn)
if _, err := f.fl.ReadAt(shr, int64(offset)); err != nil {
return nil, err
}
shrs[i] = shr
}
case rsmt2d.Row:
pos := idx*sqrLn
offset := pos*shrLn + HeaderSize
axsData := make([]byte, sqrLn*shrLn)
if _, err := f.fl.ReadAt(axsData, int64(offset)); err != nil {
return nil, err
}

row := make([]share.Share, sqrLn)
for i := range row {
row[i] = rowdata[int64(i)*shrLn : (int64(i)+1)*shrLn]
for i := range shrs {
shrs[i] = axsData[i*shrLn : (i+1)*shrLn]
}
}
return row, nil

return shrs, nil
}

func (f *File) Share(idx int) (share.Share, error) {
Expand All @@ -120,21 +139,24 @@ func (f *File) Share(idx int) (share.Share, error) {
func (f *File) ShareWithProof(idx int, axis rsmt2d.Axis) (share.Share, nmt.Proof, error) {
// TODO: Cache the axis as well as computed tree
sqrLn := int(f.hdr.SquareSize)
rowIdx := idx / sqrLn
shrs, err := f.Axis(rowIdx, axis)
axsIdx, shrIdx := idx/sqrLn, idx%sqrLn
if axis == rsmt2d.Col {
axsIdx, shrIdx = shrIdx, axsIdx
}

shrs, err := f.Axis(axsIdx, axis)
if err != nil {
return nil, nmt.Proof{}, err
}

tree := wrapper.NewErasuredNamespacedMerkleTree(uint64(sqrLn/2), uint(rowIdx))
tree := wrapper.NewErasuredNamespacedMerkleTree(uint64(sqrLn/2), uint(axsIdx))
for _, shr := range shrs {
err = tree.Push(shr)
if err != nil {
return nil, nmt.Proof{}, err
}
}

shrIdx := idx % sqrLn
proof, err := tree.ProveRange(shrIdx, shrIdx+1)
if err != nil {
return nil, nmt.Proof{}, err
Expand Down
72 changes: 47 additions & 25 deletions share/eds/file_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,20 @@ import (
"crypto/sha256"
"testing"

"github.com/celestiaorg/celestia-node/share"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/celestiaorg/rsmt2d"

"github.com/celestiaorg/celestia-node/share"
"github.com/celestiaorg/celestia-node/share/eds/edstest"
)

func TestFile(t *testing.T) {
path := t.TempDir() + "/testfile"
eds := edstest.RandEDS(t, 16)
eds := edstest.RandEDS(t, 8)
root, err := share.NewRoot(eds)
require.NoError(t, err)

fl, err := CreateFile(path, eds)
require.NoError(t, err)
Expand All @@ -25,33 +27,40 @@ func TestFile(t *testing.T) {
fl, err = OpenFile(path)
require.NoError(t, err)

for i := 0; i < int(eds.Width()); i++ {
row, err := fl.Axis(i, rsmt2d.Row)
require.NoError(t, err)
assert.EqualValues(t, eds.Row(uint(i)), row)
axis := []rsmt2d.Axis{rsmt2d.Col, rsmt2d.Row}
for _, axis := range axis {
for i := 0; i < int(eds.Width()); i++ {
row, err := fl.Axis(i, axis)
require.NoError(t, err)
assert.EqualValues(t, getAxis(i, axis, eds), row)
}
}

width := int(eds.Width())
for i := 0; i < width*width; i++ {
row, col := uint(i/width), uint(i%width)
shr, err := fl.Share(i)
require.NoError(t, err)
assert.EqualValues(t, eds.GetCell(row, col), shr)

shr, proof, err := fl.ShareWithProof(i, rsmt2d.Row)
require.NoError(t, err)
assert.EqualValues(t, eds.GetCell(row, col), shr)

roots, err := eds.RowRoots()
require.NoError(t, err)

namespace := share.ParitySharesNamespace
if int(row) < width/2 && int(col) < width/2 {
namespace = share.GetNamespace(shr)
}
for _, axis := range axis {
for i := 0; i < width*width; i++ {
row, col := uint(i/width), uint(i%width)
shr, err := fl.Share(i)
require.NoError(t, err)
assert.EqualValues(t, eds.GetCell(row, col), shr)

shr, proof, err := fl.ShareWithProof(i, axis)
require.NoError(t, err)
assert.EqualValues(t, eds.GetCell(row, col), shr)

namespace := share.ParitySharesNamespace
if int(row) < width/2 && int(col) < width/2 {
namespace = share.GetNamespace(shr)
}

ok := proof.VerifyInclusion(sha256.New(), namespace.ToNMT(), [][]byte{shr}, roots[row])
assert.True(t, ok)
dahroot := root.RowRoots[row]
if axis == rsmt2d.Col {
dahroot = root.ColumnRoots[col]
}

ok := proof.VerifyInclusion(sha256.New(), namespace.ToNMT(), [][]byte{shr}, dahroot)
assert.True(t, ok)
}
}

out, err := fl.EDS()
Expand All @@ -61,3 +70,16 @@ func TestFile(t *testing.T) {
err = fl.Close()
require.NoError(t, err)
}


// TODO(@Wondertan): Should be a method on eds
func getAxis(idx int, axis rsmt2d.Axis, eds *rsmt2d.ExtendedDataSquare) [][]byte {
switch axis {
case rsmt2d.Row:
return eds.Row(uint(idx))
case rsmt2d.Col:
return eds.Col(uint(idx))
default:
panic("")
}
}
29 changes: 16 additions & 13 deletions share/ipldv2/ipldv2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,27 @@ func TestV2Roundtrip(t *testing.T) {
dn.ConnectAll()

square := edstest.RandEDS(t, 16)
axis := []rsmt2d.Axis{rsmt2d.Col, rsmt2d.Row}
width := int(square.Width())
for i := 0; i < width*width; i++ {
smpl, err := NewSampleFrom(square, i, rsmt2d.Row)
require.NoError(t, err)
for _, axis := range axis {
for i := 0; i < width*width; i++ {
smpl, err := NewSampleFrom(square, i, axis)
require.NoError(t, err)

err = smpl.Validate()
require.NoError(t, err)
err = smpl.Validate()
require.NoError(t, err)

blkIn, err := smpl.IPLDBlock()
require.NoError(t, err)
blkIn, err := smpl.IPLDBlock()
require.NoError(t, err)

err = srv1.AddBlock(ctx, blkIn)
require.NoError(t, err)
err = srv1.AddBlock(ctx, blkIn)
require.NoError(t, err)

blkOut, err := srv2.GetBlock(ctx, blkIn.Cid())
require.NoError(t, err)
blkOut, err := srv2.GetBlock(ctx, blkIn.Cid())
require.NoError(t, err)

assert.EqualValues(t, blkIn.RawData(), blkOut.RawData())
assert.EqualValues(t, blkIn.Cid(), blkOut.Cid())
assert.EqualValues(t, blkIn.RawData(), blkOut.RawData())
assert.EqualValues(t, blkIn.Cid(), blkOut.Cid())
}
}
}
17 changes: 14 additions & 3 deletions share/ipldv2/sample.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,26 @@ func NewSample(root *share.Root, idx int, axis rsmt2d.Axis, shr share.Share, pro
// NewSampleFrom samples the EDS and constructs a new Sample.
func NewSampleFrom(eds *rsmt2d.ExtendedDataSquare, idx int, axis rsmt2d.Axis) (*Sample, error) {
sqrLn := int(eds.Width())
rowIdx, shrIdx := idx/sqrLn, idx%sqrLn
shrs := eds.Row(uint(rowIdx))
axisIdx, shrIdx := idx/sqrLn, idx%sqrLn

// TODO(@Wondertan): Should be an rsmt2d method
var shrs [][]byte
switch axis {
case rsmt2d.Row:
shrs = eds.Row(uint(axisIdx))
case rsmt2d.Col:
axisIdx, shrIdx = shrIdx, axisIdx
shrs = eds.Col(uint(axisIdx))
default:
panic("invalid axis")
}

root, err := share.NewRoot(eds)
if err != nil {
return nil, err
}

tree := wrapper.NewErasuredNamespacedMerkleTree(uint64(sqrLn/2), uint(rowIdx))
tree := wrapper.NewErasuredNamespacedMerkleTree(uint64(sqrLn/2), uint(axisIdx))
for _, shr := range shrs {
err := tree.Push(shr)
if err != nil {
Expand Down

0 comments on commit e804d2c

Please sign in to comment.