From 377979cdc5b7155e7c70c843251b36eab18a8608 Mon Sep 17 00:00:00 2001 From: Drew Wells Date: Fri, 8 Nov 2024 14:50:35 -0600 Subject: [PATCH] PTEUDO-1991 on restore failure, create a fresh schema When a database is partially or fully restored, we fail to migrate to it. This is problematic as user can not easily recover from this situation. We will attempt to version the schema by migrating whatever is there and starting with a new schema. In situations where the issue is the pgdump itself, we will be creating multiple copies of an empty schema which is cheap. In more typical situations, we will move a partially or fully migrated schema to another location and attempt a restore with a fresh schema. - change restore to return sql errors rather than exit codes - sanitize dsn emitted by restore --- cmd/config/config.yaml | 5 +- .../databasecontroller_migrate_test.go | 140 ++++++++--- internal/controller/suite_test.go | 13 +- internal/controller/testdata/mock.sql | 68 +++++ internal/dockerdb/mockdb.go | 22 +- internal/dockerdb/testdb.go | 105 +++++--- internal/dockerdb/testdb_test.go | 3 +- pkg/databaseclaim/claimstatus.go | 2 +- pkg/databaseclaim/databaseclaim.go | 8 +- pkg/dbclient/client.go | 2 +- pkg/dbclient/client_test.go | 6 +- pkg/pgctl/pgctl.go | 154 ++++++------ pkg/pgctl/pgctl_test.go | 16 +- pkg/pgctl/pgdump.go | 87 +++++-- pkg/pgctl/pgrestore.go | 233 ++++++++++++------ pkg/pgctl/pgrestore_test.go | 46 +++- pkg/pgctl/subscription.go | 48 ++++ pkg/pgctl/utils.go | 21 ++ pkg/pgctl/utils_test.go | 26 ++ pkg/roleclaim/roleclaim_test.go | 15 +- 20 files changed, 739 insertions(+), 281 deletions(-) create mode 100644 internal/controller/testdata/mock.sql create mode 100644 pkg/pgctl/subscription.go create mode 100644 pkg/pgctl/utils_test.go diff --git a/cmd/config/config.yaml b/cmd/config/config.yaml index b62fea38..622571ce 100644 --- a/cmd/config/config.yaml +++ b/cmd/config/config.yaml @@ -39,7 +39,10 @@ defaultMajorVersion: 15 passwordConfig: passwordComplexity: enabled minPasswordLength: 15 - passwordRotationPeriod: 60 + # 60s is a common error requeue in the codebase. Set this + # to something other than 60s to distinguish between + # success requeue and those error requeues. + passwordRotationPeriod: 65s systemFunctions: ib_realm: "ib_realm" diff --git a/internal/controller/databasecontroller_migrate_test.go b/internal/controller/databasecontroller_migrate_test.go index 148884e0..2c3a6bb3 100644 --- a/internal/controller/databasecontroller_migrate_test.go +++ b/internal/controller/databasecontroller_migrate_test.go @@ -41,14 +41,14 @@ import ( "github.com/infobloxopen/db-controller/pkg/pgctl" ) -var _ = Describe("claim migrate", func() { +var _ = Describe("Migrate", func() { // Define utility constants for object names and testing timeouts/durations and intervals. var ( ctxLogger context.Context cancel func() - success = ctrl.Result{Requeue: false, RequeueAfter: 60} + success = ctrl.Result{Requeue: false, RequeueAfter: 65 * time.Second} ) BeforeEach(func() { @@ -93,12 +93,11 @@ var _ = Describe("claim migrate", func() { // Namespace: "default", // } - claim := &persistancev1.DatabaseClaim{} - kctx := context.Background() BeforeEach(func() { + claim := &persistancev1.DatabaseClaim{} By("ensuring the resource does not exist") Expect(k8sClient.Get(kctx, typeNamespacedName, claim)).To(HaveOccurred()) @@ -119,7 +118,7 @@ var _ = Describe("claim migrate", func() { }, Spec: persistancev1.DatabaseClaimSpec{ Class: ptr.To(""), - DatabaseName: "postgres", + DatabaseName: "sample_app", SecretName: claimSecretName, EnableSuperUser: ptr.To(false), EnableReplicationRole: ptr.To(false), @@ -128,6 +127,7 @@ var _ = Describe("claim migrate", func() { SourceDataFrom: &persistancev1.SourceDataFrom{ Type: "database", Database: &persistancev1.Database{ + DSN: testDSN, SecretRef: &persistancev1.SecretRef{ Name: sourceSecretName, Namespace: "default", @@ -190,10 +190,12 @@ var _ = Describe("claim migrate", func() { _, err := controllerReconciler.Reconcile(ctxLogger, reconcile.Request{NamespacedName: typeNamespacedName}) Expect(err).NotTo(HaveOccurred()) + var claim persistancev1.DatabaseClaim + Expect(k8sClient.Get(ctxLogger, typeNamespacedName, &claim)).NotTo(HaveOccurred()) Expect(claim.Status.Error).To(Equal("")) By("Ensuring the active db connection info is set") Eventually(func() *persistancev1.DatabaseClaimConnectionInfo { - Expect(k8sClient.Get(ctxLogger, typeNamespacedName, claim)).NotTo(HaveOccurred()) + Expect(k8sClient.Get(ctxLogger, typeNamespacedName, &claim)).NotTo(HaveOccurred()) return claim.Status.ActiveDB.ConnectionInfo }).ShouldNot(BeNil()) @@ -223,16 +225,18 @@ var _ = Describe("claim migrate", func() { Expect(dsn.Redacted()).To(Equal(redacted)) }) - It("Migrate", func() { + It("Populate a new database", func() { Expect(controllerReconciler.Reconcile(ctxLogger, reconcile.Request{NamespacedName: typeNamespacedName})).To(Equal(success)) var dbc persistancev1.DatabaseClaim Expect(k8sClient.Get(ctxLogger, typeNamespacedName, &dbc)).NotTo(HaveOccurred()) - Expect(claim.Status.Error).To(Equal("")) + Expect(dbc.Status.Error).To(Equal("")) By("Ensuring the active db connection info is set") + Expect(controllerReconciler.Reconcile(ctxLogger, reconcile.Request{NamespacedName: typeNamespacedName})).To(Equal(success)) + Eventually(func() *persistancev1.DatabaseClaimConnectionInfo { Expect(k8sClient.Get(ctxLogger, typeNamespacedName, &dbc)).NotTo(HaveOccurred()) - return claim.Status.ActiveDB.ConnectionInfo + return dbc.Status.ActiveDB.ConnectionInfo }).ShouldNot(BeNil()) hostParams, err := hostparams.New(controllerReconciler.Config.Viper, &dbc) @@ -242,7 +246,7 @@ var _ = Describe("claim migrate", func() { By(fmt.Sprintf("Mocking a RDS pod to look like crossplane set it up: %s", fakeCPSecretName)) fakeCli, fakeDSN, fakeCancel := dockerdb.MockRDS(GinkgoT(), ctxLogger, k8sClient, fakeCPSecretName, "migrate", dbc.Spec.DatabaseName) - DeferCleanup(fakeCancel) + defer fakeCancel() fakeCPSecret := corev1.Secret{ ObjectMeta: metav1.ObjectMeta{ @@ -269,7 +273,7 @@ var _ = Describe("claim migrate", func() { By("Check source DSN looks roughly correct") activeDB := dbc.Status.ActiveDB Expect(activeDB.ConnectionInfo).NotTo(BeNil()) - compareDSN := strings.Replace(testDSN, "//postgres:postgres", fmt.Sprintf("//%s_b:", migratedowner), 1) + compareDSN := strings.Replace(testDSN, "//postgres:postgres", fmt.Sprintf("//%s_a:", migratedowner), 1) Expect(activeDB.ConnectionInfo.Uri()).To(Equal(compareDSN)) By("Check target DSN looks roughly correct") @@ -282,28 +286,18 @@ var _ = Describe("claim migrate", func() { var tempCreds corev1.Secret // temp-migrate-dbclaim-creds Expect(k8sClient.Get(ctxLogger, types.NamespacedName{Name: "temp-" + claimSecretName, Namespace: "default"}, &tempCreds)).NotTo(HaveOccurred()) - for k, v := range tempCreds.Data { - logger.Info("tempcreds", k, string(v)) - } - By("CR reconciles but must be requeued to perform migration, reconcile manually for test") - res, err = controllerReconciler.Reconcile(ctxLogger, reconcile.Request{NamespacedName: typeNamespacedName}) - Expect(err).To(BeNil()) - Expect(res.Requeue).To(BeFalse()) - Expect(k8sClient.Get(ctxLogger, typeNamespacedName, &dbc)).NotTo(HaveOccurred()) - Expect(dbc.Status.Error).To(Equal("")) + By("Requeue for as long as Controller requests it") + Eventually(func() reconcile.Result { + res, err = controllerReconciler.Reconcile(ctxLogger, reconcile.Request{NamespacedName: typeNamespacedName}) + Expect(err).To(BeNil()) + return res + }).Should(Equal(success)) - By("Waiting to disable source, reconcile manually again") - res, err = controllerReconciler.Reconcile(ctxLogger, reconcile.Request{NamespacedName: typeNamespacedName}) - Expect(err).To(BeNil()) - Expect(res.RequeueAfter).To(Equal(time.Duration(60 * time.Second))) - By("Verify migration is complete on this reconcile") - res, err = controllerReconciler.Reconcile(ctxLogger, reconcile.Request{NamespacedName: typeNamespacedName}) - Expect(err).To(BeNil()) - Expect(res.Requeue).To(BeFalse()) Expect(k8sClient.Get(ctxLogger, typeNamespacedName, &dbc)).NotTo(HaveOccurred()) Expect(dbc.Status.Error).To(Equal("")) Expect(dbc.Status.MigrationState).To(Equal(pgctl.S_Completed.String())) + activeDB = dbc.Status.ActiveDB Expect(activeDB.ConnectionInfo).NotTo(BeNil()) Expect(activeDB.ConnectionInfo.Uri()).To(Equal(compareDSN)) @@ -316,5 +310,95 @@ var _ = Describe("claim migrate", func() { }) + It("Existing schema", func() { + + Expect(controllerReconciler.Reconcile(ctxLogger, reconcile.Request{NamespacedName: typeNamespacedName})).To(Equal(success)) + var dbc persistancev1.DatabaseClaim + Expect(k8sClient.Get(ctxLogger, typeNamespacedName, &dbc)).NotTo(HaveOccurred()) + Expect(dbc.Status.Error).To(Equal("")) + By("Ensuring the active db connection info is set") + Eventually(func() *persistancev1.DatabaseClaimConnectionInfo { + Expect(k8sClient.Get(ctxLogger, typeNamespacedName, &dbc)).NotTo(HaveOccurred()) + return dbc.Status.ActiveDB.ConnectionInfo + }).ShouldNot(BeNil()) + + oldDB, err := url.Parse(testDSN) + Expect(err).NotTo(HaveOccurred()) + aConn := dbc.Status.ActiveDB.ConnectionInfo + Expect(aConn.Port).To(Equal(oldDB.Port())) + + logger.Info("what", "status", dbc.Status.ActiveDB.ConnectionInfo.Uri()) + + hostParams, err := hostparams.New(controllerReconciler.Config.Viper, &dbc) + Expect(err).ToNot(HaveOccurred()) + + fakeCPSecretName := fmt.Sprintf("%s-%s-%s", env, resourceName, hostParams.Hash()) + + By(fmt.Sprintf("Mocking a RDS pod to look like crossplane set it up: %s", fakeCPSecretName)) + fakeCli, fakeDSN, fakeCancel := dockerdb.MockRDS(GinkgoT(), ctxLogger, k8sClient, fakeCPSecretName, "migrate", dbc.Spec.DatabaseName) + defer fakeCancel() + + dockerdb.MustSQL(ctxLogger, fakeCli, "testdata/mock.sql") + + fakeCPSecret := corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: fakeCPSecretName, + Namespace: "default", + }, + } + nname := types.NamespacedName{ + Name: fakeCPSecretName, + Namespace: "default", + } + Eventually(k8sClient.Get(ctxLogger, nname, &fakeCPSecret)).Should(Succeed()) + logger.Info("debugsecret", "rdssecret", fakeCPSecret) + + By("Disabling UseExistingSource") + Expect(k8sClient.Get(ctxLogger, typeNamespacedName, &dbc)).NotTo(HaveOccurred()) + dbc.Spec.UseExistingSource = ptr.To(false) + Expect(k8sClient.Update(ctxLogger, &dbc)).NotTo(HaveOccurred()) + + res, err := controllerReconciler.Reconcile(ctxLogger, reconcile.Request{NamespacedName: typeNamespacedName}) + Expect(err).To(BeNil()) + Expect(dbc.Status.Error).To(Equal("")) + Expect(res.Requeue).To(BeTrue()) + + By("Check DSNs looks roughly correct") + Expect(k8sClient.Get(ctxLogger, typeNamespacedName, &dbc)).NotTo(HaveOccurred()) + activeDB := dbc.Status.ActiveDB + Expect(activeDB.ConnectionInfo).NotTo(BeNil()) + compareDSN := strings.Replace(testDSN, "//postgres:postgres", fmt.Sprintf("//%s_a:", migratedowner), 1) + Expect(activeDB.ConnectionInfo.Uri()).To(Equal(compareDSN)) + + newDB := dbc.Status.NewDB + compareDSN = strings.Replace(fakeDSN, "//migrate:postgres", fmt.Sprintf("//%s_a:", migratedowner), 1) + + Expect(newDB.ConnectionInfo).NotTo(BeNil()) + Expect(newDB.ConnectionInfo.Uri()).To(Equal(compareDSN)) + + By("Ensuring migration is in progress") + Expect(dbc.Status.MigrationState).To(Equal(pgctl.S_MigrationInProgress.String())) + var tempCreds corev1.Secret + // temp-migrate-dbclaim-creds + Expect(k8sClient.Get(ctxLogger, types.NamespacedName{Name: "temp-" + claimSecretName, Namespace: "default"}, &tempCreds)).NotTo(HaveOccurred()) + + By("Requeue for as long as Controller requests it") + Eventually(func() reconcile.Result { + + res, err = controllerReconciler.Reconcile(ctxLogger, reconcile.Request{NamespacedName: typeNamespacedName}) + Expect(err).To(BeNil()) + return res + }).Should(Equal(success)) + + Expect(k8sClient.Get(ctxLogger, typeNamespacedName, &dbc)).NotTo(HaveOccurred()) + Expect(dbc.Status.Error).To(Equal("")) + + Expect(dbc.Status.MigrationState).To(Equal(pgctl.S_Completed.String())) + activeDB = dbc.Status.ActiveDB + Expect(activeDB.ConnectionInfo).NotTo(BeNil()) + Expect(activeDB.ConnectionInfo.Uri()).To(Equal(compareDSN)) + + }) + }) }) diff --git a/internal/controller/suite_test.go b/internal/controller/suite_test.go index d09e14fc..74fc65f1 100644 --- a/internal/controller/suite_test.go +++ b/internal/controller/suite_test.go @@ -17,6 +17,7 @@ limitations under the License. package controller import ( + "context" "database/sql" "fmt" "path/filepath" @@ -130,7 +131,8 @@ var _ = BeforeSuite(func() { Expect(k8sClient).NotTo(BeNil()) now := time.Now() - testdb, testDSN, cleanupTestDB = dockerdb.Run(dockerdb.Config{ + logger.Info("start postgres setup") + testdb, testDSN, cleanupTestDB = dockerdb.Run(logger, dockerdb.Config{ Database: "postgres", Username: "postgres", Password: "postgres", @@ -138,14 +140,7 @@ var _ = BeforeSuite(func() { }) logger.Info("postgres_setup_took", "duration", time.Since(now)) - // Mock table for testing migrations - _, err = testdb.Exec(`CREATE TABLE IF NOT EXISTS users ( - id SERIAL PRIMARY KEY, - name TEXT NOT NULL, - email TEXT NOT NULL, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP - )`) - Expect(err).NotTo(HaveOccurred()) + dockerdb.MustSQL(context.TODO(), testdb, "testdata/mock.sql") // Setup controller By("setting up the database controller") diff --git a/internal/controller/testdata/mock.sql b/internal/controller/testdata/mock.sql new file mode 100644 index 00000000..8520bccb --- /dev/null +++ b/internal/controller/testdata/mock.sql @@ -0,0 +1,68 @@ +-- Create a table with a primary key and a sequence +CREATE TABLE users ( + user_id SERIAL PRIMARY KEY, + username VARCHAR(50) UNIQUE NOT NULL, + email VARCHAR(100) UNIQUE NOT NULL +); + +-- Create another table with a foreign key reference to users +CREATE TABLE orders ( + order_id SERIAL PRIMARY KEY, + user_id INT REFERENCES users(user_id) ON DELETE CASCADE, + order_total NUMERIC(10, 2) NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +-- Create a table with a check constraint +CREATE TABLE products ( + product_id SERIAL PRIMARY KEY, + name VARCHAR(100) NOT NULL, + price NUMERIC(10, 2) NOT NULL CHECK (price > 0) +); + +-- Create a function to calculate the total price of orders for a user +CREATE OR REPLACE FUNCTION calculate_user_total(userId INT) RETURNS NUMERIC AS $$ + DECLARE + total NUMERIC(10, 2); + BEGIN + SELECT COALESCE(SUM(order_total), 0) INTO total + FROM orders WHERE user_id = userId; + RETURN total; + END; +$$ LANGUAGE plpgsql; + +-- Create an index on the orders table to optimize queries by created_at +CREATE INDEX idx_orders_created_at ON orders(created_at); + +-- Create a materialized view that summarizes total orders per user +CREATE MATERIALIZED VIEW user_order_totals AS +SELECT u.user_id, u.username, COALESCE(SUM(o.order_total), 0) AS total_spent +FROM users u +LEFT JOIN orders o ON u.user_id = o.user_id +GROUP BY u.user_id, u.username; + +-- Create a trigger to log inserts into the orders table +CREATE TABLE order_logs ( + log_id SERIAL PRIMARY KEY, + order_id INT NOT NULL, + log_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +CREATE OR REPLACE FUNCTION log_order_insert() RETURNS TRIGGER AS $$ +BEGIN + INSERT INTO order_logs (order_id) VALUES (NEW.order_id); + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +CREATE TRIGGER after_order_insert +AFTER INSERT ON orders +FOR EACH ROW EXECUTE FUNCTION log_order_insert(); + +-- Create a table with an exclusion constraint +CREATE TABLE event_schedule ( + event_id SERIAL PRIMARY KEY, + event_name VARCHAR(100) NOT NULL, + event_time TSRANGE NOT NULL, + EXCLUDE USING gist (event_time WITH &&) -- Ensures no overlapping events +); diff --git a/internal/dockerdb/mockdb.go b/internal/dockerdb/mockdb.go index 6e2be8d2..acd95897 100644 --- a/internal/dockerdb/mockdb.go +++ b/internal/dockerdb/mockdb.go @@ -3,6 +3,8 @@ package dockerdb import ( "context" "database/sql" + "fmt" + "io/ioutil" "net/url" "strings" @@ -11,12 +13,28 @@ import ( corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/log" ) +// MustSQL will open a file and exec it on the database. +// It will panic on any errors +func MustSQL(ctx context.Context, db *sql.DB, fileName string, args ...any) { + bs, err := ioutil.ReadFile(fileName) + if err != nil { + panic(err) + } + _, err = db.ExecContext(ctx, string(bs), args...) + if err != nil { + panic(err) + } +} + func MockRDS(t GinkgoTInterface, ctx context.Context, cli client.Client, secretName, userName, databaseName string) (*sql.DB, string, func()) { t.Helper() - dbCli, fakeDSN, clean := Run(Config{ + logger := log.FromContext(ctx).WithName("mockrds") + + dbCli, fakeDSN, clean := Run(logger, Config{ Database: databaseName, Username: userName, Password: "postgres", @@ -49,7 +67,7 @@ func MockRDS(t GinkgoTInterface, ctx context.Context, cli client.Client, secretN return dbCli, fakeDSN, func() { if err := cli.Delete(ctx, secret); err != nil { - t.Logf("failed to delete secret: %v", err) + panic(fmt.Sprintf("failed to delete secret: %v", err)) } clean() } diff --git a/internal/dockerdb/testdb.go b/internal/dockerdb/testdb.go index c2a1d9ea..0333f405 100644 --- a/internal/dockerdb/testdb.go +++ b/internal/dockerdb/testdb.go @@ -1,11 +1,14 @@ package dockerdb import ( + "bufio" "bytes" + "context" "crypto/rand" "database/sql" "encoding/base64" "fmt" + "io" "log" "net" "net/url" @@ -15,29 +18,12 @@ import ( "time" "github.com/go-logr/logr" - "go.uber.org/zap/zapcore" - "sigs.k8s.io/controller-runtime/pkg/log/zap" ) -// logger is used since some times we run from testing.M and testing.T is not available -var logger logr.Logger - -// DebugLevel is used to set V level to 1 as suggested by official docs -// https://github.com/kubernetes-sigs/controller-runtime/blob/main/TMP-LOGGING.md -const debugLevel = 1 - -func init() { - // Use zap logger - opts := zap.Options{ - Development: true, - // Enable this to debug this code - //Level: zapcore.DebugLevel, - Level: zapcore.InfoLevel, - } - - logger = zap.New(zap.UseFlagOptions(&opts)) +var debugLevel = 1 -} +// This does not write to stdout or stderr +var logger logr.Logger func getEphemeralPort() int { l, err := net.Listen("tcp", "localhost:0") @@ -182,9 +168,11 @@ type Config struct { // Run a PostgreSQL database in a Docker container and return a connection to it. // The caller is responsible for calling the func() to prevent leaking containers. -func Run(cfg Config) (*sql.DB, string, func()) { +func Run(log logr.Logger, cfg Config) (*sql.DB, string, func()) { port := getEphemeralPort() + ctx, cancel := context.WithCancel(context.Background()) + // Required parameters if cfg.Database == "" { panic("database name is required") @@ -223,24 +211,27 @@ func Run(cfg Config) (*sql.DB, string, func()) { ctrArgs := []string{fmt.Sprintf("postgres:%s", cfg.DockerTag), "postgres", "-c", "wal_level=logical"} // Run PostgreSQL in Docker - cmd := exec.Command("docker", append(args, ctrArgs...)...) - logger.V(debugLevel).Info(cmd.String()) + cmd := exec.CommandContext(ctx, "docker", append(args, ctrArgs...)...) + log.V(debugLevel).Info(cmd.String()) + var stderr bytes.Buffer cmd.Stderr = &stderr out, err := cmd.Output() if err != nil { - logger.Error(err, "failed to run docker container") - logger.Info(cmd.String()) - logger.Info("stderr:" + stderr.String()) + log.Error(err, "failed to run docker container") + log.Info(cmd.String()) + log.Info("stderr:" + stderr.String()) os.Exit(1) } - logger.V(debugLevel).Info(string(out)) + log.V(debugLevel).Info(string(out)) container := string(out[:len(out)-1]) // remove newline + connectLogger(ctx, log, container) + // Exercise hotload //hotload.RegisterSQLDriver("pgx", stdlib.GetDefaultDriver()) dsn := fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=disable", url.QueryEscape(cfg.Username), url.QueryEscape(cfg.Password), GetOutboundIP(), port, cfg.Database) - logger.V(debugLevel).Info(dsn) + log.V(debugLevel).Info(dsn) f, err := os.CreateTemp("", "dsn.txt") if err != nil { panic(err) @@ -271,43 +262,73 @@ CREATE ROLE alloydbsuperuser WITH INHERIT LOGIN`) } if err != nil { - logger.Error(err, "failed to connect to database") + log.Error(err, "failed to connect to database") cmd = exec.Command("docker", "logs", container) cmd.Stderr = os.Stderr out, err := cmd.Output() if err != nil { - logger.Error(err, "failed to get logs") + log.Error(err, "failed to get logs") } - logger.Info(string(out)) + log.Info(string(out)) os.Exit(1) } // TODO: change this to debug logging, just timing jenkins for now - logger.Info("db_connected", "dsn", dsn, "duration", time.Since(now)) + log.Info("db_connected", "dsn", dsn, "duration", time.Since(now)) return conn, dsn, func() { // Cleanup container on close, dont exit without trying all steps first now := time.Now() defer func() { - logger.V(debugLevel).Info("container_cleanup_took", "duration", time.Since(now)) + cancel() + log.V(debugLevel).Info("container_cleanup_took", "duration", time.Since(now)) }() err := os.Remove(f.Name()) if err != nil { - logger.Error(err, "failed to remove temp file") + log.Error(err, "failed to remove temp file") } cmd := exec.Command("docker", "rm", "-f", container) // This take 10 seconds to run, and we don't care if // it was successful. So use Start() to not wait for // it to finish. - logger.V(debugLevel).Info(cmd.String()) + log.V(debugLevel).Info(cmd.String()) if err := cmd.Start(); err != nil { - logger.Error(err, "failed to remove container") + log.Error(err, "failed to remove container") } } } +func connectLogger(ctx context.Context, logger logr.Logger, containerID string) { + + log := logger.WithName("testdb") + // Connect to the container's logs + cmd := exec.CommandContext(ctx, "docker", "logs", "-f", containerID) + + log.Info("connecting to container logs", "cmd", cmd.String()) + + stdout, err := cmd.StdoutPipe() + if err != nil { + log.Error(err, "failed to get stdout pipe") + os.Exit(1) + } + stderr, err := cmd.StderrPipe() + if err != nil { + log.Error(err, "failed to get stderr pipe") + os.Exit(1) + } + + go logStdouterr(stdout, log) + go logStdouterr(stderr, log) + + err = cmd.Start() + if err != nil { + logger.Error(err, "failed to start command") + return + } +} + func GetOutboundIP() string { conn, err := net.Dial("udp", "8.8.8.8:80") if err != nil { @@ -319,3 +340,15 @@ func GetOutboundIP() string { return localAddr.IP.String() } + +func logStdouterr(out io.ReadCloser, logger logr.Logger) { + scanner := bufio.NewScanner(out) + for scanner.Scan() { + // This is very noisy and left off by default + // Consider wiring this up to a test verbosity setting + //logger.V(1).Info(scanner.Text()) + } + if err := scanner.Err(); err != nil { + logger.Error(err, "Error reading command output") + } +} diff --git a/internal/dockerdb/testdb_test.go b/internal/dockerdb/testdb_test.go index 65451553..1361eeac 100644 --- a/internal/dockerdb/testdb_test.go +++ b/internal/dockerdb/testdb_test.go @@ -3,11 +3,12 @@ package dockerdb import ( "testing" + "github.com/go-logr/logr" _ "github.com/lib/pq" ) func TestDB(t *testing.T) { - db, _, cleanup := Run(Config{Database: "testdb"}) + db, _, cleanup := Run(logr.Logger{}, Config{Database: "testdb"}) defer cleanup() defer db.Close() _, err := db.Exec("CREATE TABLE test (id SERIAL PRIMARY KEY, name TEXT)") diff --git a/pkg/databaseclaim/claimstatus.go b/pkg/databaseclaim/claimstatus.go index 2d853736..d15a2953 100644 --- a/pkg/databaseclaim/claimstatus.go +++ b/pkg/databaseclaim/claimstatus.go @@ -23,7 +23,7 @@ func (r *DatabaseClaimReconciler) manageError(ctx context.Context, dbClaim *v1.D // Class of errors that should stop the reconciliation loop // but not cause a status change on the CR if errors.Is(inErr, ErrDoNotUpdateStatus) { - return ctrl.Result{}, nil + return ctrl.Result{RequeueAfter: r.getPasswordRotationTime()}, nil } return manageError(ctx, r.Client, dbClaim, inErr) } diff --git a/pkg/databaseclaim/databaseclaim.go b/pkg/databaseclaim/databaseclaim.go index d9206dbb..25558391 100644 --- a/pkg/databaseclaim/databaseclaim.go +++ b/pkg/databaseclaim/databaseclaim.go @@ -446,7 +446,9 @@ func (r *DatabaseClaimReconciler) reconcileUseExistingDB(ctx context.Context, re } } - logr.Info("status_block", "status", dbClaim.Status) + // Reset status as this is a new cycle + dbClaim.Status.Error = "" + dbClaim.Status.MigrationState = "" sourceDSN, err := auth.GetSourceDataFromDSN(ctx, r.Client, dbClaim) if err != nil { @@ -486,8 +488,6 @@ func (r *DatabaseClaimReconciler) reconcileUseExistingDB(ctx context.Context, re } defer dbClient.Close() - logr.Info(fmt.Sprintf("processing DBClaim: %s namespace: %s AppID: %s", dbClaim.Name, dbClaim.Namespace, dbClaim.Spec.AppID)) - dbName := existingDBConnInfo.DatabaseName updateDBStatus(&dbClaim.Status.NewDB, dbName) @@ -915,7 +915,7 @@ loop: logr.Error(err, "ignoring delete temp secret error") } //create connection info secret - logr.Info("migration complete") + logr.Info("migration complete", "status", dbClaim.Status) return r.manageSuccess(ctx, dbClaim) } diff --git a/pkg/dbclient/client.go b/pkg/dbclient/client.go index 35a7e921..f0e6b40a 100644 --- a/pkg/dbclient/client.go +++ b/pkg/dbclient/client.go @@ -395,7 +395,7 @@ func (pc *client) CreateSchema(schemaName string) (bool, error) { createSchema := strings.Replace(` CREATE SCHEMA IF NOT EXISTS "%schema%"; REVOKE ALL ON SCHEMA "%schema%" FROM PUBLIC; - GRANT USAGE ON SCHEMA "%schema%" TO PUBLIC; + GRANT USAGE ON SCHEMA "%schema%" TO PUBLIC; REVOKE ALL ON ALL TABLES IN SCHEMA "%schema%" FROM PUBLIC ; GRANT SELECT ON ALL TABLES IN SCHEMA "%schema%" TO PUBLIC; `, "%schema%", schemaName, -1) diff --git a/pkg/dbclient/client_test.go b/pkg/dbclient/client_test.go index 4527683e..3dbc2532 100644 --- a/pkg/dbclient/client_test.go +++ b/pkg/dbclient/client_test.go @@ -20,7 +20,9 @@ func NewTestLogger(t *testing.T) logr.Logger { func TestPostgresClientOperations(t *testing.T) { - db, dsn, close := dockerdb.Run(dockerdb.Config{ + testLogger := NewTestLogger(t) + + db, dsn, close := dockerdb.Run(testLogger, dockerdb.Config{ Username: "test", Password: "pa@ss$){[d~&!@#$%^*()_+`-={}|[]:<>?,./", Database: "postgres", @@ -75,7 +77,7 @@ func TestPostgresClientOperations(t *testing.T) { dbURL: dsn, DB: db, adminDB: db, - log: NewTestLogger(t), + log: testLogger, } got, err := pc.CreateDatabase(tt.args.dbName) diff --git a/pkg/pgctl/pgctl.go b/pkg/pgctl/pgctl.go index 41472bef..d966ff04 100644 --- a/pkg/pgctl/pgctl.go +++ b/pkg/pgctl/pgctl.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net/url" + "path/filepath" "github.com/go-logr/logr" "github.com/infobloxopen/db-controller/pkg/metrics" @@ -211,7 +212,7 @@ func (s *validate_connection_state) String() string { func (s *create_publication_state) Execute() (State, error) { log := s.config.Log.WithValues("state", s.String()) - log.Info("started") + log.V(1).Info("started") var ( err error @@ -229,8 +230,8 @@ func (s *create_publication_state) Execute() (State, error) { q := ` SELECT EXISTS( - SELECT pubname - FROM pg_catalog.pg_publication + SELECT pubname + FROM pg_catalog.pg_publication WHERE pubname = $1)` // dynamically creating the table list to be included in the publication @@ -270,17 +271,19 @@ func (s *create_publication_state) Execute() (State, error) { log.Error(err, "could not query for publication name") return nil, err } - if !exists { - log.Info("creating publication:", "with name", DefaultPubName) - if _, err := sourceDBAdmin.Exec(createPub); err != nil { - log.Error(err, "create publication failed") - return nil, err - } - log.Info("publication created", "name", DefaultPubName) - } else { + if exists { log.Info("publication already exists", "with name", DefaultPubName) + return ©_schema_state{ + config: s.config, + }, nil } + if _, err := sourceDBAdmin.Exec(createPub); err != nil { + log.Error(err, "create publication failed") + return nil, err + } + log.V(1).Info("publication created", "name", DefaultPubName) + return ©_schema_state{ config: s.config, }, nil @@ -315,7 +318,7 @@ var revokeSuperUserAccess = func(DBAdmin *sql.DB, userrole string, cloud string) // dumping the schema, and restoring it to the target database. func (s *copy_schema_state) Execute() (State, error) { log := s.config.Log.WithValues("state", s.String()) - log.Info("started") + log.V(1).Info("started") //grant rds_superuser to target db user temporarily //this is required to copy schema from source to target and take ownership of the objects @@ -336,6 +339,7 @@ func (s *copy_schema_state) Execute() (State, error) { return nil, err } + // FIXME: actually use the role name, not the rotating user name rolename := url.User.Username() log.Info("granting super user access", "role", rolename) @@ -346,49 +350,38 @@ func (s *copy_schema_state) Execute() (State, error) { return nil, err } - dump := NewDump(s.config.SourceDBAdminDsn) - dump.SetupFormat("p") - dump.SetPath(s.config.ExportFilePath) - dump.EnableVerbose() - dump.SetOptions([]string{ + dump := NewDump(s.config.SourceDBAdminDsn, WithFormat("p"), WithPath(s.config.ExportFilePath), WithLogger(s.config.Log.V(1)), WithVerbose(true), WithOptions([]string{ "--schema-only", "--no-publication", "--no-subscriptions", "--no-privileges", "--no-owner", "--exclude-schema=ib", - }) + })) dumpExec := dump.Exec(ExecOptions{StreamPrint: true}) if dumpExec.Error != nil { return nil, dumpExec.Error.Err } - log.Info("executed dump", "full command", dumpExec.FullCommand) if err = dump.modifyPgDumpInfo(); err != nil { log.Error(err, "failed to comment create policy") return nil, err } - restore := NewRestore(s.config.TargetDBUserDsn) - restore.EnableVerbose() - restore.SetPath(s.config.ExportFilePath) - - restoreExec := restore.Exec(dumpExec.FileName, ExecOptions{StreamPrint: true}) - if restoreExec.Error != nil { - log.Error(restoreExec.Error.Err, "restore failed") - - // Attempt to drop schemas after restore failure. - dropResult := restore.DropSchemas() - if dropResult.Error != nil { - log.Error(dropResult.Error.Err, "failed to drop schemas") - return nil, dropResult.Error.Err - } - - return nil, restoreExec.Error.Err + restore, err := NewRestore(s.config.TargetDBAdminDsn, s.config.TargetDBUserDsn, getParentRole(rolename), WithRestoreLogger(s.config.Log)) + if err != nil { + return nil, err } + defer restore.Close() - log.Info("executed restore", "full command", restoreExec.FullCommand) + dumpPath := filepath.Join(s.config.ExportFilePath, dumpExec.FileName) + err = restore.Exec(dumpPath, ExecOptions{StreamPrint: true}) + if err != nil { + log.Error(err, "restore failed") + return nil, err + } + log.Info("restore_successful") err = revokeSuperUserAccess(targetDBAdmin, rolename, s.config.Cloud) if err != nil { @@ -397,7 +390,7 @@ func (s *copy_schema_state) Execute() (State, error) { return nil, err } - log.Info("completed") + log.V(1).Info("completed") return &create_subscription_state{ config: s.config, }, nil @@ -417,7 +410,7 @@ var getSourceDbAdminDSNForCreateSubscription = func(c *Config) string { func (s *create_subscription_state) Execute() (State, error) { log := s.config.Log.WithValues("state", s.String()) - log.Info("started") + log.V(1).Info("started") var exists bool targetDBAdmin, err := getDB(s.config.TargetDBAdminDsn, nil) @@ -434,40 +427,42 @@ func (s *create_subscription_state) Execute() (State, error) { )` createSub := fmt.Sprintf(` CREATE SUBSCRIPTION %s - CONNECTION '%s' - PUBLICATION %s + CONNECTION '%s' + PUBLICATION %s WITH (enabled=false)`, DefaultSubName, getSourceDbAdminDSNForCreateSubscription(&s.config), DefaultPubName) + log.V(1).Info("started subscription", "sub", createSub) + err = targetDBAdmin.QueryRow(q, DefaultSubName).Scan(&exists) if err != nil { log.Error(err, "could not query for subscription name", "stmt", createSub) return nil, err } - if !exists { - log.Info("creating subscription:", "with name", DefaultSubName) - if _, err := targetDBAdmin.Exec(createSub); err != nil { - log.Error(err, "could not create subscription") - return nil, err - } - log.Info("subscription created", "name", DefaultSubName) + + if _, err := targetDBAdmin.Exec(createSub); err != nil { + log.Error(err, "could not create subscription") + return nil, err } - log.Info("completed") + + log.V(1).Info("completed", "subscription_name", DefaultSubName) return &enable_subscription_state{ config: s.config, }, nil } + func (s *create_subscription_state) Id() StateEnum { return S_CreateSubscription } + func (s *create_subscription_state) String() string { return S_CreateSubscription.String() } func (s *enable_subscription_state) Execute() (State, error) { log := s.config.Log.WithValues("state", s.String()) - log.Info("started") + log.V(1).Info("started") targetDBAdmin, err := getDB(s.config.TargetDBAdminDsn, nil) if err != nil { @@ -501,7 +496,7 @@ func (s *enable_subscription_state) Execute() (State, error) { return nil, fmt.Errorf("unable to enable subscription. subscription not found - %s", DefaultSubName) } - log.Info("completed") + log.V(1).Info("completed") return &cut_over_readiness_check_state{ config: s.config, @@ -516,7 +511,7 @@ func (s *enable_subscription_state) String() string { func (s *cut_over_readiness_check_state) Execute() (State, error) { log := s.config.Log.WithValues("state", s.String()) - log.Info("started") + log.V(1).Info("started") var exists bool var count int @@ -537,23 +532,24 @@ func (s *cut_over_readiness_check_state) Execute() (State, error) { pubQuery := fmt.Sprintf(` SELECT EXISTS ( - SELECT 1 - FROM pg_replication_slots - WHERE slot_type = 'logical' + SELECT 1 + FROM pg_replication_slots + WHERE slot_type = 'logical' AND slot_name like '%s_%%' AND temporary = 't' )`, DefaultPubName) + // relExistsQuery := fmt.Sprintf(` - // SELECT EXISTS - // ( - // SELECT 1 - // FROM pg_subscription s, pg_subscription_rel sr - // WHERE s.oid = sr.srsubid - // AND s.subname = '%s' - // )`, DefaultSubName) + // SELECT EXISTS + // ( + // SELECT 1 + // FROM pg_subscription s, pg_subscription_rel sr + // WHERE s.oid = sr.srsubid + // AND s.subname = '%s' + // )`, DefaultSubName) subQuery := fmt.Sprintf(` - SELECT count(srrelid) + SELECT count(srrelid) FROM pg_subscription s, pg_subscription_rel sr WHERE s.oid = sr.srsubid AND sr.srsubstate not in ('r', 's') @@ -580,6 +576,13 @@ func (s *cut_over_readiness_check_state) Execute() (State, error) { // return retry(s.config), nil // } + state, err := getSubscriptionStatus(targetDBAdmin, DefaultSubName) + if err != nil { + log.Error(err, "could not get subscription status") + } else { + log.Info("subscription_status", "state", state) + } + err = targetDBAdmin.QueryRow(subQuery).Scan(&count) if err != nil { log.Error(err, "could not query for subscription for completion") @@ -587,9 +590,10 @@ func (s *cut_over_readiness_check_state) Execute() (State, error) { } if count > 0 { log.Info("migration not complete in target - retry check in a few seconds") + return retry(s.config), nil } - log.Info("completed") + log.V(1).Info("completed") return &reset_target_sequence_state{ config: s.config, }, nil @@ -603,7 +607,7 @@ func (s *cut_over_readiness_check_state) String() string { func (s *reset_target_sequence_state) Execute() (State, error) { log := s.config.Log.WithValues("state", s.String()) - log.Info("started") + log.V(1).Info("started") type seqCount struct { seqName string @@ -660,7 +664,7 @@ func (s *reset_target_sequence_state) Execute() (State, error) { return nil, err } } - log.Info("completed") + log.V(1).Info("completed") return &reroute_target_secret_state{ config: s.config, }, nil @@ -699,7 +703,7 @@ func (s *wait_to_disable_source_state) String() string { func (s *disable_source_access_state) Execute() (State, error) { log := s.config.Log.WithValues("state", s.String()) - log.Info("started") + log.V(1).Info("started") sourceDBAdmin, err := getDB(s.config.SourceDBAdminDsn, nil) if err != nil { @@ -727,7 +731,7 @@ func (s *disable_source_access_state) Execute() (State, error) { log.Error(err, "failed revoking access for source db - "+rolename) return nil, err } - log.Info("completed") + log.V(1).Info("completed") return &validate_migration_status_state{ config: s.config, }, nil @@ -742,7 +746,7 @@ func (s *disable_source_access_state) String() string { func (s *validate_migration_status_state) Execute() (State, error) { log := s.config.Log.WithValues("state", s.String()) - log.Info("started") + log.V(1).Info("started") var ( sourceTableName string @@ -821,7 +825,7 @@ func (s *validate_migration_status_state) Execute() (State, error) { return nil, err } if deuce { - log.Info("completed") + log.V(1).Info("completed") return &disable_subscription_state{ config: s.config, }, nil @@ -842,7 +846,7 @@ func (s *validate_migration_status_state) String() string { func (s *disable_subscription_state) Execute() (State, error) { log := s.config.Log.WithValues("state", s.String()) - log.Info("started") + log.V(1).Info("started") var exists bool @@ -873,7 +877,7 @@ func (s *disable_subscription_state) Execute() (State, error) { }, nil } - log.Info("completed") + log.V(1).Info("completed") return &delete_subscription_state{ config: s.config, @@ -888,7 +892,7 @@ func (s *disable_subscription_state) String() string { func (s *delete_subscription_state) Execute() (State, error) { log := s.config.Log.WithValues("state", s.String()) - log.Info("started") + log.V(1).Info("started") var exists bool targetDBAdmin, err := getDB(s.config.TargetDBAdminDsn, nil) @@ -918,7 +922,7 @@ func (s *delete_subscription_state) Execute() (State, error) { } log.Info("Subscription deleted") } - log.Info("completed") + log.V(1).Info("completed") return &delete_publication_state{ config: s.config, }, nil @@ -932,7 +936,7 @@ func (s *delete_subscription_state) String() string { func (s *delete_publication_state) Execute() (State, error) { log := s.config.Log.WithValues("state", s.String()) - log.Info("started") + log.V(1).Info("started") var exists bool @@ -965,7 +969,7 @@ func (s *delete_publication_state) Execute() (State, error) { log.Info("publication not found. ignoring and moving on") } - log.Info("completed") + log.V(1).Info("completed") return &completed_state{ config: s.config, }, nil diff --git a/pkg/pgctl/pgctl_test.go b/pkg/pgctl/pgctl_test.go index 9787159b..36c3c001 100644 --- a/pkg/pgctl/pgctl_test.go +++ b/pkg/pgctl/pgctl_test.go @@ -14,6 +14,7 @@ import ( "github.com/go-logr/logr" "github.com/infobloxopen/db-controller/internal/dockerdb" "github.com/lib/pq" + "go.uber.org/zap/zapcore" "sigs.k8s.io/controller-runtime/pkg/log/zap" ) @@ -42,13 +43,14 @@ var logger logr.Logger func TestMain(m *testing.M) { //need to do this trick to avoid os.Exit bypassing defer logic - //with this silly setup, defer is called in realTestMain before the exit is called in this func + //with this sill setup, defer is called in realTestMain before the exit is called in this func os.Exit(setupAndRunTests(m)) } func setupAndRunTests(m *testing.M) int { opts := zap.Options{ Development: true, + Level: zapcore.InfoLevel, } logger = zap.New(zap.UseFlagOptions(&opts)) @@ -61,7 +63,7 @@ func setupAndRunTests(m *testing.M) int { // migration. // FIXME: randomly generate network name - _, sourceDSN, sourceClose := dockerdb.Run(dockerdb.Config{ + _, sourceDSN, sourceClose := dockerdb.Run(logger, dockerdb.Config{ HostName: "pubHost", DockerTag: sourceVersion, Database: "pub", @@ -75,7 +77,7 @@ func setupAndRunTests(m *testing.M) int { panic(err) } - _, targetDSN, targetClose := dockerdb.Run(dockerdb.Config{ + _, targetDSN, targetClose := dockerdb.Run(logger, dockerdb.Config{ HostName: "subHost", DockerTag: targetVersion, Database: "sub", @@ -103,7 +105,7 @@ func setupAndRunTests(m *testing.M) int { // Set up source and target databases for unit testing each step of the // migration. - _, dataTestSourceAdminDSN, dataTestSourceClose := dockerdb.Run(dockerdb.Config{ + _, dataTestSourceAdminDSN, dataTestSourceClose := dockerdb.Run(logger, dockerdb.Config{ HostName: "dataTestSourceHost", DockerTag: sourceVersion, Database: "dataTestSource", @@ -117,7 +119,7 @@ func setupAndRunTests(m *testing.M) int { panic(err) } - _, dataTestTargetAdminDSN, dataTestTargetClose := dockerdb.Run(dockerdb.Config{ + _, dataTestTargetAdminDSN, dataTestTargetClose := dockerdb.Run(logger, dockerdb.Config{ HostName: "dataTestTargetHost", DockerTag: targetVersion, Database: "dataTestTarget", @@ -144,11 +146,11 @@ func setupAndRunTests(m *testing.M) int { // ----------------------------------------------------------------------- // Set up a database for testing the drop schema functionality. - _, dropSchemaDSN, dropSchemaClose := dockerdb.Run(dockerdb.Config{ + _, dropSchemaDSN, dropSchemaClose := dockerdb.Run(logger, dockerdb.Config{ HostName: "dropSchemaHost", DockerTag: targetVersion, Database: "sub", - Username: "dropSchemaAdmin", + Username: "dropschemaadmin", Password: "dropSchemaSecret", Network: networkName, }) diff --git a/pkg/pgctl/pgdump.go b/pkg/pgctl/pgdump.go index 8465f2e6..7ae81c40 100644 --- a/pkg/pgctl/pgdump.go +++ b/pkg/pgctl/pgdump.go @@ -1,12 +1,15 @@ package pgctl import ( + "bufio" "fmt" "os" "os/exec" "runtime" "strings" "time" + + "github.com/go-logr/logr" ) var ( @@ -21,11 +24,10 @@ type Results struct { } type Result struct { - Mine string - FileName string - Output string - Error *ResultError - FullCommand string + Mine string + FileName string + Output string + Error *ResultError } type ResultError struct { @@ -41,25 +43,49 @@ type Dump struct { Format *string Options []string fileName string + logger logr.Logger } -func NewDump(DsnUri string) *Dump { - return &Dump{Options: PGDumpOpts, DsnUri: DsnUri} +type DumpOptions = func(x *Dump) + +// NewDump creates a new Dump instance with the provided configuration. DSN must be in URI format. +func NewDump(DSN string, options ...DumpOptions) *Dump { + d := &Dump{Options: PGDumpOpts, DsnUri: DSN} + for _, option := range options { + option(d) + } + return d } func (x *Dump) Exec(opts ExecOptions) Result { result := Result{Mine: "application/x-tar"} result.FileName = x.GetFileName() options := append(x.dumpOptions(), fmt.Sprintf(`-f%s%v`, x.Path, result.FileName)) - result.FullCommand = strings.Join(options, " ") + + // TODO: santitize dsn + x.logger.Info("pgdump_database", "full_command", PGDump+" "+strings.Join(options, " ")) + cmd := exec.Command(PGDump, options...) - // cmd.Env = append(os.Environ(), x.EnvPassword) - stderrIn, _ := cmd.StderrPipe() + stderrIn, err := cmd.StderrPipe() + if err != nil { + result.Error = &ResultError{Err: err} + return result + } + go func() { - result.Output = streamExecOutput(stderrIn, opts) + + scanner := bufio.NewScanner(stderrIn) + for scanner.Scan() { + x.logger.Info(scanner.Text()) + } + if err := scanner.Err(); err != nil { + x.logger.Error(err, "Error reading command output") + } + }() + cmd.Start() - err := cmd.Wait() + err = cmd.Wait() if exitError, ok := err.(*exec.ExitError); ok { result.Error = &ResultError{Err: err, ExitCode: exitError.ExitCode(), CmdOutput: result.Output} } @@ -69,10 +95,6 @@ func (x *Dump) ResetOptions() { x.Options = []string{} } -func (x *Dump) EnableVerbose() { - x.Verbose = true -} - func (x *Dump) SetFileName(filename string) { x.fileName = filename } @@ -85,12 +107,34 @@ func (x *Dump) GetFileName() string { return x.fileName } -func (x *Dump) SetupFormat(f string) { - x.Format = &f +func WithFormat(f string) func(x *Dump) { + return func(x *Dump) { + x.Format = &f + } +} + +func WithPath(path string) func(x *Dump) { + return func(x *Dump) { + x.Path = path + } +} + +func WithLogger(logger logr.Logger) func(x *Dump) { + return func(x *Dump) { + x.logger = logger.WithName("pg_dump") + } +} + +func WithVerbose(verbose bool) func(x *Dump) { + return func(x *Dump) { + x.Verbose = verbose + } } -func (x *Dump) SetPath(path string) { - x.Path = path +func WithOptions(o []string) func(x *Dump) { + return func(x *Dump) { + x.Options = o + } } func (x *Dump) newFileName() string { @@ -112,9 +156,6 @@ func (x *Dump) dumpOptions() []string { return options } -func (x *Dump) SetOptions(o []string) { - x.Options = o -} func (x *Dump) GetOptions() []string { return x.Options } diff --git a/pkg/pgctl/pgrestore.go b/pkg/pgctl/pgrestore.go index c75cc4bc..917ef911 100644 --- a/pkg/pgctl/pgrestore.go +++ b/pkg/pgctl/pgrestore.go @@ -1,9 +1,17 @@ package pgctl import ( + "context" + "database/sql" + "errors" "fmt" + "net/url" "os/exec" "strings" + "time" + + "github.com/go-logr/logr" + "github.com/lib/pq" ) var ( @@ -12,80 +20,146 @@ var ( ) type Restore struct { - DsnUri string - Verbose bool - Path string - Options []string - Schemas []string + AdminDSN string + UserRole string + DsnUri string + Verbose bool + Options []string + Schemas []string + logger logr.Logger + databaseName string + + adminDB *sql.DB + retries int } +type RestoreOptions = func(x *Restore) + // NewRestore creates a new Restore instance with the provided configuration. -func NewRestore(DsnUri string) *Restore { - return &Restore{ - Options: PGDRestoreOpts, - DsnUri: DsnUri, - Schemas: []string{"public"}, +func NewRestore(adminDSN, userDSN, userRole string, options ...RestoreOptions) (*Restore, error) { + u, err := url.Parse(adminDSN) + if err != nil { + return nil, err + } + + u, err = url.Parse(userDSN) + if err != nil { + return nil, err + } + + databaseName := strings.TrimPrefix(u.Path, "/") + if databaseName == "" { + return nil, fmt.Errorf("database name not found in user DSN") + } + + r := &Restore{ + AdminDSN: adminDSN, + Options: PGDRestoreOpts, + DsnUri: userDSN, + Schemas: []string{"public"}, + databaseName: databaseName, + UserRole: userRole, + retries: 3, + } + + r.adminDB, err = sql.Open("postgres", adminDSN) + if err != nil { + return nil, err + } + + for _, option := range options { + option(r) + } + + return r, nil +} + +func WithRestoreLogger(logger logr.Logger) RestoreOptions { + return func(x *Restore) { + x.logger = logger.WithName("pg_restore") + } +} + +// Close closes the database connection. +func (x *Restore) Close() error { + return x.adminDB.Close() +} + +// Exec runs the pg_restore command with the provided +// filename and options. It does this with pgctl cli because +// multiple statements are used and parsing those in Go would +// be error prone. +func (x *Restore) Exec(sqlPath string, opts ExecOptions) error { + count := 0 + for { + err := x.exec(sqlPath, opts) + if err == nil { + return nil + } + if count > x.retries { + x.logger.Error(err, "restore database failed", "count", count) + return err + } + count++ + x.logger.Info("restore failed, backing up schema and re-attempting", "count", count) + if x.moveSchema(context.Background()) != nil { + return fmt.Errorf("failed to move schema: %w", err) + } } } -// Exec runs the pg_restore command with the provided filename and options. -func (x *Restore) Exec(filename string, opts ExecOptions) Result { +func (x *Restore) exec(sqlPath string, opts ExecOptions) error { options := []string{ - x.DsnUri, "-vON_ERROR_STOP=ON", - fmt.Sprintf("--file=%s%s", x.Path, filename), + fmt.Sprintf("--file=%s", sqlPath), } options = append(options, x.restoreOptions()...) - result := Result{ - FullCommand: strings.Join(options, " "), - } + logCmd := PSQL + " " + strings.Join(append([]string{SanitizeDSN(x.DsnUri)}, options...), " ") - cmd := exec.Command(PSQL, options...) + x.logger.Info("restoring", "full_command", logCmd) + + args := append([]string{x.DsnUri}, options...) + + cmd := exec.Command(PSQL, args...) // Pipe to capture error output. stderrIn, err := cmd.StderrPipe() if err != nil { - result.Error = &ResultError{Err: err} - return result + return err } - go func() { - result.Output = streamExecOutput(stderrIn, opts) - }() + var lastLine string + go logStdouterr(stderrIn, x.logger, &lastLine) + defer stderrIn.Close() err = cmd.Start() if err != nil { - result.Error = &ResultError{Err: err, CmdOutput: result.Output} - return result + return err } err = cmd.Wait() if err != nil { - if exitError, ok := err.(*exec.ExitError); ok { - result.Error = &ResultError{Err: exitError, ExitCode: exitError.ExitCode(), CmdOutput: result.Output} - return result + exitErr, ok := err.(*exec.ExitError) + if !ok { + return fmt.Errorf("unexpected error executing pgctl: %w", err) } - - result.Error = &ResultError{Err: err, CmdOutput: result.Output} - return result + // Probably ran into an issue with a pgctl line. See if we have a lastLine to parse + errLine := parseLastLine(sqlPath, lastLine) + x.logger.Error(exitErr, "pgctl error", "lastLine", errLine) + if strings.Contains(errLine.Error(), "already exists") { + return fmt.Errorf("%s", errLine) + } + return errors.Join(err, fmt.Errorf("%s", errLine)) } - return result + return nil } func (x *Restore) ResetOptions() { x.Options = []string{} } -func (x *Restore) EnableVerbose() { - x.Verbose = true -} - -func (x *Restore) SetPath(path string) { - x.Path = path -} - func (x *Restore) SetSchemas(schemas []string) { x.Schemas = schemas } @@ -107,46 +181,59 @@ func (x *Restore) GetOptions() []string { return x.Options } -// DropSchemas drops all schemas except the system ones. -func (x *Restore) DropSchemas() Result { - dropSchemaSQL := ` - DO $$ DECLARE - r RECORD; - BEGIN - FOR r IN (SELECT nspname FROM pg_namespace WHERE nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast') AND nspname !~ '^pg_temp_') LOOP - EXECUTE 'DROP SCHEMA IF EXISTS ' || quote_ident(r.nspname) || ' CASCADE'; - END LOOP; - END $$; - ` - - result := Result{} - cmd := exec.Command(PSQL, x.DsnUri, "-c", dropSchemaSQL) +// moveSchema will backup the existing public schema then +// create a new one owned by the user role. +func (x *Restore) moveSchema(ctx context.Context) error { + if err := x.backupSchema(ctx); err != nil { + return fmt.Errorf("failed to backup schema: %w", err) + } - // Pipe to capture error output. - stderrIn, err := cmd.StderrPipe() - if err != nil { - result.Error = &ResultError{Err: err} - return result + if err := x.recreateSchema(ctx); err != nil { + return fmt.Errorf("failed to recreate schemas: %w", err) } - go func() { - result.Output = streamExecOutput(stderrIn, ExecOptions{}) - }() + return nil +} - if err := cmd.Start(); err != nil { - result.Error = &ResultError{Err: err, CmdOutput: result.Output} - return result +func (x *Restore) backupSchema(ctx context.Context) error { + sql := fmt.Sprintf(`DO $$ +BEGIN + IF EXISTS (SELECT 1 FROM pg_namespace WHERE nspname = 'public') THEN + EXECUTE 'ALTER SCHEMA public RENAME TO public_migrate_failed_%s'; + END IF; +END $$; +`, time.Now().Format("20060102150405")) + + return x.sqlExec(ctx, "backup_schema", sql) +} + +// DropSchemas drops all schemas except the system ones. +func (x *Restore) recreateSchema(ctx context.Context) error { + if x.UserRole == "" { + return fmt.Errorf("user role not found") } - if err := cmd.Wait(); err != nil { - if exitError, ok := err.(*exec.ExitError); ok { - result.Error = &ResultError{Err: exitError, ExitCode: exitError.ExitCode(), CmdOutput: result.Output} - return result - } + if err := x.sqlExec(context.TODO(), "create_schema", "CREATE SCHEMA public AUTHORIZATION "+pq.QuoteIdentifier(x.UserRole)); err != nil { + return fmt.Errorf("failed to create public schema: %w", err) + } + // Re-apply standard public permissions + return x.sqlExec(context.TODO(), "grant_schema", `GRANT USAGE ON SCHEMA public to PUBLIC;`) +} - result.Error = &ResultError{Err: err, CmdOutput: result.Output} - return result +func (x *Restore) sqlExec(ctx context.Context, action, query string, args ...any) error { + + x.logger.V(1).Info(action, "full_command", query) + + db, err := sql.Open("postgres", x.AdminDSN) + if err != nil { + return fmt.Errorf("failed to open database connection: %w", err) + } + defer db.Close() + + _, err = db.ExecContext(ctx, query, args...) + if err != nil { + return fmt.Errorf("failed to execute query: %w query: %s", err, query) } - return result + return err } diff --git a/pkg/pgctl/pgrestore_test.go b/pkg/pgctl/pgrestore_test.go index db67cf57..ef30343e 100644 --- a/pkg/pgctl/pgrestore_test.go +++ b/pkg/pgctl/pgrestore_test.go @@ -1,6 +1,7 @@ package pgctl import ( + "context" "database/sql" "fmt" "testing" @@ -8,7 +9,8 @@ import ( "github.com/lib/pq" ) -func TestDropSchemas(t *testing.T) { +func TestBackupSchemas(t *testing.T) { + // Create a connection to the isolated test database. testDB, err := getDB(dropSchemaDBAdminDsn, nil) if err != nil { @@ -16,23 +18,26 @@ func TestDropSchemas(t *testing.T) { } defer closeDB(logger, testDB) - restore := NewRestore(dropSchemaDBAdminDsn) - - t.Run("Drop existing schemas", func(t *testing.T) { - setupTestSchemas(t, testDB) + // FIXME: This unit tests needs to verify it works with + // a user dsn like is used in the controller code. + restore, err := NewRestore(dropSchemaDBAdminDsn, dropSchemaDBAdminDsn, "dropschemaadmin") + if err != nil { + t.Fatal(err) + } + t.Run("Backup public schema", func(t *testing.T) { // Verify the schemas were created. schemasBefore, err := listSchemas(testDB) if err != nil { t.Fatalf("failed to list schemas before drop: %v", err) } + if len(schemasBefore) == 0 { t.Fatal("no schemas were created for testing") } - dropResult := restore.DropSchemas() - if dropResult.Error != nil { - t.Fatalf("drop schemas failed: %v", dropResult.Error.Err) + if err = restore.backupSchema(context.TODO()); err != nil { + t.Fatalf("drop schemas failed: %s", err) } // Verify that all non-system schemas were dropped. @@ -43,9 +48,24 @@ func TestDropSchemas(t *testing.T) { if len(schemasAfter) != 0 { t.Fatalf("expected no schemas after drop, but found: %v", schemasAfter) } + + if err := restore.recreateSchema(context.TODO()); err != nil { + t.Fatal(err) + } + + // Verify that the schemas were recreated. + schemasAfterRecreate, err := listSchemas(testDB) + if err != nil { + t.Fatalf("failed to list schemas after recreate: %v", err) + } + if len(schemasAfterRecreate) != len(schemasBefore) { + t.Fatalf("expected %d schemas after recreate, but found: %v", len(schemasBefore), schemasAfterRecreate) + } + }) t.Run("Drop schemas when there are no schemas", func(t *testing.T) { + t.Skip("this is not supported") // Ensure no schemas exist before the test. schemasBefore, err := listSchemas(testDB) if err != nil { @@ -56,9 +76,9 @@ func TestDropSchemas(t *testing.T) { } // Call DropSchemas to verify it handles the empty state correctly. - dropResult := restore.DropSchemas() - if dropResult.Error != nil { - t.Fatalf("drop schemas failed when no schemas existed: %v", dropResult.Error.Err) + + if err := restore.backupSchema(context.TODO()); err != nil { + t.Fatalf("drop schemas failed when no schemas existed: %s", err) } }) } @@ -85,7 +105,9 @@ func listSchemas(db *sql.DB) ([]string, error) { SELECT schema_name FROM information_schema.schemata WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'pg_toast') - AND schema_name NOT LIKE 'pg_temp_%'`) + AND schema_name NOT LIKE 'pg_temp_%' + AND schema_name NOT LIKE 'public_migrate_failed%' +`) if err != nil { return nil, err } diff --git a/pkg/pgctl/subscription.go b/pkg/pgctl/subscription.go new file mode 100644 index 00000000..c04640ed --- /dev/null +++ b/pkg/pgctl/subscription.go @@ -0,0 +1,48 @@ +package pgctl + +import ( + "database/sql" + "fmt" +) + +// SubscriptionStatus represents the status of a PostgreSQL subscription +type SubscriptionStatus struct { + TableName string + SRSubState string + State string + LSN sql.NullString + CurrentLSN sql.NullString + LagSize sql.NullString +} + +// getSubscriptionStatus returns the status of a PostgreSQL subscription +func getSubscriptionStatus(db *sql.DB, subscriptionName string) (*SubscriptionStatus, error) { + // Query to get subscription status information + query := `SELECT + sr.srrelid::regclass as table_name, + sr.srsubstate, + CASE sr.srsubstate + WHEN 'i' THEN 'Initializing' + WHEN 'd' THEN 'Data Copying' + WHEN 's' THEN 'Synchronized' + WHEN 'r' THEN 'Ready' + WHEN 'f' THEN 'Failed' + END as state, + sr.srsublsn as lsn, + pg_last_wal_receive_lsn()::text as current_lsn, + pg_size_pretty(pg_wal_lsn_diff(pg_last_wal_receive_lsn(), sr.srsublsn)) as lag_size +FROM pg_subscription s +JOIN pg_subscription_rel sr ON s.oid = sr.srsubid +WHERE s.subname = $1; +` + + var status SubscriptionStatus + err := db.QueryRow(query, subscriptionName).Scan(&status.TableName, &status.SRSubState, &status.State, &status.LSN, &status.CurrentLSN, &status.LagSize) + if err == sql.ErrNoRows { + return nil, fmt.Errorf("subscription '%s' does not exist", subscriptionName) + } else if err != nil { + return nil, err + } + + return &status, nil +} diff --git a/pkg/pgctl/utils.go b/pkg/pgctl/utils.go index 6fdc167c..f91b01aa 100644 --- a/pkg/pgctl/utils.go +++ b/pkg/pgctl/utils.go @@ -8,6 +8,7 @@ import ( "io" "net/url" "os/exec" + "strings" "time" "github.com/go-logr/logr" @@ -150,3 +151,23 @@ func SanitizeDSN(dsn string) string { } return u.Redacted() } + +// treamExecOutput reads the output of a command and returns it as a string. +func logStdouterr(out io.ReadCloser, logger logr.Logger, lastLine *string) { + scanner := bufio.NewScanner(out) + for scanner.Scan() { + line := scanner.Text() + logger.Info(line) + *lastLine = line + } + if err := scanner.Err(); err != nil { + logger.Error(err, "Error reading command output") + } +} + +// parseLastLine attempts to normalize lines emitted by pgctl +// before: sql:/tmp/pub_1731361094.sql:36: ERROR: function "calculate_user_total" already exists with same argument types +// after: ERROR: function "calculate_user_total" already exists with same argument types +func parseLastLine(sqlFilePath, line string) error { + return fmt.Errorf(strings.TrimPrefix(line, fmt.Sprintf("psql:%s:", sqlFilePath))) +} diff --git a/pkg/pgctl/utils_test.go b/pkg/pgctl/utils_test.go new file mode 100644 index 00000000..82a3ea86 --- /dev/null +++ b/pkg/pgctl/utils_test.go @@ -0,0 +1,26 @@ +package pgctl + +import ( + "errors" + "testing" +) + +func TestParseLastLine(t *testing.T) { + filePath := "/tmp/pub_1731362625.sql" + for _, tt := range []struct { + line string + want error + }{ + { + line: `psql:/tmp/pub_1731362625.sql:36: ERROR: function "calculate_user_total" already exists with same argument types`, + want: errors.New(`36: ERROR: function "calculate_user_total" already exists with same argument types`), + }, + } { + + got := parseLastLine(filePath, tt.line) + if got.Error() != tt.want.Error() { + t.Errorf("got %s, want %s", got.Error(), tt.want.Error()) + } + + } +} diff --git a/pkg/roleclaim/roleclaim_test.go b/pkg/roleclaim/roleclaim_test.go index 3ef3cc84..00bf33bd 100644 --- a/pkg/roleclaim/roleclaim_test.go +++ b/pkg/roleclaim/roleclaim_test.go @@ -33,10 +33,11 @@ type reconciler struct { var viperObj = viper.New() func TestDBRoleClaimController_CreateSchemasAndRoles(t *testing.T) { - logf.SetLogger(zap.New(zap.WriteTo(GinkgoWriter), zap.UseDevMode(true))) + logger := zap.New(zap.WriteTo(GinkgoWriter), zap.UseDevMode(true)) + logf.SetLogger(logger) RegisterFailHandler(Fail) - _, dsn, close := dockerdb.Run(dockerdb.Config{ + _, dsn, close := dockerdb.Run(logger, dockerdb.Config{ Username: "mainUser", Password: "masterpassword", Database: "postgres", @@ -144,10 +145,11 @@ func TestDBRoleClaimController_CreateSchemasAndRoles(t *testing.T) { } func TestDBRoleClaimController_ExistingSchemaRoleAndUser(t *testing.T) { - logf.SetLogger(zap.New(zap.WriteTo(GinkgoWriter), zap.UseDevMode(true))) + logger := zap.New(zap.WriteTo(GinkgoWriter), zap.UseDevMode(true)) + logf.SetLogger(logger) RegisterFailHandler(Fail) - _, dsn, close := dockerdb.Run(dockerdb.Config{ + _, dsn, close := dockerdb.Run(logger, dockerdb.Config{ Username: "mainUser", Password: "masterpassword", Database: "postgres", @@ -248,10 +250,11 @@ func TestDBRoleClaimController_ExistingSchemaRoleAndUser(t *testing.T) { } func TestDBRoleClaimController_RevokeRolesAndAssignNew(t *testing.T) { - logf.SetLogger(zap.New(zap.WriteTo(GinkgoWriter), zap.UseDevMode(true))) + logger := zap.New(zap.WriteTo(GinkgoWriter), zap.UseDevMode(true)) + logf.SetLogger(logger) RegisterFailHandler(Fail) - _, dsn, close := dockerdb.Run(dockerdb.Config{ + _, dsn, close := dockerdb.Run(logger, dockerdb.Config{ Username: "mainUser", Password: "masterpassword", Database: "postgres",