From 9b7c3ba1a395dc3859181e31e40764493f5270d7 Mon Sep 17 00:00:00 2001 From: Vlad <13818348+walldiss@users.noreply.github.com> Date: Mon, 10 Jun 2024 13:16:27 +0500 Subject: [PATCH] add validation file --- store/file/validating.go | 65 +++++++++++++++++++ store/file/validating_test.go | 114 ++++++++++++++++++++++++++++++++++ 2 files changed, 179 insertions(+) create mode 100644 store/file/validating.go create mode 100644 store/file/validating_test.go diff --git a/store/file/validating.go b/store/file/validating.go new file mode 100644 index 0000000000..1d551d6f25 --- /dev/null +++ b/store/file/validating.go @@ -0,0 +1,65 @@ +package file + +import ( + "context" + "errors" + "fmt" + + "github.com/celestiaorg/rsmt2d" + + "github.com/celestiaorg/celestia-node/share" + eds "github.com/celestiaorg/celestia-node/share/new_eds" + "github.com/celestiaorg/celestia-node/share/shwap" +) + +var _ eds.AccessorCloser = validation{} + +// ErrOutOfBounds is returned whenever an index is out of bounds. +var ErrOutOfBounds = errors.New("index is out of bounds") + +// validation is a AccessorCloser implementation that performs sanity checks on methods. It wraps +// another AccessorCloser and performs bounds checks on index arguments. +type validation struct { + eds.AccessorCloser +} + +func WithValidation(f eds.AccessorCloser) eds.AccessorCloser { + return &validation{AccessorCloser: f} +} + +func (f validation) Sample(ctx context.Context, rowIdx, colIdx int) (shwap.Sample, error) { + if err := validateIndexBounds(ctx, f, rowIdx); err != nil { + return shwap.Sample{}, fmt.Errorf("col: %w", err) + } + if err := validateIndexBounds(ctx, f, colIdx); err != nil { + return shwap.Sample{}, fmt.Errorf("row: %w", err) + } + return f.AccessorCloser.Sample(ctx, rowIdx, colIdx) +} + +func (f validation) AxisHalf(ctx context.Context, axisType rsmt2d.Axis, axisIdx int) (eds.AxisHalf, error) { + if err := validateIndexBounds(ctx, f, axisIdx); err != nil { + return eds.AxisHalf{}, fmt.Errorf("%s: %w", axisType, err) + } + return f.AccessorCloser.AxisHalf(ctx, axisType, axisIdx) +} + +func (f validation) RowNamespaceData( + ctx context.Context, + namespace share.Namespace, + rowIdx int, +) (shwap.RowNamespaceData, error) { + if err := validateIndexBounds(ctx, f, rowIdx); err != nil { + return shwap.RowNamespaceData{}, fmt.Errorf("row: %w", err) + } + return f.AccessorCloser.RowNamespaceData(ctx, namespace, rowIdx) +} + +// validateIndexBounds checks if the index is within the bounds of the eds. +func validateIndexBounds(ctx context.Context, f eds.AccessorCloser, idx int) error { + size := f.Size(ctx) + if idx < 0 || idx >= size { + return fmt.Errorf("%w: index %d is out of bounds: [0, %d)", ErrOutOfBounds, idx, size) + } + return nil +} diff --git a/store/file/validating_test.go b/store/file/validating_test.go new file mode 100644 index 0000000000..36c9705cca --- /dev/null +++ b/store/file/validating_test.go @@ -0,0 +1,114 @@ +package file + +import ( + "context" + "errors" + "io" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/celestiaorg/rsmt2d" + + "github.com/celestiaorg/celestia-node/share/eds/edstest" + "github.com/celestiaorg/celestia-node/share/ipld" + eds "github.com/celestiaorg/celestia-node/share/new_eds" + "github.com/celestiaorg/celestia-node/share/sharetest" +) + +func TestValidation_Sample(t *testing.T) { + tests := []struct { + name string + rowIdx, colIdx int + odsSize int + expectFail bool + }{ + {"ValidIndices", 3, 2, 4, false}, + {"OutOfBoundsX", 8, 3, 4, true}, + {"OutOfBoundsY", 3, 8, 4, true}, + {"NegativeX", -1, 4, 8, true}, + {"NegativeY", 3, -1, 8, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + randEDS := edstest.RandEDS(t, tt.odsSize) + accessor := &eds.Rsmt2D{ExtendedDataSquare: randEDS} + validation := WithValidation(withCloser(accessor, nil)) + + _, err := validation.Sample(context.Background(), tt.rowIdx, tt.colIdx) + if tt.expectFail { + require.ErrorIs(t, err, ErrOutOfBounds) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestValidation_AxisHalf(t *testing.T) { + tests := []struct { + name string + axisType rsmt2d.Axis + axisIdx int + odsSize int + expectFail bool + }{ + {"ValidIndex", rsmt2d.Row, 2, 4, false}, + {"OutOfBounds", rsmt2d.Col, 8, 4, true}, + {"NegativeIndex", rsmt2d.Row, -1, 4, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + randEDS := edstest.RandEDS(t, tt.odsSize) + accessor := &eds.Rsmt2D{ExtendedDataSquare: randEDS} + validation := WithValidation(withCloser(accessor, nil)) + + _, err := validation.AxisHalf(context.Background(), tt.axisType, tt.axisIdx) + if tt.expectFail { + require.ErrorIs(t, err, ErrOutOfBounds) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestValidation_RowNamespaceData(t *testing.T) { + tests := []struct { + name string + rowIdx int + odsSize int + expectFail bool + }{ + {"ValidIndex", 3, 4, false}, + {"OutOfBounds", 8, 4, true}, + {"NegativeIndex", -1, 4, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + randEDS := edstest.RandEDS(t, tt.odsSize) + accessor := &eds.Rsmt2D{ExtendedDataSquare: randEDS} + validation := WithValidation(withCloser(accessor, nil)) + + ns := sharetest.RandV0Namespace() + _, err := validation.RowNamespaceData(context.Background(), ns, tt.rowIdx) + if tt.expectFail { + require.ErrorIs(t, err, ErrOutOfBounds) + } else { + require.True(t, err == nil || errors.Is(err, ipld.ErrNamespaceOutsideRange)) + } + }) + } +} + +func withCloser(accessor eds.Accessor, closer io.Closer) eds.AccessorCloser { + return &accessorCloser{Accessor: accessor, Closer: closer} +} + +type accessorCloser struct { + eds.Accessor + io.Closer +}