From a6f6fc54e0091254cb2ff229d9b005824be6c2c4 Mon Sep 17 00:00:00 2001 From: Maximilien Cuony Date: Tue, 28 Jan 2025 14:03:48 +0100 Subject: [PATCH] Change Generate() function to a SchemaSource instead of a db for the from parameter --- README.md | 2 +- cmd/pg-schema-diff/plan_cmd.go | 4 +++- .../acceptance_test.go | 5 ++++- .../database_schema_source_cases_test.go | 10 ++++++---- pkg/diff/plan_generator.go | 16 +++++++++++----- pkg/diff/plan_generator_test.go | 10 +++++----- 6 files changed, 30 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index f847ccc..92d6f38 100644 --- a/README.md +++ b/README.md @@ -141,7 +141,7 @@ if err != nil { } defer tempDbFactory.Close() // Generate the migration plan -plan, err := diff.Generate(ctx, connPool, diff.DDLSchemaSource(ddl), +plan, err := diff.Generate(ctx, diff.DBSchemaSource(connPool), diff.DDLSchemaSource(ddl), diff.WithTempDbFactory(tempDbFactory), diff.WithDataPackNewTables(), ) diff --git a/cmd/pg-schema-diff/plan_cmd.go b/cmd/pg-schema-diff/plan_cmd.go index c982e63..36eabb2 100644 --- a/cmd/pg-schema-diff/plan_cmd.go +++ b/cmd/pg-schema-diff/plan_cmd.go @@ -410,7 +410,9 @@ func generatePlan(ctx context.Context, logger log.Logger, connConfig *pgx.ConnCo defer schemaSourceCloser.Close() } - plan, err := diff.Generate(ctx, connPool, schemaSource, + connSource := diff.DBSchemaSource(connPool) + + plan, err := diff.Generate(ctx, connSource, schemaSource, append( planConfig.opts, diff.WithTempDbFactory(tempDbFactory), diff --git a/internal/migration_acceptance_tests/acceptance_test.go b/internal/migration_acceptance_tests/acceptance_test.go index 3d51cda..4eacb4f 100644 --- a/internal/migration_acceptance_tests/acceptance_test.go +++ b/internal/migration_acceptance_tests/acceptance_test.go @@ -96,7 +96,10 @@ func (suite *acceptanceTestSuite) runTest(tc acceptanceTestCase) { } if tc.planFactory == nil { tc.planFactory = func(ctx context.Context, connPool sqldb.Queryable, tempDbFactory tempdb.Factory, newSchemaDDL []string, opts ...diff.PlanOpt) (diff.Plan, error) { - return diff.Generate(ctx, connPool, diff.DDLSchemaSource(newSchemaDDL), + + connSource := diff.DBSchemaSource(connPool) + + return diff.Generate(ctx, connSource, diff.DDLSchemaSource(newSchemaDDL), append(tc.planOpts, diff.WithTempDbFactory(tempDbFactory), )...) diff --git a/internal/migration_acceptance_tests/database_schema_source_cases_test.go b/internal/migration_acceptance_tests/database_schema_source_cases_test.go index 6ba536a..dcc8af6 100644 --- a/internal/migration_acceptance_tests/database_schema_source_cases_test.go +++ b/internal/migration_acceptance_tests/database_schema_source_cases_test.go @@ -35,7 +35,7 @@ func databaseSchemaSourcePlan(ctx context.Context, connPool sqldb.Queryable, tem opts = append(opts, diff.WithGetSchemaOpts(o)) } - return diff.Generate(ctx, connPool, diff.DBSchemaSource(newSchemaDb.ConnPool), opts...) + return diff.Generate(ctx, diff.DBSchemaSource(connPool), diff.DBSchemaSource(newSchemaDb.ConnPool), opts...) } func dirSchemaSourcePlanFactory(schemaDirs []string) planFactory { @@ -53,7 +53,9 @@ func dirSchemaSourcePlanFactory(schemaDirs []string) planFactory { return diff.Plan{}, fmt.Errorf("creating schema source: %w", err) } - return diff.Generate(ctx, connPool, schemaSource, opts...) + connSource := diff.DBSchemaSource(connPool) + + return diff.Generate(ctx, connSource, schemaSource, opts...) } } @@ -65,7 +67,7 @@ var databaseSchemaSourceTestCases = []acceptanceTestCase{ oldSchemaDDL: []string{ ` CREATE TABLE fizz(); - + CREATE TABLE foobar( id INT, bar SERIAL NOT NULL, @@ -134,7 +136,7 @@ var databaseSchemaSourceTestCases = []acceptanceTestCase{ CREATE INDEX bar_normal_idx ON bar(bar); CREATE INDEX bar_another_normal_id ON bar(bar, fizz); CREATE UNIQUE INDEX bar_unique_idx on bar(fizz, buzz); - + `, }, expectedHazardTypes: []diff.MigrationHazardType{ diff --git a/pkg/diff/plan_generator.go b/pkg/diff/plan_generator.go index a897781..b2c5a49 100644 --- a/pkg/diff/plan_generator.go +++ b/pkg/diff/plan_generator.go @@ -106,20 +106,22 @@ func WithGetSchemaOpts(getSchemaOpts ...externalschema.GetSchemaOpt) PlanOpt { // newDDL: DDL encoding the new schema // opts: Additional options to configure the plan generation func GeneratePlan(ctx context.Context, queryable sqldb.Queryable, tempdbFactory tempdb.Factory, newDDL []string, opts ...PlanOpt) (Plan, error) { - return Generate(ctx, queryable, DDLSchemaSource(newDDL), append(opts, WithTempDbFactory(tempdbFactory), WithIncludeSchemas("public"))...) + + schemaSource := DBSchemaSource(queryable) + + return Generate(ctx, schemaSource, DDLSchemaSource(newDDL), append(opts, WithTempDbFactory(tempdbFactory), WithIncludeSchemas("public"))...) } // Generate generates a migration plan to migrate the database to the target schema // // Parameters: -// fromDB: The target database to generate the diff for. It is recommended to pass in *sql.DB of the db you -// wish to migrate. If using a connection pool, it is RECOMMENDED to set a maximum number of connections. +// fromSchema: The target schema to generate the diff for. // targetSchema: The (source of the) schema you want to migrate the database to. Use DDLSchemaSource if the new // schema is encoded in DDL. // opts: Additional options to configure the plan generation func Generate( ctx context.Context, - fromDB sqldb.Queryable, + fromSchema SchemaSource, targetSchema SchemaSource, opts ...PlanOpt, ) (Plan, error) { @@ -132,7 +134,11 @@ func Generate( opt(planOptions) } - currentSchema, err := schema.GetSchema(ctx, fromDB, planOptions.getSchemaOpts...) + currentSchema, err := fromSchema.GetSchema(ctx, schemaSourcePlanDeps{ + tempDBFactory: planOptions.tempDbFactory, + logger: planOptions.logger, + getSchemaOpts: planOptions.getSchemaOpts, + }) if err != nil { return Plan{}, fmt.Errorf("getting current schema: %w", err) } diff --git a/pkg/diff/plan_generator_test.go b/pkg/diff/plan_generator_test.go index a615bcf..dc93ca4 100644 --- a/pkg/diff/plan_generator_test.go +++ b/pkg/diff/plan_generator_test.go @@ -104,7 +104,7 @@ func (suite *planGeneratorTestSuite) TestGenerate() { tempDbFactory := suite.mustBuildTempDbFactory(context.Background()) defer tempDbFactory.Close() - plan, err := Generate(context.Background(), connPool, DDLSchemaSource([]string{newSchemaDDL}), WithTempDbFactory(tempDbFactory)) + plan, err := Generate(context.Background(), DBSchemaSource(connPool), DDLSchemaSource([]string{newSchemaDDL}), WithTempDbFactory(tempDbFactory)) suite.NoError(err) suite.mustApplyMigrationPlan(connPool, plan) @@ -140,7 +140,7 @@ func (suite *planGeneratorTestSuite) TestGeneratePlan_SchemaSourceErr() { connPool := suite.mustGetTestDBPool() defer connPool.Close() - _, err := Generate(context.Background(), connPool, fakeSchemaSource, + _, err := Generate(context.Background(), DBSchemaSource(connPool), fakeSchemaSource, WithTempDbFactory(tempDbFactory), WithGetSchemaOpts(getSchemaOpts...), WithLogger(logger), @@ -163,7 +163,7 @@ func (suite *planGeneratorTestSuite) TestGenerate_CannotPackNewTablesWithoutIgno connPool := suite.mustGetTestDBPool() defer connPool.Close() - _, err := Generate(context.Background(), connPool, DDLSchemaSource([]string{``}), + _, err := Generate(context.Background(), DBSchemaSource(connPool), DDLSchemaSource([]string{``}), WithTempDbFactory(tempDbFactory), WithDataPackNewTables(), WithRespectColumnOrder(), @@ -174,7 +174,7 @@ func (suite *planGeneratorTestSuite) TestGenerate_CannotPackNewTablesWithoutIgno func (suite *planGeneratorTestSuite) TestGenerate_CannotBuildMigrationFromDDLWithoutTempDbFactory() { pool := suite.mustGetTestDBPool() defer pool.Close() - _, err := Generate(context.Background(), pool, DDLSchemaSource([]string{``}), + _, err := Generate(context.Background(), DBSchemaSource(pool), DDLSchemaSource([]string{``}), WithIncludeSchemas("public"), WithDoNotValidatePlan(), ) @@ -184,7 +184,7 @@ func (suite *planGeneratorTestSuite) TestGenerate_CannotBuildMigrationFromDDLWit func (suite *planGeneratorTestSuite) TestGenerate_CannotValidateWithoutTempDbFactory() { pool := suite.mustGetTestDBPool() defer pool.Close() - _, err := Generate(context.Background(), pool, DDLSchemaSource([]string{``}), + _, err := Generate(context.Background(), DBSchemaSource(pool), DDLSchemaSource([]string{``}), WithIncludeSchemas("public"), WithDoNotValidatePlan(), )