Skip to content

Commit

Permalink
feat(restore): improve context cancelling during batching
Browse files Browse the repository at this point in the history
I'm not sure if previous behavior was bugged, but changes
introduced in this commit should make it more clear that
batching mechanism respects context cancellation.
This commit also adds a simple test validating that
pausing restore during batching ends quickly.
  • Loading branch information
Michal-Leszczynski committed Oct 21, 2024
1 parent 0c7447a commit 43491e9
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 9 deletions.
12 changes: 10 additions & 2 deletions pkg/service/restore/batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package restore

import (
"context"
"slices"
"sync"

Expand Down Expand Up @@ -198,8 +199,12 @@ func (bd *batchDispatcher) ValidateAllDispatched() error {
// failed to be restored (see batchDispatcher.wait description for more information).
// Because of that, it's important to call ReportSuccess or ReportFailure after
// each dispatched batch was attempted to be restored.
func (bd *batchDispatcher) DispatchBatch(host string) (batch, bool) {
func (bd *batchDispatcher) DispatchBatch(ctx context.Context, host string) (batch, bool) {
for {
if ctx.Err() != nil {
return batch{}, false
}

bd.mu.Lock()
// Check if there is anything to batch
if bd.done {
Expand All @@ -224,7 +229,10 @@ func (bd *batchDispatcher) DispatchBatch(host string) (batch, bool) {
bd.waitCnt++
bd.mu.Unlock()

<-wait
select {
case <-ctx.Done():
case <-wait:
}

bd.mu.Lock()
bd.waitCnt--
Expand Down
49 changes: 49 additions & 0 deletions pkg/service/restore/restore_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -833,4 +833,53 @@ func TestRestoreTablesBatchRetryIntegration(t *testing.T) {
t.Fatal("Restore hanged")
}
})

t.Run("paused restore with slow calls to download and las", func(t *testing.T) {
Print("Make download and las calls slow")
reachedDataStage := atomic.Bool{}
reachedDataStageChan := make(chan struct{})
h.dstCluster.Hrt.SetInterceptor(httpx.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
if strings.HasPrefix(req.URL.Path, "/agent/rclone/sync/copypaths") ||
strings.HasPrefix(req.URL.Path, "/storage_service/sstables/") {
if reachedDataStage.CompareAndSwap(false, true) {
close(reachedDataStageChan)
}
time.Sleep(time.Second)
return nil, nil
}
return nil, nil
}))

Print("Run restore")
grantRestoreTablesPermissions(t, h.dstCluster.rootSession, ksFilter, h.dstUser)
h.dstCluster.TaskID = uuid.NewTime()
h.dstCluster.RunID = uuid.NewTime()
rawProps, err := json.Marshal(props)
if err != nil {
t.Fatal(errors.Wrap(err, "marshal properties"))
}
ctx, cancel := context.WithCancel(context.Background())
res := make(chan error)
go func() {
res <- h.dstRestoreSvc.Restore(ctx, h.dstCluster.ClusterID, h.dstCluster.TaskID, h.dstCluster.RunID, rawProps)
}()

Print("Wait for data stage")
select {
case <-reachedDataStageChan:
cancel()
case err := <-res:
t.Fatalf("Restore finished before reaching data stage with: %s", err)
}

Print("Validate restore was paused in time")
select {
case err := <-res:
if !errors.Is(err, context.Canceled) {
t.Fatalf("Expected restore to end with context cancelled, got %q", err)
}
case <-time.NewTimer(2 * time.Second).C:
t.Fatal("Restore wasn't paused in time")
}
})
}
14 changes: 7 additions & 7 deletions pkg/service/restore/tables_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,11 @@ func (w *tablesWorker) stageRestoreData(ctx context.Context) error {
hi := w.hostInfo(host, dc, hostToShard[host])
w.logger.Info(ctx, "Host info", "host", hi.Host, "transfers", hi.Transfers, "rate limit", hi.RateLimit)
for {
if ctx.Err() != nil {
return ctx.Err()
}
// Download and stream in parallel
b, ok := bd.DispatchBatch(hi.Host)
b, ok := bd.DispatchBatch(ctx, hi.Host)
if !ok {
w.logger.Info(ctx, "No more batches to restore", "host", hi.Host)
return nil
Expand All @@ -232,19 +235,13 @@ func (w *tablesWorker) stageRestoreData(ctx context.Context) error {
w.logger.Error(ctx, "Failed to create new run progress",
"host", hi.Host,
"error", err)
if ctx.Err() != nil {
return err
}
continue
}
if err := w.restoreBatch(ctx, b, pr); err != nil {
err = multierr.Append(errors.Wrap(err, "restore batch"), bd.ReportFailure(hi.Host, b))
w.logger.Error(ctx, "Failed to restore batch",
"host", hi.Host,
"error", err)
if ctx.Err() != nil {
return err
}
continue
}
bd.ReportSuccess(b)
Expand All @@ -261,6 +258,9 @@ func (w *tablesWorker) stageRestoreData(ctx context.Context) error {

err = parallel.Run(len(hosts), w.target.Parallel, f, notify)
if err == nil {
if ctx.Err() != nil {
return ctx.Err()
}
return bd.ValidateAllDispatched()
}
return err
Expand Down

0 comments on commit 43491e9

Please sign in to comment.