diff --git a/share/new_eds/accessor.go b/share/new_eds/accessor.go index 09920161f6..8fe740b29e 100644 --- a/share/new_eds/accessor.go +++ b/share/new_eds/accessor.go @@ -32,3 +32,12 @@ type AccessorCloser interface { Accessor io.Closer } + +type accessorCloser struct { + Accessor + io.Closer +} + +func WithCloser(a Accessor, c io.Closer) AccessorCloser { + return &accessorCloser{a, c} +} diff --git a/share/new_eds/validation.go b/share/new_eds/validation.go new file mode 100644 index 0000000000..826cfa55a5 --- /dev/null +++ b/share/new_eds/validation.go @@ -0,0 +1,76 @@ +package eds + +import ( + "context" + "errors" + "fmt" + "sync/atomic" + + "github.com/celestiaorg/rsmt2d" + + "github.com/celestiaorg/celestia-node/share" + "github.com/celestiaorg/celestia-node/share/shwap" +) + +var _ Accessor = validation{} + +// ErrOutOfBounds is returned whenever an index is out of bounds. +var ErrOutOfBounds = errors.New("index is out of bounds") + +// validation is a Accessor implementation that performs sanity checks on methods. It wraps +// another Accessor and performs bounds checks on index arguments. +type validation struct { + Accessor + size *atomic.Int32 +} + +func WithValidation(f Accessor) Accessor { + return &validation{Accessor: f, size: new(atomic.Int32)} +} + +func (f validation) Size(ctx context.Context) int { + size := f.size.Load() + if size == 0 { + loaded := f.Accessor.Size(ctx) + f.size.Store(int32(loaded)) + return loaded + } + return int(size) +} + +func (f validation) Sample(ctx context.Context, rowIdx, colIdx int) (shwap.Sample, error) { + if err := validateIndexBounds(ctx, f, colIdx); err != nil { + return shwap.Sample{}, fmt.Errorf("col: %w", err) + } + if err := validateIndexBounds(ctx, f, rowIdx); err != nil { + return shwap.Sample{}, fmt.Errorf("row: %w", err) + } + return f.Accessor.Sample(ctx, rowIdx, colIdx) +} + +func (f validation) AxisHalf(ctx context.Context, axisType rsmt2d.Axis, axisIdx int) (AxisHalf, error) { + if err := validateIndexBounds(ctx, f, axisIdx); err != nil { + return AxisHalf{}, fmt.Errorf("%s: %w", axisType, err) + } + return f.Accessor.AxisHalf(ctx, axisType, axisIdx) +} + +func (f validation) RowNamespaceData( + ctx context.Context, + namespace share.Namespace, + rowIdx int, +) (shwap.RowNamespaceData, error) { + if err := validateIndexBounds(ctx, f, rowIdx); err != nil { + return shwap.RowNamespaceData{}, fmt.Errorf("row: %w", err) + } + return f.Accessor.RowNamespaceData(ctx, namespace, rowIdx) +} + +// validateIndexBounds checks if the index is within the bounds of the eds. +func validateIndexBounds(ctx context.Context, f Accessor, idx int) error { + size := f.Size(ctx) + if idx < 0 || idx >= size { + return fmt.Errorf("%w: index %d is out of bounds: [0, %d)", ErrOutOfBounds, idx, size) + } + return nil +} diff --git a/share/new_eds/validation_test.go b/share/new_eds/validation_test.go new file mode 100644 index 0000000000..ad4e6f2efc --- /dev/null +++ b/share/new_eds/validation_test.go @@ -0,0 +1,103 @@ +package eds + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/celestiaorg/rsmt2d" + + "github.com/celestiaorg/celestia-node/share/eds/edstest" + "github.com/celestiaorg/celestia-node/share/sharetest" + "github.com/celestiaorg/celestia-node/share/shwap" +) + +func TestValidation_Sample(t *testing.T) { + tests := []struct { + name string + rowIdx, colIdx int + odsSize int + expectFail bool + }{ + {"ValidIndices", 3, 2, 4, false}, + {"OutOfBoundsX", 8, 3, 4, true}, + {"OutOfBoundsY", 3, 8, 4, true}, + {"NegativeX", -1, 4, 8, true}, + {"NegativeY", 3, -1, 8, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + randEDS := edstest.RandEDS(t, tt.odsSize) + accessor := &Rsmt2D{ExtendedDataSquare: randEDS} + validation := WithValidation(WithCloser(accessor, nil)) + + _, err := validation.Sample(context.Background(), tt.rowIdx, tt.colIdx) + if tt.expectFail { + require.ErrorIs(t, err, ErrOutOfBounds) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestValidation_AxisHalf(t *testing.T) { + tests := []struct { + name string + axisType rsmt2d.Axis + axisIdx int + odsSize int + expectFail bool + }{ + {"ValidIndex", rsmt2d.Row, 2, 4, false}, + {"OutOfBounds", rsmt2d.Col, 8, 4, true}, + {"NegativeIndex", rsmt2d.Row, -1, 4, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + randEDS := edstest.RandEDS(t, tt.odsSize) + accessor := &Rsmt2D{ExtendedDataSquare: randEDS} + validation := WithValidation(WithCloser(accessor, nil)) + + _, err := validation.AxisHalf(context.Background(), tt.axisType, tt.axisIdx) + if tt.expectFail { + require.ErrorIs(t, err, ErrOutOfBounds) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestValidation_RowNamespaceData(t *testing.T) { + tests := []struct { + name string + rowIdx int + odsSize int + expectFail bool + }{ + {"ValidIndex", 3, 4, false}, + {"OutOfBounds", 8, 4, true}, + {"NegativeIndex", -1, 4, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + randEDS := edstest.RandEDS(t, tt.odsSize) + accessor := &Rsmt2D{ExtendedDataSquare: randEDS} + validation := WithValidation(WithCloser(accessor, nil)) + + ns := sharetest.RandV0Namespace() + _, err := validation.RowNamespaceData(context.Background(), ns, tt.rowIdx) + if tt.expectFail { + require.ErrorIs(t, err, ErrOutOfBounds) + } else { + require.True(t, err == nil || errors.Is(err, shwap.ErrNamespaceOutsideRange), err) + } + }) + } +}