diff --git a/dialect/pgdialect/alter_table.go b/dialect/pgdialect/alter_table.go index 450f12513..af103fe86 100644 --- a/dialect/pgdialect/alter_table.go +++ b/dialect/pgdialect/alter_table.go @@ -9,10 +9,12 @@ import ( ) func (d *Dialect) Migrator(db *bun.DB) sqlschema.Migrator { - return &Migrator{db: db} + return &Migrator{db: db, BaseMigrator: sqlschema.NewBaseMigrator(db)} } type Migrator struct { + *sqlschema.BaseMigrator + db *bun.DB } diff --git a/internal/dbtest/migrate_test.go b/internal/dbtest/migrate_test.go index 12f310b36..91bb59265 100644 --- a/internal/dbtest/migrate_test.go +++ b/internal/dbtest/migrate_test.go @@ -3,7 +3,9 @@ package dbtest_test import ( "context" "errors" + "sort" "testing" + "time" "github.com/stretchr/testify/require" "github.com/uptrace/bun" @@ -167,6 +169,7 @@ func TestAutoMigrator_Run(t *testing.T) { fn func(t *testing.T, db *bun.DB) }{ {testRenameTable}, + {testCreateDropTable}, } testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) { @@ -217,68 +220,181 @@ func testRenameTable(t *testing.T, db *bun.DB) { require.Equal(t, "changed", tables[0].Name) } -func TestDetector_Diff(t *testing.T) { - tests := []struct { - states func(testing.TB, context.Context, schema.Dialect) (stateDb sqlschema.State, stateModel sqlschema.State) - operations []migrate.Operation - }{ - { - states: testDetectRenamedTable, - operations: []migrate.Operation{ - &migrate.RenameTable{ - From: "books", - To: "books_renamed", - }, - }, - }, +func testCreateDropTable(t *testing.T, db *bun.DB) { + type DropMe struct { + bun.BaseModel `bun:"table:dropme"` + Foo int `bun:"foo,identity"` + } + + type CreateMe struct { + bun.BaseModel `bun:"table:createme"` + Bar string `bun:",pk,default:gen_random_uuid()"` + Baz time.Time } + // Arrange + ctx := context.Background() + dbInspector, err := sqlschema.NewInspector(db) + if err != nil { + t.Skip(err) + } + mustResetModel(t, ctx, db, (*DropMe)(nil)) + mustDropTableOnCleanup(t, ctx, db, (*CreateMe)(nil)) + + m, err := migrate.NewAutoMigrator(db, + migrate.WithTableNameAuto(migrationsTable), + migrate.WithLocksTableNameAuto(migrationLocksTable), + migrate.WithModel((*CreateMe)(nil))) + require.NoError(t, err) + + // Act + err = m.Run(ctx) + require.NoError(t, err) + + // Assert + state, err := dbInspector.Inspect(ctx) + require.NoError(t, err) + + tables := state.Tables + require.Len(t, tables, 1) + require.Equal(t, "createme", tables[0].Name) +} + +type Journal struct { + ISBN string `bun:"isbn,pk"` + Title string `bun:"title,notnull"` + Pages int `bun:"page_count,notnull,default:0"` +} + +type Reader struct { + Username string `bun:",pk,default:gen_random_uuid()"` +} + +type ExternalUsers struct { + bun.BaseModel `bun:"external.users"` + Name string `bun:",pk"` +} + +func TestDetector_Diff(t *testing.T) { testEachDialect(t, func(t *testing.T, dialectName string, dialect schema.Dialect) { - for _, tt := range tests { + for _, tt := range []struct { + name string + states func(testing.TB, context.Context, schema.Dialect) (stateDb sqlschema.State, stateModel sqlschema.State) + want []migrate.Operation + }{ + { + name: "1 table renamed, 1 added, 2 dropped", + states: func(tb testing.TB, ctx context.Context, d schema.Dialect) (stateDb sqlschema.State, stateModel sqlschema.State) { + // Database state ------------- + type Subscription struct { + bun.BaseModel `bun:"table:billing.subscriptions"` + } + type Review struct{} + + type Author struct { + Name string `bun:"name"` + } + + // Model state ------------- + type JournalRenamed struct { + bun.BaseModel `bun:"table:journals_renamed"` + + ISBN string `bun:"isbn,pk"` + Title string `bun:"title,notnull"` + Pages int `bun:"page_count,notnull,default:0"` + } + + return getState(tb, ctx, d, + (*Author)(nil), + (*Journal)(nil), + (*Review)(nil), + (*Subscription)(nil), + ), getState(tb, ctx, d, + (*Author)(nil), + (*JournalRenamed)(nil), + (*Reader)(nil), + ) + }, + want: []migrate.Operation{ + &migrate.RenameTable{ + Schema: dialect.DefaultSchema(), + From: "journals", + To: "journals_renamed", + }, + &migrate.CreateTable{ + Model: &Reader{}, // (*Reader)(nil) would be more idiomatic, but schema.Tables + }, + &migrate.DropTable{ + Schema: "billing", + Name: "billing.subscriptions", // TODO: fix once schema is used correctly + }, + &migrate.DropTable{ + Schema: dialect.DefaultSchema(), + Name: "reviews", + }, + }, + }, + { + name: "renaming does not work across schemas", + states: func(tb testing.TB, ctx context.Context, d schema.Dialect) (stateDb sqlschema.State, stateModel sqlschema.State) { + // Users have the same columns as the "added" ExternalUsers. + // However, we should not recognize it as a RENAME, because only models in the same schema can be renamed. + // Instead, this is a DROP + CREATE case. + type Users struct { + bun.BaseModel `bun:"external_users"` + Name string `bun:",pk"` + } + + return getState(tb, ctx, d, + (*Users)(nil), + ), getState(t, ctx, d, + (*ExternalUsers)(nil), + ) + }, + want: []migrate.Operation{ + &migrate.DropTable{ + Schema: dialect.DefaultSchema(), + Name: "external_users", + }, + &migrate.CreateTable{ + Model: &ExternalUsers{}, + }, + }, + }, + } { t.Run(funcName(tt.states), func(t *testing.T) { ctx := context.Background() var d migrate.Detector stateDb, stateModel := tt.states(t, ctx, dialect) - diff := d.Diff(stateDb, stateModel) - - require.Equal(t, tt.operations, diff.Operations()) + got := d.Diff(stateDb, stateModel).Operations() + checkEqualChangeset(t, got, tt.want) }) } }) } -func testDetectRenamedTable(tb testing.TB, ctx context.Context, dialect schema.Dialect) (s1, s2 sqlschema.State) { - type Book struct { - bun.BaseModel +func checkEqualChangeset(tb testing.TB, got, want []migrate.Operation) { + tb.Helper() - ISBN string `bun:"isbn,pk"` - Title string `bun:"title,notnull"` - Pages int `bun:"page_count,notnull,default:0"` - } - - type Author struct { - bun.BaseModel - Name string `bun:"name"` - } + // Sort alphabetically to ensure we don't fail because of the wrong order + sort.Slice(got, func(i, j int) bool { + return got[i].String() < got[j].String() + }) + sort.Slice(want, func(i, j int) bool { + return want[i].String() < want[j].String() + }) - type BookRenamed struct { - bun.BaseModel `bun:"table:books_renamed"` + var cgot, cwant migrate.Changeset + cgot.Add(got...) + cwant.Add(want...) - ISBN string `bun:"isbn,pk"` - Title string `bun:"title,notnull"` - Pages int `bun:"page_count,notnull,default:0"` - } - return getState(tb, ctx, dialect, - (*Author)(nil), - (*Book)(nil), - ), getState(tb, ctx, dialect, - (*Author)(nil), - (*BookRenamed)(nil), - ) + require.Equal(tb, cwant.String(), cgot.String()) } func getState(tb testing.TB, ctx context.Context, dialect schema.Dialect, models ...interface{}) sqlschema.State { + tb.Helper() + tables := schema.NewTables(dialect) tables.Register(models...) diff --git a/migrate/auto.go b/migrate/auto.go index c137d77e9..c58e243d9 100644 --- a/migrate/auto.go +++ b/migrate/auto.go @@ -167,31 +167,56 @@ func (am *AutoMigrator) Run(ctx context.Context) error { // INTERNAL ------------------------------------------------------------------- -// Operation is an abstraction a level above a MigrationFunc. -// Apart from storing the function to execute the change, -// it knows how to *write* the corresponding code, and what the reverse operation is. -type Operation interface { - Func(sqlschema.Migrator) MigrationFunc - // GetReverse returns an operation that can revert the current one. - GetReverse() Operation -} +type Detector struct{} -type RenameTable struct { - From string - To string -} +func (d *Detector) Diff(got, want sqlschema.State) Changeset { + var changes Changeset -func (rt *RenameTable) Func(m sqlschema.Migrator) MigrationFunc { - return func(ctx context.Context, db *bun.DB) error { - return m.RenameTable(ctx, rt.From, rt.To) + oldModels := newTableSet(got.Tables...) + newModels := newTableSet(want.Tables...) + + addedModels := newModels.Sub(oldModels) + +AddedLoop: + for _, added := range addedModels.Values() { + removedModels := oldModels.Sub(newModels) + for _, removed := range removedModels.Values() { + if d.canRename(added, removed) { + changes.Add(&RenameTable{ + Schema: removed.Schema, + From: removed.Name, + To: added.Name, + }) + + // TODO: check for altered columns. + + // Do not check this model further, we know it was renamed. + oldModels.Remove(removed.Name) + continue AddedLoop + } + } + // If a new table did not appear because of the rename operation, then it must've been created. + changes.Add(&CreateTable{ + Schema: added.Schema, + Name: added.Name, + Model: added.Model, + }) } -} -func (rt *RenameTable) GetReverse() Operation { - return &RenameTable{ - From: rt.To, - To: rt.From, + // Tables that aren't present anymore and weren't renamed were deleted. + for _, t := range oldModels.Sub(newModels).Values() { + changes.Add(&DropTable{ + Schema: t.Schema, + Name: t.Name, + }) } + + return changes +} + +// canRename checks if t1 can be renamed to t2. +func (d Detector) canRename(t1, t2 sqlschema.Table) bool { + return t1.Schema == t2.Schema && sqlschema.EqualSignatures(t1, t2) } // Changeset is a set of changes that alter database state. @@ -201,14 +226,24 @@ type Changeset struct { var _ Operation = (*Changeset)(nil) +func (c Changeset) String() string { + var ops []string + for _, op := range c.operations { + ops = append(ops, op.String()) + } + return strings.Join(ops, "\n") +} + func (c Changeset) Operations() []Operation { return c.operations } -func (c *Changeset) Add(op Operation) { - c.operations = append(c.operations, op) +// Add new operations to the changeset. +func (c *Changeset) Add(op ...Operation) { + c.operations = append(c.operations, op...) } +// Func chains all underlying operations in a single MigrationFunc. func (c *Changeset) Func(m sqlschema.Migrator) MigrationFunc { return func(ctx context.Context, db *bun.DB) error { for _, op := range c.operations { @@ -239,32 +274,118 @@ func (c *Changeset) Down(m sqlschema.Migrator) MigrationFunc { return c.GetReverse().Func(m) } -type Detector struct{} +// Operation is an abstraction a level above a MigrationFunc. +// Apart from storing the function to execute the change, +// it knows how to *write* the corresponding code, and what the reverse operation is. +type Operation interface { + fmt.Stringer -func (d *Detector) Diff(got, want sqlschema.State) Changeset { - var changes Changeset + Func(sqlschema.Migrator) MigrationFunc + // GetReverse returns an operation that can revert the current one. + GetReverse() Operation +} - // Detect renamed models - oldModels := newTableSet(got.Tables...) - newModels := newTableSet(want.Tables...) +// noop is a migration that doesn't change the schema. +type noop struct{} - addedModels := newModels.Sub(oldModels) - for _, added := range addedModels.Values() { - removedModels := oldModels.Sub(newModels) - for _, removed := range removedModels.Values() { - if !sqlschema.EqualSignatures(added, removed) { - continue - } - changes.Add(&RenameTable{ - From: removed.Name, - To: added.Name, - }) - } +var _ Operation = (*noop)(nil) + +func (*noop) String() string { return "noop" } +func (*noop) Func(m sqlschema.Migrator) MigrationFunc { + return func(ctx context.Context, db *bun.DB) error { return nil } +} +func (*noop) GetReverse() Operation { return &noop{} } + +type RenameTable struct { + Schema string + From string + To string +} + +var _ Operation = (*RenameTable)(nil) + +func (op RenameTable) String() string { + return fmt.Sprintf( + "Rename table %q.%q to %q.%q", + op.Schema, trimSchema(op.From), op.Schema, trimSchema(op.To), + ) +} + +func (op *RenameTable) Func(m sqlschema.Migrator) MigrationFunc { + return func(ctx context.Context, db *bun.DB) error { + return m.RenameTable(ctx, op.From, op.To) } +} - return changes +func (op *RenameTable) GetReverse() Operation { + return &RenameTable{ + From: op.To, + To: op.From, + } +} + +type CreateTable struct { + Schema string + Name string + Model interface{} } +var _ Operation = (*CreateTable)(nil) + +func (op CreateTable) String() string { + return fmt.Sprintf("CreateTable %T", op.Model) +} + +func (op *CreateTable) Func(m sqlschema.Migrator) MigrationFunc { + return func(ctx context.Context, db *bun.DB) error { + return m.CreateTable(ctx, op.Model) + } +} + +func (op *CreateTable) GetReverse() Operation { + return &DropTable{ + Schema: op.Schema, + Name: op.Name, + } +} + +type DropTable struct { + Schema string + Name string +} + +var _ Operation = (*DropTable)(nil) + +func (op DropTable) String() string { + return fmt.Sprintf("DropTable %q.%q", op.Schema, trimSchema(op.Name)) +} + +func (op *DropTable) Func(m sqlschema.Migrator) MigrationFunc { + return func(ctx context.Context, db *bun.DB) error { + return m.DropTable(ctx, op.Schema, op.Name) + } +} + +// GetReverse for a DropTable returns a no-op migration. Logically, CreateTable is the reverse, +// but DropTable does not have the table's definition to create one. +// +// TODO: we can fetch table definitions for deleted tables +// from the database engine and execute them as a raw query. +func (op *DropTable) GetReverse() Operation { + return &noop{} +} + +// trimSchema drops schema name from the table name. +// This is a workaroud until schema.Table.Schema is fully integrated with other bun packages. +func trimSchema(name string) string { + if strings.Contains(name, ".") { + return strings.Split(name, ".")[1] + } + return name +} + +// sqlschema utils ------------------------------------------------------------ + // tableSet stores unique table definitions. type tableSet struct { underlying map[string]sqlschema.Table diff --git a/migrate/sqlschema/inspector.go b/migrate/sqlschema/inspector.go index 7974b0c25..2f44f93c5 100644 --- a/migrate/sqlschema/inspector.go +++ b/migrate/sqlschema/inspector.go @@ -32,6 +32,8 @@ func NewInspector(db *bun.DB, excludeTables ...string) (Inspector, error) { }, nil } +// SchemaInspector creates the current project state from the passed bun.Models. +// Do not recycle SchemaInspector for different sets of models, as older models will not be de-registerred before the next run. type SchemaInspector struct { tables *schema.Tables } @@ -44,8 +46,6 @@ func NewSchemaInspector(tables *schema.Tables) *SchemaInspector { } } -// Inspect creates the current project state from the passed bun.Models. -// Do not recycle SchemaInspector for different sets of models, as older models will not be de-registerred before the next run. func (si *SchemaInspector) Inspect(ctx context.Context) (State, error) { var state State for _, t := range si.tables.All() { @@ -64,6 +64,7 @@ func (si *SchemaInspector) Inspect(ctx context.Context) (State, error) { state.Tables = append(state.Tables, Table{ Schema: t.Schema, Name: t.Name, + Model: t.ZeroIface, Columns: columns, }) } diff --git a/migrate/sqlschema/migrator.go b/migrate/sqlschema/migrator.go index 037c90e23..41b481f77 100644 --- a/migrate/sqlschema/migrator.go +++ b/migrate/sqlschema/migrator.go @@ -15,8 +15,11 @@ type MigratorDialect interface { type Migrator interface { RenameTable(ctx context.Context, oldName, newName string) error + CreateTable(ctx context.Context, model interface{}) error + DropTable(ctx context.Context, schema, table string) error } +// Migrator is a dialect-agnostic wrapper for sqlschema.Dialect type migrator struct { Migrator } @@ -30,3 +33,28 @@ func NewMigrator(db *bun.DB) (Migrator, error) { Migrator: md.Migrator(db), }, nil } + +// BaseMigrator can be embeded by dialect's Migrator implementations to re-use some of the existing bun queries. +type BaseMigrator struct { + db *bun.DB +} + +func NewBaseMigrator(db *bun.DB) *BaseMigrator { + return &BaseMigrator{db: db} +} + +func (m *BaseMigrator) CreateTable(ctx context.Context, model interface{}) error { + _, err := m.db.NewCreateTable().Model(model).Exec(ctx) + if err != nil { + return err + } + return nil +} + +func (m *BaseMigrator) DropTable(ctx context.Context, schema, name string) error { + _, err := m.db.NewDropTable().TableExpr("?.?", bun.Ident(schema), bun.Ident(name)).Exec(ctx) + if err != nil { + return err + } + return nil +} diff --git a/migrate/sqlschema/state.go b/migrate/sqlschema/state.go index 8b89368a4..8f7e96b0d 100644 --- a/migrate/sqlschema/state.go +++ b/migrate/sqlschema/state.go @@ -7,6 +7,7 @@ type State struct { type Table struct { Schema string Name string + Model interface{} Columns map[string]Column } diff --git a/query_table_drop.go b/query_table_drop.go index e4447a8d2..a92014515 100644 --- a/query_table_drop.go +++ b/query_table_drop.go @@ -151,3 +151,12 @@ func (q *DropTableQuery) afterDropTableHook(ctx context.Context) error { } return nil } + +func (q *DropTableQuery) String() string { + buf, err := q.AppendQuery(q.db.Formatter(), nil) + if err != nil { + panic(err) + } + + return string(buf) +} diff --git a/schema/table.go b/schema/table.go index 7be2cef95..2878c21cf 100644 --- a/schema/table.go +++ b/schema/table.go @@ -83,6 +83,7 @@ func newTable(dialect Dialect, typ reflect.Type) *Table { t.setName(tableName) t.Alias = t.ModelName t.SQLAlias = t.quoteIdent(t.ModelName) + t.Schema = t.dialect.DefaultSchema() hooks := []struct { typ reflect.Type @@ -281,7 +282,7 @@ func (t *Table) processBaseModelField(f reflect.StructField) { } if s, ok := tag.Option("table"); ok { - schema, _ := t.schemaFromTagName(tag.Name) + schema, _ := t.schemaFromTagName(s) t.Schema = schema t.setName(s) }