From 783c4c190caf0f39d282e81c3dcecf5dfe74594b Mon Sep 17 00:00:00 2001 From: Oleg Kovalov Date: Thu, 21 Nov 2024 10:43:00 +0100 Subject: [PATCH] add context cancel test --- share/availability/light/availability.go | 13 ++-- share/availability/light/availability_test.go | 75 +++++++++++++++++++ 2 files changed, 82 insertions(+), 6 deletions(-) diff --git a/share/availability/light/availability.go b/share/availability/light/availability.go index f2bc5994e2..b0f33435aa 100644 --- a/share/availability/light/availability.go +++ b/share/availability/light/availability.go @@ -116,6 +116,11 @@ func (la *ShareAvailability) SharesAvailable(ctx context.Context, header *header log.Debugw("starting sampling session", "root", dah.String()) + idxs := make([]shwap.SampleCoords, len(samples.Remaining)) + for i, s := range samples.Remaining { + idxs[i] = shwap.SampleCoords{Row: s.Row, Col: s.Col} + } + // remove one second from the deadline to ensure we have enough time to process the results samplingCtx, cancel := context.WithCancel(ctx) if deadline, ok := ctx.Deadline(); ok { @@ -123,11 +128,6 @@ func (la *ShareAvailability) SharesAvailable(ctx context.Context, header *header } defer cancel() - idxs := make([]shwap.SampleCoords, len(samples.Remaining)) - for i, s := range samples.Remaining { - idxs[i] = shwap.SampleCoords{Row: s.Row, Col: s.Col} - } - smpls, errGetSamples := la.getter.GetSamples(samplingCtx, header, idxs) if len(smpls) == 0 { return share.ErrNotAvailable @@ -157,7 +157,8 @@ func (la *ShareAvailability) SharesAvailable(ctx context.Context, header *header return fmt.Errorf("store sampling result: %w", err) } - if errors.Is(errGetSamples, context.Canceled) { + if errors.Is(errGetSamples, context.Canceled) || + errors.Is(errGetSamples, context.DeadlineExceeded) { // Availability did not complete due to context cancellation, return context error instead of // share.ErrNotAvailable return context.Canceled diff --git a/share/availability/light/availability_test.go b/share/availability/light/availability_test.go index dce5021abb..16c2b0efda 100644 --- a/share/availability/light/availability_test.go +++ b/share/availability/light/availability_test.go @@ -384,6 +384,49 @@ func TestPrunePartialFailed(t *testing.T) { require.False(t, exist) } +func TestPruneWithCancelledContext(t *testing.T) { + const size = 8 + ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) + t.Cleanup(cancel) + + eds, h := randEdsAndHeader(t, size) + ds := ds_sync.MutexWrap(datastore.NewMapDatastore()) + clientBs := blockstore.NewBlockstore(ds) + + ex := newTimeoutExchange(newExchangeOverEDS(ctx, t, eds)) + getter := bitswap.NewGetter(ex, clientBs, 0) + getter.Start() + defer getter.Stop() + + // Create a new ShareAvailability instance and sample the shares + sampleAmount := uint(20) + avail := NewShareAvailability(getter, ds, clientBs, WithSampleAmount(sampleAmount)) + + ctx2, cancel2 := context.WithTimeout(ctx, 1500*time.Millisecond) + defer cancel2() + + err := avail.SharesAvailable(ctx2, h) + require.Error(t, err, context.Canceled) + // close ShareAvailability to force flush of batched writes + avail.Close(ctx) + + preDeleteCount := countKeys(ctx, t, clientBs) + require.EqualValues(t, sampleAmount, preDeleteCount) + + // prune the samples + err = avail.Prune(ctx, h) + require.NoError(t, err) + + // Check if samples are deleted + postDeleteCount := countKeys(ctx, t, clientBs) + require.Zero(t, postDeleteCount) + + // Check if sampling result is deleted + exist, err := avail.ds.Has(ctx, datastoreKeyForRoot(h.DAH)) + require.NoError(t, err) + require.False(t, exist) +} + type halfSessionExchange struct { exchange.SessionExchange attempt atomic.Int32 @@ -417,6 +460,38 @@ func (hse *halfSessionExchange) GetBlocks(ctx context.Context, cids []cid.Cid) ( return out, nil } +type timeoutExchange struct { + exchange.SessionExchange +} + +func newTimeoutExchange(ex exchange.SessionExchange) *timeoutExchange { + return &timeoutExchange{SessionExchange: ex} +} + +func (hse *timeoutExchange) NewSession(context.Context) exchange.Fetcher { + return hse +} + +func (hse *timeoutExchange) GetBlocks(ctx context.Context, cids []cid.Cid) (<-chan blocks.Block, error) { + out := make(chan blocks.Block, len(cids)) + defer close(out) + + for _, cid := range cids { + + blk, err := hse.SessionExchange.GetBlock(ctx, cid) + if err != nil { + return nil, err + } + + out <- blk + } + + // sleep 1 second guarantees that we will exhaust context. + time.Sleep(time.Second) + + return out, nil +} + func randEdsAndHeader(t *testing.T, size int) (*rsmt2d.ExtendedDataSquare, *header.ExtendedHeader) { height := uint64(42) eds := edstest.RandEDS(t, size)