From 43491e90ea4c178b7701fab490353a2cbf39ad71 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Leszczy=C5=84ski?= <2000michal@wp.pl> Date: Mon, 21 Oct 2024 09:13:22 +0200 Subject: [PATCH] feat(restore): improve context cancelling during batching I'm not sure if previous behavior was bugged, but changes introduced in this commit should make it more clear that batching mechanism respects context cancellation. This commit also adds a simple test validating that pausing restore during batching ends quickly. --- pkg/service/restore/batch.go | 12 ++++- .../restore/restore_integration_test.go | 49 +++++++++++++++++++ pkg/service/restore/tables_worker.go | 14 +++--- 3 files changed, 66 insertions(+), 9 deletions(-) diff --git a/pkg/service/restore/batch.go b/pkg/service/restore/batch.go index bbd22c88e..b0cc4f03a 100644 --- a/pkg/service/restore/batch.go +++ b/pkg/service/restore/batch.go @@ -3,6 +3,7 @@ package restore import ( + "context" "slices" "sync" @@ -198,8 +199,12 @@ func (bd *batchDispatcher) ValidateAllDispatched() error { // failed to be restored (see batchDispatcher.wait description for more information). // Because of that, it's important to call ReportSuccess or ReportFailure after // each dispatched batch was attempted to be restored. -func (bd *batchDispatcher) DispatchBatch(host string) (batch, bool) { +func (bd *batchDispatcher) DispatchBatch(ctx context.Context, host string) (batch, bool) { for { + if ctx.Err() != nil { + return batch{}, false + } + bd.mu.Lock() // Check if there is anything to batch if bd.done { @@ -224,7 +229,10 @@ func (bd *batchDispatcher) DispatchBatch(host string) (batch, bool) { bd.waitCnt++ bd.mu.Unlock() - <-wait + select { + case <-ctx.Done(): + case <-wait: + } bd.mu.Lock() bd.waitCnt-- diff --git a/pkg/service/restore/restore_integration_test.go b/pkg/service/restore/restore_integration_test.go index de97405ee..13a0fe9db 100644 --- a/pkg/service/restore/restore_integration_test.go +++ b/pkg/service/restore/restore_integration_test.go @@ -833,4 +833,53 @@ func TestRestoreTablesBatchRetryIntegration(t *testing.T) { t.Fatal("Restore hanged") } }) + + t.Run("paused restore with slow calls to download and las", func(t *testing.T) { + Print("Make download and las calls slow") + reachedDataStage := atomic.Bool{} + reachedDataStageChan := make(chan struct{}) + h.dstCluster.Hrt.SetInterceptor(httpx.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + if strings.HasPrefix(req.URL.Path, "/agent/rclone/sync/copypaths") || + strings.HasPrefix(req.URL.Path, "/storage_service/sstables/") { + if reachedDataStage.CompareAndSwap(false, true) { + close(reachedDataStageChan) + } + time.Sleep(time.Second) + return nil, nil + } + return nil, nil + })) + + Print("Run restore") + grantRestoreTablesPermissions(t, h.dstCluster.rootSession, ksFilter, h.dstUser) + h.dstCluster.TaskID = uuid.NewTime() + h.dstCluster.RunID = uuid.NewTime() + rawProps, err := json.Marshal(props) + if err != nil { + t.Fatal(errors.Wrap(err, "marshal properties")) + } + ctx, cancel := context.WithCancel(context.Background()) + res := make(chan error) + go func() { + res <- h.dstRestoreSvc.Restore(ctx, h.dstCluster.ClusterID, h.dstCluster.TaskID, h.dstCluster.RunID, rawProps) + }() + + Print("Wait for data stage") + select { + case <-reachedDataStageChan: + cancel() + case err := <-res: + t.Fatalf("Restore finished before reaching data stage with: %s", err) + } + + Print("Validate restore was paused in time") + select { + case err := <-res: + if !errors.Is(err, context.Canceled) { + t.Fatalf("Expected restore to end with context cancelled, got %q", err) + } + case <-time.NewTimer(2 * time.Second).C: + t.Fatal("Restore wasn't paused in time") + } + }) } diff --git a/pkg/service/restore/tables_worker.go b/pkg/service/restore/tables_worker.go index 098b84be5..045f1f8e1 100644 --- a/pkg/service/restore/tables_worker.go +++ b/pkg/service/restore/tables_worker.go @@ -211,8 +211,11 @@ func (w *tablesWorker) stageRestoreData(ctx context.Context) error { hi := w.hostInfo(host, dc, hostToShard[host]) w.logger.Info(ctx, "Host info", "host", hi.Host, "transfers", hi.Transfers, "rate limit", hi.RateLimit) for { + if ctx.Err() != nil { + return ctx.Err() + } // Download and stream in parallel - b, ok := bd.DispatchBatch(hi.Host) + b, ok := bd.DispatchBatch(ctx, hi.Host) if !ok { w.logger.Info(ctx, "No more batches to restore", "host", hi.Host) return nil @@ -232,9 +235,6 @@ func (w *tablesWorker) stageRestoreData(ctx context.Context) error { w.logger.Error(ctx, "Failed to create new run progress", "host", hi.Host, "error", err) - if ctx.Err() != nil { - return err - } continue } if err := w.restoreBatch(ctx, b, pr); err != nil { @@ -242,9 +242,6 @@ func (w *tablesWorker) stageRestoreData(ctx context.Context) error { w.logger.Error(ctx, "Failed to restore batch", "host", hi.Host, "error", err) - if ctx.Err() != nil { - return err - } continue } bd.ReportSuccess(b) @@ -261,6 +258,9 @@ func (w *tablesWorker) stageRestoreData(ctx context.Context) error { err = parallel.Run(len(hosts), w.target.Parallel, f, notify) if err == nil { + if ctx.Err() != nil { + return ctx.Err() + } return bd.ValidateAllDispatched() } return err