diff --git a/pkg/service/restore/restore_integration_test.go b/pkg/service/restore/restore_integration_test.go index efbd588df..3d0af6448 100644 --- a/pkg/service/restore/restore_integration_test.go +++ b/pkg/service/restore/restore_integration_test.go @@ -7,17 +7,24 @@ package restore_test import ( "context" + "encoding/json" "fmt" + "slices" "strings" "testing" + "time" "github.com/pkg/errors" + "github.com/scylladb/gocqlx/v2" + "github.com/scylladb/scylla-manager/v3/pkg/service/backup" . "github.com/scylladb/scylla-manager/v3/pkg/service/backup/backupspec" . "github.com/scylladb/scylla-manager/v3/pkg/testutils" . "github.com/scylladb/scylla-manager/v3/pkg/testutils/db" . "github.com/scylladb/scylla-manager/v3/pkg/testutils/testconfig" "github.com/scylladb/scylla-manager/v3/pkg/util/maputil" "github.com/scylladb/scylla-manager/v3/pkg/util/query" + "github.com/scylladb/scylla-manager/v3/pkg/util/uuid" + "go.uber.org/multierr" ) func TestRestoreTablesUserIntegration(t *testing.T) { @@ -334,3 +341,183 @@ func TestRestoreTablesVnodeToTabletsIntegration(t *testing.T) { validateTableContent[int, int](t, h.srcCluster.rootSession, h.dstCluster.rootSession, ks, tab, c1, c2) } + +func TestRestoreTablesPausedIntegration(t *testing.T) { + testCases := []struct { + rf int + pauseInterval time.Duration + minPauseCnt int + }{ + {rf: 1, pauseInterval: time.Hour, minPauseCnt: 0}, + {rf: 2, pauseInterval: 45 * time.Second, minPauseCnt: 2}, + {rf: 1, pauseInterval: 45 * time.Second, minPauseCnt: 2}, + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("rf: %d, pauseInterval: %d, minPauseCnt: %d", tc.rf, tc.pauseInterval, tc.minPauseCnt), func(t *testing.T) { + h := newTestHelper(t, ManagedSecondClusterHosts(), ManagedClusterHosts()) + kss := []string{randomizedName("paused_1_"), randomizedName("paused_2_")} + tabs := []string{randomizedName("tab_1_"), randomizedName("tab_2_")} + mv := randomizedName("mv") + index := randomizedName("index") + + Print("Create and fill tables in src cluster") + var units []backup.Unit + for _, ks := range kss { + Print(fmt.Sprintf("Create %q in both clusters", ks)) + ksStmt := "CREATE KEYSPACE %q WITH replication = {'class': 'NetworkTopologyStrategy', 'dc1': %d}" + ExecStmt(t, h.srcCluster.rootSession, fmt.Sprintf(ksStmt, ks, tc.rf)) + ExecStmt(t, h.dstCluster.rootSession, fmt.Sprintf(ksStmt, ks, tc.rf)) + u := backup.Unit{Keyspace: ks, AllTables: true} + + for i, tab := range tabs { + Print(fmt.Sprintf("Create %q.%q in both clusters", ks, tab)) + createTable(t, h.srcCluster.rootSession, ks, tab) + createTable(t, h.dstCluster.rootSession, ks, tab) + u.Tables = append(u.Tables, tab) + + if i == 0 { + Print(fmt.Sprintf("Create MV and SI for %s.%s in both clusters", ks, tab)) + CreateMaterializedView(t, h.srcCluster.rootSession, ks, tab, mv) + CreateMaterializedView(t, h.dstCluster.rootSession, ks, tab, mv) + CreateSecondaryIndex(t, h.srcCluster.rootSession, ks, tab, index) + CreateSecondaryIndex(t, h.dstCluster.rootSession, ks, tab, index) + u.Tables = append(u.Tables, mv, index+"_index") + } + + Print(fmt.Sprintf("Fill %s.%s in src cluster", ks, tab)) + fillTable(t, h.srcCluster.rootSession, 1, ks, tab) + } + units = append(units, u) + } + + Print("Run backup") + loc := []Location{testLocation("paused", "")} + S3InitBucket(t, loc[0].Path) + ksFilter := slices.Clone(kss) + + // Starting from SM 3.3.1, SM does not allow to back up views, + // but backed up views should still be tested as older backups might + // contain them. That's why here we manually force backup target + // to contain the views. + ctx := context.Background() + h.srcCluster.TaskID = uuid.NewTime() + h.srcCluster.RunID = uuid.NewTime() + + rawProps, err := json.Marshal(map[string]any{ + "location": loc, + "keyspace": ksFilter, + }) + if err != nil { + t.Fatal(errors.Wrap(err, "marshal properties")) + } + + target, err := h.srcBackupSvc.GetTarget(ctx, h.srcCluster.ClusterID, rawProps) + if err != nil { + t.Fatal(errors.Wrap(err, "generate target")) + } + target.Units = units + + err = h.srcBackupSvc.Backup(ctx, h.srcCluster.ClusterID, h.srcCluster.TaskID, h.srcCluster.RunID, target) + if err != nil { + t.Fatal(errors.Wrap(err, "run backup")) + } + + pr, err := h.srcBackupSvc.GetProgress(ctx, h.srcCluster.ClusterID, h.srcCluster.TaskID, h.srcCluster.RunID) + if err != nil { + t.Fatal(errors.Wrap(err, "get progress")) + } + tag := pr.SnapshotTag + + Print("Run restore tables") + grantRestoreTablesPermissions(t, h.dstCluster.rootSession, ksFilter, h.dstUser) + props := map[string]any{ + "location": loc, + "keyspace": ksFilter, + "snapshot_tag": tag, + "restore_tables": true, + } + err = runPausedRestore(t, func(ctx context.Context) error { + h.dstCluster.RunID = uuid.NewTime() + rawProps, err := json.Marshal(props) + if err != nil { + return err + } + return h.dstRestoreSvc.Restore(ctx, h.dstCluster.ClusterID, h.dstCluster.TaskID, h.dstCluster.RunID, rawProps) + }, tc.pauseInterval, tc.minPauseCnt) + if err != nil { + t.Fatal(err) + } + + for _, ks := range kss { + for i, tab := range tabs { + validateTableContent[int, int](t, h.srcCluster.rootSession, h.dstCluster.rootSession, ks, tab, "id", "data") + if i == 0 { + validateTableContent[int, int](t, h.srcCluster.rootSession, h.dstCluster.rootSession, ks, mv, "id", "data") + validateTableContent[int, int](t, h.srcCluster.rootSession, h.dstCluster.rootSession, ks, index+"_index", "id", "data") + } + } + } + }) + } +} + +func createTable(t *testing.T, session gocqlx.Session, keyspace string, tables ...string) { + for _, tab := range tables { + ExecStmt(t, session, fmt.Sprintf("CREATE TABLE %q.%q (id int PRIMARY KEY, data int)", keyspace, tab)) + } +} + +func fillTable(t *testing.T, session gocqlx.Session, rowCnt int, keyspace string, tables ...string) { + for _, tab := range tables { + stmt := fmt.Sprintf("INSERT INTO %q.%q (id, data) VALUES (?, ?)", keyspace, tab) + q := session.Query(stmt, []string{"id", "data"}) + + for i := 0; i < rowCnt; i++ { + if err := q.Bind(i, i).Exec(); err != nil { + t.Fatal(err) + } + } + + q.Release() + } +} + +func runPausedRestore(t *testing.T, restore func(ctx context.Context) error, pauseInterval time.Duration, minPauseCnt int) (err error) { + t.Helper() + + ticker := time.NewTicker(pauseInterval) + ctx, cancel := context.WithCancel(context.Background()) + res := make(chan error) + pauseCnt := 0 + defer func() { + t.Logf("Restore was paused %d times", pauseCnt) + if pauseCnt < minPauseCnt { + err = multierr.Append(err, errors.Errorf("expected to pause at least %d times, got %d", minPauseCnt, pauseCnt)) + } + }() + + go func() { + res <- restore(ctx) + }() + for { + select { + case err := <-res: + cancel() + return err + case <-ticker.C: + t.Log("Pause restore") + cancel() + err := <-res + if err == nil || !errors.Is(err, context.Canceled) { + return err + } + + pauseCnt++ + ctx, cancel = context.WithCancel(context.Background()) + go func() { + res <- restore(ctx) + }() + } + } +}