Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(shwap): Add validation file #3485

Merged
merged 7 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions share/new_eds/accessor.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,12 @@ type AccessorCloser interface {
Accessor
io.Closer
}

type accessorCloser struct {
Accessor
io.Closer
}

func WithCloser(a Accessor, c io.Closer) AccessorCloser {
return &accessorCloser{a, c}
}
76 changes: 76 additions & 0 deletions share/new_eds/validation.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package eds

import (
"context"
"errors"
"fmt"
"sync/atomic"

"github.com/celestiaorg/rsmt2d"

"github.com/celestiaorg/celestia-node/share"
"github.com/celestiaorg/celestia-node/share/shwap"
)

var _ Accessor = validation{}

// ErrOutOfBounds is returned whenever an index is out of bounds.
var ErrOutOfBounds = errors.New("index is out of bounds")

// validation is a Accessor implementation that performs sanity checks on methods. It wraps
// another Accessor and performs bounds checks on index arguments.
type validation struct {
Accessor
size *atomic.Int32
}

func WithValidation(f Accessor) Accessor {
return &validation{Accessor: f, size: new(atomic.Int32)}
}

func (f validation) Size(ctx context.Context) int {
size := f.size.Load()
if size == 0 {
loaded := f.Accessor.Size(ctx)
f.size.Store(int32(loaded))
return loaded
}
return int(size)
}

func (f validation) Sample(ctx context.Context, rowIdx, colIdx int) (shwap.Sample, error) {
if err := validateIndexBounds(ctx, f, colIdx); err != nil {
return shwap.Sample{}, fmt.Errorf("col: %w", err)
}
if err := validateIndexBounds(ctx, f, rowIdx); err != nil {
return shwap.Sample{}, fmt.Errorf("row: %w", err)
}
return f.Accessor.Sample(ctx, rowIdx, colIdx)
}

func (f validation) AxisHalf(ctx context.Context, axisType rsmt2d.Axis, axisIdx int) (AxisHalf, error) {
if err := validateIndexBounds(ctx, f, axisIdx); err != nil {
return AxisHalf{}, fmt.Errorf("%s: %w", axisType, err)
}
return f.Accessor.AxisHalf(ctx, axisType, axisIdx)
}

func (f validation) RowNamespaceData(
ctx context.Context,
namespace share.Namespace,
rowIdx int,
) (shwap.RowNamespaceData, error) {
if err := validateIndexBounds(ctx, f, rowIdx); err != nil {
return shwap.RowNamespaceData{}, fmt.Errorf("row: %w", err)
}
return f.Accessor.RowNamespaceData(ctx, namespace, rowIdx)
}

// validateIndexBounds checks if the index is within the bounds of the eds.
func validateIndexBounds(ctx context.Context, f Accessor, idx int) error {
size := f.Size(ctx)
if idx < 0 || idx >= size {
return fmt.Errorf("%w: index %d is out of bounds: [0, %d)", ErrOutOfBounds, idx, size)
}
return nil
}
103 changes: 103 additions & 0 deletions share/new_eds/validation_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
package eds

import (
"context"
"errors"
"testing"

"github.com/stretchr/testify/require"

"github.com/celestiaorg/rsmt2d"

"github.com/celestiaorg/celestia-node/share/eds/edstest"
"github.com/celestiaorg/celestia-node/share/sharetest"
"github.com/celestiaorg/celestia-node/share/shwap"
)

func TestValidation_Sample(t *testing.T) {
tests := []struct {
name string
rowIdx, colIdx int
odsSize int
expectFail bool
}{
{"ValidIndices", 3, 2, 4, false},
{"OutOfBoundsX", 8, 3, 4, true},
{"OutOfBoundsY", 3, 8, 4, true},
{"NegativeX", -1, 4, 8, true},
{"NegativeY", 3, -1, 8, true},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
randEDS := edstest.RandEDS(t, tt.odsSize)
accessor := &Rsmt2D{ExtendedDataSquare: randEDS}
validation := WithValidation(WithCloser(accessor, nil))

_, err := validation.Sample(context.Background(), tt.rowIdx, tt.colIdx)
if tt.expectFail {
require.ErrorIs(t, err, ErrOutOfBounds)
} else {
require.NoError(t, err)
}
})
}
}

func TestValidation_AxisHalf(t *testing.T) {
tests := []struct {
name string
axisType rsmt2d.Axis
axisIdx int
odsSize int
expectFail bool
}{
{"ValidIndex", rsmt2d.Row, 2, 4, false},
{"OutOfBounds", rsmt2d.Col, 8, 4, true},
{"NegativeIndex", rsmt2d.Row, -1, 4, true},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
randEDS := edstest.RandEDS(t, tt.odsSize)
accessor := &Rsmt2D{ExtendedDataSquare: randEDS}
validation := WithValidation(WithCloser(accessor, nil))

_, err := validation.AxisHalf(context.Background(), tt.axisType, tt.axisIdx)
if tt.expectFail {
require.ErrorIs(t, err, ErrOutOfBounds)
} else {
require.NoError(t, err)
}
})
}
}

func TestValidation_RowNamespaceData(t *testing.T) {
tests := []struct {
name string
rowIdx int
odsSize int
expectFail bool
}{
{"ValidIndex", 3, 4, false},
{"OutOfBounds", 8, 4, true},
{"NegativeIndex", -1, 4, true},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
randEDS := edstest.RandEDS(t, tt.odsSize)
accessor := &Rsmt2D{ExtendedDataSquare: randEDS}
validation := WithValidation(WithCloser(accessor, nil))

ns := sharetest.RandV0Namespace()
_, err := validation.RowNamespaceData(context.Background(), ns, tt.rowIdx)
if tt.expectFail {
require.ErrorIs(t, err, ErrOutOfBounds)
} else {
require.True(t, err == nil || errors.Is(err, shwap.ErrNamespaceOutsideRange), err)
}
})
}
}
Loading