Skip to content

Commit

Permalink
feat(restore_test): test restore pause/resume
Browse files Browse the repository at this point in the history
This commit adds runPausedRestore function which allows to run
restore which is going to be interrupted every pauseInterval
at least minPauseCnt times.
Pausing after some arbitrary amount of time might seem flaky,
but we can't always rely on scyllaclient hooks for pausing restore,
as we are biased to where we put them.

This commit also uses runPausedRestore in TestRestoreTablesPausedIntegration
which tests for #4037.
  • Loading branch information
Michal-Leszczynski committed Sep 23, 2024
1 parent 564546d commit 6bd0a7a
Showing 1 changed file with 187 additions and 0 deletions.
187 changes: 187 additions & 0 deletions pkg/service/restore/restore_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,24 @@ package restore_test

import (
"context"
"encoding/json"
"fmt"
"slices"
"strings"
"testing"
"time"

"github.com/pkg/errors"
"github.com/scylladb/gocqlx/v2"
"github.com/scylladb/scylla-manager/v3/pkg/service/backup"
. "github.com/scylladb/scylla-manager/v3/pkg/service/backup/backupspec"
. "github.com/scylladb/scylla-manager/v3/pkg/testutils"
. "github.com/scylladb/scylla-manager/v3/pkg/testutils/db"
. "github.com/scylladb/scylla-manager/v3/pkg/testutils/testconfig"
"github.com/scylladb/scylla-manager/v3/pkg/util/maputil"
"github.com/scylladb/scylla-manager/v3/pkg/util/query"
"github.com/scylladb/scylla-manager/v3/pkg/util/uuid"
"go.uber.org/multierr"
)

func TestRestoreTablesUserIntegration(t *testing.T) {
Expand Down Expand Up @@ -334,3 +341,183 @@ func TestRestoreTablesVnodeToTabletsIntegration(t *testing.T) {

validateTableContent[int, int](t, h.srcCluster.rootSession, h.dstCluster.rootSession, ks, tab, c1, c2)
}

func TestRestoreTablesPausedIntegration(t *testing.T) {
testCases := []struct {
rf int
pauseInterval time.Duration
minPauseCnt int
}{
{rf: 1, pauseInterval: time.Hour, minPauseCnt: 0},
{rf: 2, pauseInterval: 45 * time.Second, minPauseCnt: 2},
{rf: 1, pauseInterval: 45 * time.Second, minPauseCnt: 2},
}

for _, tc := range testCases {
t.Run(fmt.Sprintf("rf: %d, pauseInterval: %d, minPauseCnt: %d", tc.rf, tc.pauseInterval, tc.minPauseCnt), func(t *testing.T) {
h := newTestHelper(t, ManagedSecondClusterHosts(), ManagedClusterHosts())
kss := []string{randomizedName("paused_1_"), randomizedName("paused_2_")}
tabs := []string{randomizedName("tab_1_"), randomizedName("tab_2_")}
mv := randomizedName("mv")
index := randomizedName("index")

Print("Create and fill tables in src cluster")
var units []backup.Unit
for _, ks := range kss {
Print(fmt.Sprintf("Create %q in both clusters", ks))
ksStmt := "CREATE KEYSPACE %q WITH replication = {'class': 'NetworkTopologyStrategy', 'dc1': %d}"
ExecStmt(t, h.srcCluster.rootSession, fmt.Sprintf(ksStmt, ks, tc.rf))
ExecStmt(t, h.dstCluster.rootSession, fmt.Sprintf(ksStmt, ks, tc.rf))
u := backup.Unit{Keyspace: ks, AllTables: true}

for i, tab := range tabs {
Print(fmt.Sprintf("Create %q.%q in both clusters", ks, tab))
createTable(t, h.srcCluster.rootSession, ks, tab)
createTable(t, h.dstCluster.rootSession, ks, tab)
u.Tables = append(u.Tables, tab)

if i == 0 {
Print(fmt.Sprintf("Create MV and SI for %s.%s in both clusters", ks, tab))
CreateMaterializedView(t, h.srcCluster.rootSession, ks, tab, mv)
CreateMaterializedView(t, h.dstCluster.rootSession, ks, tab, mv)
CreateSecondaryIndex(t, h.srcCluster.rootSession, ks, tab, index)
CreateSecondaryIndex(t, h.dstCluster.rootSession, ks, tab, index)
u.Tables = append(u.Tables, mv, index+"_index")
}

Print(fmt.Sprintf("Fill %s.%s in src cluster", ks, tab))
fillTable(t, h.srcCluster.rootSession, 1, ks, tab)
}
units = append(units, u)
}

Print("Run backup")
loc := []Location{testLocation("paused", "")}
S3InitBucket(t, loc[0].Path)
ksFilter := slices.Clone(kss)

// Starting from SM 3.3.1, SM does not allow to back up views,
// but backed up views should still be tested as older backups might
// contain them. That's why here we manually force backup target
// to contain the views.
ctx := context.Background()
h.srcCluster.TaskID = uuid.NewTime()
h.srcCluster.RunID = uuid.NewTime()

rawProps, err := json.Marshal(map[string]any{
"location": loc,
"keyspace": ksFilter,
})
if err != nil {
t.Fatal(errors.Wrap(err, "marshal properties"))
}

target, err := h.srcBackupSvc.GetTarget(ctx, h.srcCluster.ClusterID, rawProps)
if err != nil {
t.Fatal(errors.Wrap(err, "generate target"))
}
target.Units = units

err = h.srcBackupSvc.Backup(ctx, h.srcCluster.ClusterID, h.srcCluster.TaskID, h.srcCluster.RunID, target)
if err != nil {
t.Fatal(errors.Wrap(err, "run backup"))
}

pr, err := h.srcBackupSvc.GetProgress(ctx, h.srcCluster.ClusterID, h.srcCluster.TaskID, h.srcCluster.RunID)
if err != nil {
t.Fatal(errors.Wrap(err, "get progress"))
}
tag := pr.SnapshotTag

Print("Run restore tables")
grantRestoreTablesPermissions(t, h.dstCluster.rootSession, ksFilter, h.dstUser)
props := map[string]any{
"location": loc,
"keyspace": ksFilter,
"snapshot_tag": tag,
"restore_tables": true,
}
err = runPausedRestore(t, func(ctx context.Context) error {
h.dstCluster.RunID = uuid.NewTime()
rawProps, err := json.Marshal(props)
if err != nil {
return err
}
return h.dstRestoreSvc.Restore(ctx, h.dstCluster.ClusterID, h.dstCluster.TaskID, h.dstCluster.RunID, rawProps)
}, tc.pauseInterval, tc.minPauseCnt)
if err != nil {
t.Fatal(err)
}

for _, ks := range kss {
for i, tab := range tabs {
validateTableContent[int, int](t, h.srcCluster.rootSession, h.dstCluster.rootSession, ks, tab, "id", "data")
if i == 0 {
validateTableContent[int, int](t, h.srcCluster.rootSession, h.dstCluster.rootSession, ks, mv, "id", "data")
validateTableContent[int, int](t, h.srcCluster.rootSession, h.dstCluster.rootSession, ks, index+"_index", "id", "data")
}
}
}
})
}
}

func createTable(t *testing.T, session gocqlx.Session, keyspace string, tables ...string) {
for _, tab := range tables {
ExecStmt(t, session, fmt.Sprintf("CREATE TABLE %q.%q (id int PRIMARY KEY, data int)", keyspace, tab))
}
}

func fillTable(t *testing.T, session gocqlx.Session, rowCnt int, keyspace string, tables ...string) {
for _, tab := range tables {
stmt := fmt.Sprintf("INSERT INTO %q.%q (id, data) VALUES (?, ?)", keyspace, tab)
q := session.Query(stmt, []string{"id", "data"})

for i := 0; i < rowCnt; i++ {
if err := q.Bind(i, i).Exec(); err != nil {
t.Fatal(err)
}
}

q.Release()
}
}

func runPausedRestore(t *testing.T, restore func(ctx context.Context) error, pauseInterval time.Duration, minPauseCnt int) (err error) {
t.Helper()

ticker := time.NewTicker(pauseInterval)
ctx, cancel := context.WithCancel(context.Background())
res := make(chan error)
pauseCnt := 0
defer func() {
t.Logf("Restore was paused %d times", pauseCnt)
if pauseCnt < minPauseCnt {
err = multierr.Append(err, errors.Errorf("expected to pause at least %d times, got %d", minPauseCnt, pauseCnt))
}
}()

go func() {
res <- restore(ctx)
}()
for {
select {
case err := <-res:
cancel()
return err
case <-ticker.C:
t.Log("Pause restore")
cancel()
err := <-res
if err == nil || !errors.Is(err, context.Canceled) {
return err
}

pauseCnt++
ctx, cancel = context.WithCancel(context.Background())
go func() {
res <- restore(ctx)
}()
}
}
}

0 comments on commit 6bd0a7a

Please sign in to comment.