Skip to content

Change the Generate function to take a SchemaSource instead of a sqldb.Queryable #189

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ if err != nil {
}
defer tempDbFactory.Close()
// Generate the migration plan
plan, err := diff.Generate(ctx, connPool, diff.DDLSchemaSource(ddl),
plan, err := diff.Generate(ctx, diff.DBSchemaSource(connPool), diff.DDLSchemaSource(ddl),
diff.WithTempDbFactory(tempDbFactory),
diff.WithDataPackNewTables(),
)
Expand Down
4 changes: 3 additions & 1 deletion cmd/pg-schema-diff/plan_cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,9 @@ func generatePlan(ctx context.Context, logger log.Logger, connConfig *pgx.ConnCo
defer schemaSourceCloser.Close()
}

plan, err := diff.Generate(ctx, connPool, schemaSource,
connSource := diff.DBSchemaSource(connPool)

plan, err := diff.Generate(ctx, connSource, schemaSource,
append(
planConfig.opts,
diff.WithTempDbFactory(tempDbFactory),
Expand Down
5 changes: 4 additions & 1 deletion internal/migration_acceptance_tests/acceptance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,10 @@ func (suite *acceptanceTestSuite) runTest(tc acceptanceTestCase) {
}
if tc.planFactory == nil {
tc.planFactory = func(ctx context.Context, connPool sqldb.Queryable, tempDbFactory tempdb.Factory, newSchemaDDL []string, opts ...diff.PlanOpt) (diff.Plan, error) {
return diff.Generate(ctx, connPool, diff.DDLSchemaSource(newSchemaDDL),

connSource := diff.DBSchemaSource(connPool)

return diff.Generate(ctx, connSource, diff.DDLSchemaSource(newSchemaDDL),
append(tc.planOpts,
diff.WithTempDbFactory(tempDbFactory),
)...)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func databaseSchemaSourcePlan(ctx context.Context, connPool sqldb.Queryable, tem
opts = append(opts, diff.WithGetSchemaOpts(o))
}

return diff.Generate(ctx, connPool, diff.DBSchemaSource(newSchemaDb.ConnPool), opts...)
return diff.Generate(ctx, diff.DBSchemaSource(connPool), diff.DBSchemaSource(newSchemaDb.ConnPool), opts...)
}

func dirSchemaSourcePlanFactory(schemaDirs []string) planFactory {
Expand All @@ -53,7 +53,9 @@ func dirSchemaSourcePlanFactory(schemaDirs []string) planFactory {
return diff.Plan{}, fmt.Errorf("creating schema source: %w", err)
}

return diff.Generate(ctx, connPool, schemaSource, opts...)
connSource := diff.DBSchemaSource(connPool)

return diff.Generate(ctx, connSource, schemaSource, opts...)
}
}

Expand All @@ -65,7 +67,7 @@ var databaseSchemaSourceTestCases = []acceptanceTestCase{
oldSchemaDDL: []string{
`
CREATE TABLE fizz();

CREATE TABLE foobar(
id INT,
bar SERIAL NOT NULL,
Expand Down Expand Up @@ -134,7 +136,7 @@ var databaseSchemaSourceTestCases = []acceptanceTestCase{
CREATE INDEX bar_normal_idx ON bar(bar);
CREATE INDEX bar_another_normal_id ON bar(bar, fizz);
CREATE UNIQUE INDEX bar_unique_idx on bar(fizz, buzz);

`,
},
expectedHazardTypes: []diff.MigrationHazardType{
Expand Down
16 changes: 11 additions & 5 deletions pkg/diff/plan_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,20 +106,22 @@ func WithGetSchemaOpts(getSchemaOpts ...externalschema.GetSchemaOpt) PlanOpt {
// newDDL: DDL encoding the new schema
// opts: Additional options to configure the plan generation
func GeneratePlan(ctx context.Context, queryable sqldb.Queryable, tempdbFactory tempdb.Factory, newDDL []string, opts ...PlanOpt) (Plan, error) {
return Generate(ctx, queryable, DDLSchemaSource(newDDL), append(opts, WithTempDbFactory(tempdbFactory), WithIncludeSchemas("public"))...)

schemaSource := DBSchemaSource(queryable)

return Generate(ctx, schemaSource, DDLSchemaSource(newDDL), append(opts, WithTempDbFactory(tempdbFactory), WithIncludeSchemas("public"))...)
}

// Generate generates a migration plan to migrate the database to the target schema
//
// Parameters:
// fromDB: The target database to generate the diff for. It is recommended to pass in *sql.DB of the db you
// wish to migrate. If using a connection pool, it is RECOMMENDED to set a maximum number of connections.
// fromSchema: The target schema to generate the diff for.
// targetSchema: The (source of the) schema you want to migrate the database to. Use DDLSchemaSource if the new
// schema is encoded in DDL.
// opts: Additional options to configure the plan generation
func Generate(
ctx context.Context,
fromDB sqldb.Queryable,
fromSchema SchemaSource,
targetSchema SchemaSource,
opts ...PlanOpt,
) (Plan, error) {
Expand All @@ -132,7 +134,11 @@ func Generate(
opt(planOptions)
}

currentSchema, err := schema.GetSchema(ctx, fromDB, planOptions.getSchemaOpts...)
currentSchema, err := fromSchema.GetSchema(ctx, schemaSourcePlanDeps{
tempDBFactory: planOptions.tempDbFactory,
logger: planOptions.logger,
getSchemaOpts: planOptions.getSchemaOpts,
})
if err != nil {
return Plan{}, fmt.Errorf("getting current schema: %w", err)
}
Expand Down
10 changes: 5 additions & 5 deletions pkg/diff/plan_generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func (suite *planGeneratorTestSuite) TestGenerate() {
tempDbFactory := suite.mustBuildTempDbFactory(context.Background())
defer tempDbFactory.Close()

plan, err := Generate(context.Background(), connPool, DDLSchemaSource([]string{newSchemaDDL}), WithTempDbFactory(tempDbFactory))
plan, err := Generate(context.Background(), DBSchemaSource(connPool), DDLSchemaSource([]string{newSchemaDDL}), WithTempDbFactory(tempDbFactory))
suite.NoError(err)

suite.mustApplyMigrationPlan(connPool, plan)
Expand Down Expand Up @@ -140,7 +140,7 @@ func (suite *planGeneratorTestSuite) TestGeneratePlan_SchemaSourceErr() {
connPool := suite.mustGetTestDBPool()
defer connPool.Close()

_, err := Generate(context.Background(), connPool, fakeSchemaSource,
_, err := Generate(context.Background(), DBSchemaSource(connPool), fakeSchemaSource,
WithTempDbFactory(tempDbFactory),
WithGetSchemaOpts(getSchemaOpts...),
WithLogger(logger),
Expand All @@ -163,7 +163,7 @@ func (suite *planGeneratorTestSuite) TestGenerate_CannotPackNewTablesWithoutIgno
connPool := suite.mustGetTestDBPool()
defer connPool.Close()

_, err := Generate(context.Background(), connPool, DDLSchemaSource([]string{``}),
_, err := Generate(context.Background(), DBSchemaSource(connPool), DDLSchemaSource([]string{``}),
WithTempDbFactory(tempDbFactory),
WithDataPackNewTables(),
WithRespectColumnOrder(),
Expand All @@ -174,7 +174,7 @@ func (suite *planGeneratorTestSuite) TestGenerate_CannotPackNewTablesWithoutIgno
func (suite *planGeneratorTestSuite) TestGenerate_CannotBuildMigrationFromDDLWithoutTempDbFactory() {
pool := suite.mustGetTestDBPool()
defer pool.Close()
_, err := Generate(context.Background(), pool, DDLSchemaSource([]string{``}),
_, err := Generate(context.Background(), DBSchemaSource(pool), DDLSchemaSource([]string{``}),
WithIncludeSchemas("public"),
WithDoNotValidatePlan(),
)
Expand All @@ -184,7 +184,7 @@ func (suite *planGeneratorTestSuite) TestGenerate_CannotBuildMigrationFromDDLWit
func (suite *planGeneratorTestSuite) TestGenerate_CannotValidateWithoutTempDbFactory() {
pool := suite.mustGetTestDBPool()
defer pool.Close()
_, err := Generate(context.Background(), pool, DDLSchemaSource([]string{``}),
_, err := Generate(context.Background(), DBSchemaSource(pool), DDLSchemaSource([]string{``}),
WithIncludeSchemas("public"),
WithDoNotValidatePlan(),
)
Expand Down