Skip to content

Commit

Permalink
add validation file
Browse files Browse the repository at this point in the history
  • Loading branch information
walldiss committed Jun 11, 2024
1 parent c78f081 commit 9b7c3ba
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 0 deletions.
65 changes: 65 additions & 0 deletions store/file/validating.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package file

import (
"context"
"errors"
"fmt"

"github.com/celestiaorg/rsmt2d"

"github.com/celestiaorg/celestia-node/share"
eds "github.com/celestiaorg/celestia-node/share/new_eds"
"github.com/celestiaorg/celestia-node/share/shwap"
)

var _ eds.AccessorCloser = validation{}

// ErrOutOfBounds is returned whenever an index is out of bounds.
var ErrOutOfBounds = errors.New("index is out of bounds")

// validation is a AccessorCloser implementation that performs sanity checks on methods. It wraps
// another AccessorCloser and performs bounds checks on index arguments.
type validation struct {
eds.AccessorCloser
}

func WithValidation(f eds.AccessorCloser) eds.AccessorCloser {
return &validation{AccessorCloser: f}
}

func (f validation) Sample(ctx context.Context, rowIdx, colIdx int) (shwap.Sample, error) {
if err := validateIndexBounds(ctx, f, rowIdx); err != nil {
return shwap.Sample{}, fmt.Errorf("col: %w", err)
}
if err := validateIndexBounds(ctx, f, colIdx); err != nil {
return shwap.Sample{}, fmt.Errorf("row: %w", err)
}
return f.AccessorCloser.Sample(ctx, rowIdx, colIdx)
}

func (f validation) AxisHalf(ctx context.Context, axisType rsmt2d.Axis, axisIdx int) (eds.AxisHalf, error) {
if err := validateIndexBounds(ctx, f, axisIdx); err != nil {
return eds.AxisHalf{}, fmt.Errorf("%s: %w", axisType, err)
}
return f.AccessorCloser.AxisHalf(ctx, axisType, axisIdx)
}

func (f validation) RowNamespaceData(
ctx context.Context,
namespace share.Namespace,
rowIdx int,
) (shwap.RowNamespaceData, error) {
if err := validateIndexBounds(ctx, f, rowIdx); err != nil {
return shwap.RowNamespaceData{}, fmt.Errorf("row: %w", err)
}
return f.AccessorCloser.RowNamespaceData(ctx, namespace, rowIdx)
}

// validateIndexBounds checks if the index is within the bounds of the eds.
func validateIndexBounds(ctx context.Context, f eds.AccessorCloser, idx int) error {
size := f.Size(ctx)
if idx < 0 || idx >= size {
return fmt.Errorf("%w: index %d is out of bounds: [0, %d)", ErrOutOfBounds, idx, size)
}
return nil
}
114 changes: 114 additions & 0 deletions store/file/validating_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
package file

import (
"context"
"errors"
"io"
"testing"

"github.com/stretchr/testify/require"

"github.com/celestiaorg/rsmt2d"

"github.com/celestiaorg/celestia-node/share/eds/edstest"
"github.com/celestiaorg/celestia-node/share/ipld"
eds "github.com/celestiaorg/celestia-node/share/new_eds"
"github.com/celestiaorg/celestia-node/share/sharetest"
)

func TestValidation_Sample(t *testing.T) {
tests := []struct {
name string
rowIdx, colIdx int
odsSize int
expectFail bool
}{
{"ValidIndices", 3, 2, 4, false},
{"OutOfBoundsX", 8, 3, 4, true},
{"OutOfBoundsY", 3, 8, 4, true},
{"NegativeX", -1, 4, 8, true},
{"NegativeY", 3, -1, 8, true},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
randEDS := edstest.RandEDS(t, tt.odsSize)
accessor := &eds.Rsmt2D{ExtendedDataSquare: randEDS}
validation := WithValidation(withCloser(accessor, nil))

_, err := validation.Sample(context.Background(), tt.rowIdx, tt.colIdx)
if tt.expectFail {
require.ErrorIs(t, err, ErrOutOfBounds)
} else {
require.NoError(t, err)
}
})
}
}

func TestValidation_AxisHalf(t *testing.T) {
tests := []struct {
name string
axisType rsmt2d.Axis
axisIdx int
odsSize int
expectFail bool
}{
{"ValidIndex", rsmt2d.Row, 2, 4, false},
{"OutOfBounds", rsmt2d.Col, 8, 4, true},
{"NegativeIndex", rsmt2d.Row, -1, 4, true},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
randEDS := edstest.RandEDS(t, tt.odsSize)
accessor := &eds.Rsmt2D{ExtendedDataSquare: randEDS}
validation := WithValidation(withCloser(accessor, nil))

_, err := validation.AxisHalf(context.Background(), tt.axisType, tt.axisIdx)
if tt.expectFail {
require.ErrorIs(t, err, ErrOutOfBounds)
} else {
require.NoError(t, err)
}
})
}
}

func TestValidation_RowNamespaceData(t *testing.T) {
tests := []struct {
name string
rowIdx int
odsSize int
expectFail bool
}{
{"ValidIndex", 3, 4, false},
{"OutOfBounds", 8, 4, true},
{"NegativeIndex", -1, 4, true},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
randEDS := edstest.RandEDS(t, tt.odsSize)
accessor := &eds.Rsmt2D{ExtendedDataSquare: randEDS}
validation := WithValidation(withCloser(accessor, nil))

ns := sharetest.RandV0Namespace()
_, err := validation.RowNamespaceData(context.Background(), ns, tt.rowIdx)
if tt.expectFail {
require.ErrorIs(t, err, ErrOutOfBounds)
} else {
require.True(t, err == nil || errors.Is(err, ipld.ErrNamespaceOutsideRange))
}
})
}
}

func withCloser(accessor eds.Accessor, closer io.Closer) eds.AccessorCloser {
return &accessorCloser{Accessor: accessor, Closer: closer}
}

type accessorCloser struct {
eds.Accessor
io.Closer
}

0 comments on commit 9b7c3ba

Please sign in to comment.