Skip to content

Commit

Permalink
Change Generate() function to a SchemaSource instead of a db for the …
Browse files Browse the repository at this point in the history
…from parameter (#189)
  • Loading branch information
the-glu authored Feb 4, 2025
1 parent 34273e5 commit 7e225f6
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 17 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ if err != nil {
}
defer tempDbFactory.Close()
// Generate the migration plan
plan, err := diff.Generate(ctx, connPool, diff.DDLSchemaSource(ddl),
plan, err := diff.Generate(ctx, diff.DBSchemaSource(connPool), diff.DDLSchemaSource(ddl),
diff.WithTempDbFactory(tempDbFactory),
diff.WithDataPackNewTables(),
)
Expand Down
4 changes: 3 additions & 1 deletion cmd/pg-schema-diff/plan_cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,9 @@ func generatePlan(ctx context.Context, logger log.Logger, connConfig *pgx.ConnCo
defer schemaSourceCloser.Close()
}

plan, err := diff.Generate(ctx, connPool, schemaSource,
connSource := diff.DBSchemaSource(connPool)

plan, err := diff.Generate(ctx, connSource, schemaSource,
append(
planConfig.opts,
diff.WithTempDbFactory(tempDbFactory),
Expand Down
5 changes: 4 additions & 1 deletion internal/migration_acceptance_tests/acceptance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,10 @@ func (suite *acceptanceTestSuite) runTest(tc acceptanceTestCase) {
}
if tc.planFactory == nil {
tc.planFactory = func(ctx context.Context, connPool sqldb.Queryable, tempDbFactory tempdb.Factory, newSchemaDDL []string, opts ...diff.PlanOpt) (diff.Plan, error) {
return diff.Generate(ctx, connPool, diff.DDLSchemaSource(newSchemaDDL),

connSource := diff.DBSchemaSource(connPool)

return diff.Generate(ctx, connSource, diff.DDLSchemaSource(newSchemaDDL),
append(tc.planOpts,
diff.WithTempDbFactory(tempDbFactory),
)...)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func databaseSchemaSourcePlan(ctx context.Context, connPool sqldb.Queryable, tem
opts = append(opts, diff.WithGetSchemaOpts(o))
}

return diff.Generate(ctx, connPool, diff.DBSchemaSource(newSchemaDb.ConnPool), opts...)
return diff.Generate(ctx, diff.DBSchemaSource(connPool), diff.DBSchemaSource(newSchemaDb.ConnPool), opts...)
}

func dirSchemaSourcePlanFactory(schemaDirs []string) planFactory {
Expand All @@ -53,7 +53,9 @@ func dirSchemaSourcePlanFactory(schemaDirs []string) planFactory {
return diff.Plan{}, fmt.Errorf("creating schema source: %w", err)
}

return diff.Generate(ctx, connPool, schemaSource, opts...)
connSource := diff.DBSchemaSource(connPool)

return diff.Generate(ctx, connSource, schemaSource, opts...)
}
}

Expand All @@ -65,7 +67,7 @@ var databaseSchemaSourceTestCases = []acceptanceTestCase{
oldSchemaDDL: []string{
`
CREATE TABLE fizz();
CREATE TABLE foobar(
id INT,
bar SERIAL NOT NULL,
Expand Down Expand Up @@ -134,7 +136,7 @@ var databaseSchemaSourceTestCases = []acceptanceTestCase{
CREATE INDEX bar_normal_idx ON bar(bar);
CREATE INDEX bar_another_normal_id ON bar(bar, fizz);
CREATE UNIQUE INDEX bar_unique_idx on bar(fizz, buzz);
`,
},
expectedHazardTypes: []diff.MigrationHazardType{
Expand Down
16 changes: 11 additions & 5 deletions pkg/diff/plan_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,20 +106,22 @@ func WithGetSchemaOpts(getSchemaOpts ...externalschema.GetSchemaOpt) PlanOpt {
// newDDL: DDL encoding the new schema
// opts: Additional options to configure the plan generation
func GeneratePlan(ctx context.Context, queryable sqldb.Queryable, tempdbFactory tempdb.Factory, newDDL []string, opts ...PlanOpt) (Plan, error) {
return Generate(ctx, queryable, DDLSchemaSource(newDDL), append(opts, WithTempDbFactory(tempdbFactory), WithIncludeSchemas("public"))...)

schemaSource := DBSchemaSource(queryable)

return Generate(ctx, schemaSource, DDLSchemaSource(newDDL), append(opts, WithTempDbFactory(tempdbFactory), WithIncludeSchemas("public"))...)
}

// Generate generates a migration plan to migrate the database to the target schema
//
// Parameters:
// fromDB: The target database to generate the diff for. It is recommended to pass in *sql.DB of the db you
// wish to migrate. If using a connection pool, it is RECOMMENDED to set a maximum number of connections.
// fromSchema: The target schema to generate the diff for.
// targetSchema: The (source of the) schema you want to migrate the database to. Use DDLSchemaSource if the new
// schema is encoded in DDL.
// opts: Additional options to configure the plan generation
func Generate(
ctx context.Context,
fromDB sqldb.Queryable,
fromSchema SchemaSource,
targetSchema SchemaSource,
opts ...PlanOpt,
) (Plan, error) {
Expand All @@ -132,7 +134,11 @@ func Generate(
opt(planOptions)
}

currentSchema, err := schema.GetSchema(ctx, fromDB, planOptions.getSchemaOpts...)
currentSchema, err := fromSchema.GetSchema(ctx, schemaSourcePlanDeps{
tempDBFactory: planOptions.tempDbFactory,
logger: planOptions.logger,
getSchemaOpts: planOptions.getSchemaOpts,
})
if err != nil {
return Plan{}, fmt.Errorf("getting current schema: %w", err)
}
Expand Down
10 changes: 5 additions & 5 deletions pkg/diff/plan_generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func (suite *planGeneratorTestSuite) TestGenerate() {
tempDbFactory := suite.mustBuildTempDbFactory(context.Background())
defer tempDbFactory.Close()

plan, err := Generate(context.Background(), connPool, DDLSchemaSource([]string{newSchemaDDL}), WithTempDbFactory(tempDbFactory))
plan, err := Generate(context.Background(), DBSchemaSource(connPool), DDLSchemaSource([]string{newSchemaDDL}), WithTempDbFactory(tempDbFactory))
suite.NoError(err)

suite.mustApplyMigrationPlan(connPool, plan)
Expand Down Expand Up @@ -140,7 +140,7 @@ func (suite *planGeneratorTestSuite) TestGeneratePlan_SchemaSourceErr() {
connPool := suite.mustGetTestDBPool()
defer connPool.Close()

_, err := Generate(context.Background(), connPool, fakeSchemaSource,
_, err := Generate(context.Background(), DBSchemaSource(connPool), fakeSchemaSource,
WithTempDbFactory(tempDbFactory),
WithGetSchemaOpts(getSchemaOpts...),
WithLogger(logger),
Expand All @@ -163,7 +163,7 @@ func (suite *planGeneratorTestSuite) TestGenerate_CannotPackNewTablesWithoutIgno
connPool := suite.mustGetTestDBPool()
defer connPool.Close()

_, err := Generate(context.Background(), connPool, DDLSchemaSource([]string{``}),
_, err := Generate(context.Background(), DBSchemaSource(connPool), DDLSchemaSource([]string{``}),
WithTempDbFactory(tempDbFactory),
WithDataPackNewTables(),
WithRespectColumnOrder(),
Expand All @@ -174,7 +174,7 @@ func (suite *planGeneratorTestSuite) TestGenerate_CannotPackNewTablesWithoutIgno
func (suite *planGeneratorTestSuite) TestGenerate_CannotBuildMigrationFromDDLWithoutTempDbFactory() {
pool := suite.mustGetTestDBPool()
defer pool.Close()
_, err := Generate(context.Background(), pool, DDLSchemaSource([]string{``}),
_, err := Generate(context.Background(), DBSchemaSource(pool), DDLSchemaSource([]string{``}),
WithIncludeSchemas("public"),
WithDoNotValidatePlan(),
)
Expand All @@ -184,7 +184,7 @@ func (suite *planGeneratorTestSuite) TestGenerate_CannotBuildMigrationFromDDLWit
func (suite *planGeneratorTestSuite) TestGenerate_CannotValidateWithoutTempDbFactory() {
pool := suite.mustGetTestDBPool()
defer pool.Close()
_, err := Generate(context.Background(), pool, DDLSchemaSource([]string{``}),
_, err := Generate(context.Background(), DBSchemaSource(pool), DDLSchemaSource([]string{``}),
WithIncludeSchemas("public"),
WithDoNotValidatePlan(),
)
Expand Down

0 comments on commit 7e225f6

Please sign in to comment.