diff --git a/migrator.go b/migrator.go index 3f80515..ca07209 100644 --- a/migrator.go +++ b/migrator.go @@ -36,6 +36,58 @@ func (m Migrator) GetTables() (tableList []string, err error) { return tableList, m.DB.Raw("SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_CATALOG = ?", m.CurrentDatabase()).Scan(&tableList).Error } +func (m Migrator) CreateTable(values ...interface{}) (err error) { + if err = m.Migrator.CreateTable(values...); err != nil { + return + } + for _, value := range m.ReorderModels(values, false) { + if err = m.RunWithValue(value, func(stmt *gorm.Statement) (err error) { + if stmt.Schema == nil { + return + } + for _, fieldName := range stmt.Schema.DBNames { + field := stmt.Schema.FieldsByDBName[fieldName] + if field.Comment == "" { + continue + } + if err = m.setColumnComment(stmt, field, true); err != nil { + return + } + } + return + }); err != nil { + return + } + } + return +} + +func (m Migrator) setColumnComment(stmt *gorm.Statement, field *schema.Field, add bool) error { + schemaName := m.getTableSchemaName(stmt.Schema) + // add field comment + if add { + return m.DB.Exec( + "EXEC sp_addextendedproperty 'MS_Description', ?, 'SCHEMA', ?, 'TABLE', ?, 'COLUMN', ?", + field.Comment, schemaName, stmt.Table, field.DBName, + ).Error + } + // update field comment + return m.DB.Exec( + "EXEC sp_updateextendedproperty 'MS_Description', ?, 'SCHEMA', ?, 'TABLE', ?, 'COLUMN', ?", + field.Comment, schemaName, stmt.Table, field.DBName, + ).Error +} + +func (m Migrator) getTableSchemaName(schema *schema.Schema) string { + // return the schema name if it is explicitly provided in the table name + // otherwise return default schema name + schemaName := getTableSchemaName(schema) + if schemaName == "" { + schemaName = m.DefaultSchema() + } + return schemaName +} + func getTableSchemaName(schema *schema.Schema) string { // return the schema name if it is explicitly provided in the table name // otherwise return a sql wildcard -> use any table_schema @@ -141,6 +193,26 @@ func (m Migrator) RenameTable(oldName, newName interface{}) error { ).Error } +func (m Migrator) AddColumn(value interface{}, name string) error { + if err := m.Migrator.AddColumn(value, name); err != nil { + return err + } + + return m.RunWithValue(value, func(stmt *gorm.Statement) (err error) { + if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(name); field != nil { + if field.Comment == "" { + return + } + if err = m.setColumnComment(stmt, field, true); err != nil { + return + } + } + } + return + }) +} + func (m Migrator) HasColumn(value interface{}, field string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { @@ -200,6 +272,39 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error }) } +func (m Migrator) GetColumnComment(stmt *gorm.Statement, fieldDBName string) (description string) { + queryTx := m.DB + if m.DB.DryRun { + queryTx = m.DB.Session(&gorm.Session{}) + queryTx.DryRun = false + } + var comment sql.NullString + queryTx.Raw("SELECT value FROM ?.sys.fn_listextendedproperty('MS_Description', 'SCHEMA', ?, 'TABLE', ?, 'COLUMN', ?)", + gorm.Expr(m.CurrentDatabase()), m.getTableSchemaName(stmt.Schema), stmt.Table, fieldDBName).Scan(&comment) + if comment.Valid { + description = comment.String + } + return +} + +func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error { + if err := m.Migrator.MigrateColumn(value, field, columnType); err != nil { + return err + } + + return m.RunWithValue(value, func(stmt *gorm.Statement) (err error) { + description := m.GetColumnComment(stmt, field.DBName) + if field.Comment != description { + if description == "" { + err = m.setColumnComment(stmt, field, true) + } else { + err = m.setColumnComment(stmt, field, false) + } + } + return + }) +} + var defaultValueTrimRegexp = regexp.MustCompile("^\\('?([^']*)'?\\)$") // ColumnTypes return columnTypes []gorm.ColumnType and execErr error diff --git a/migrator_test.go b/migrator_test.go index fb01944..056a58a 100644 --- a/migrator_test.go +++ b/migrator_test.go @@ -188,3 +188,63 @@ func testGetMigrateColumns(db *gorm.DB, dst interface{}) (columnsWithDefault, co } return } + +type TestTableFieldComment struct { + ID string `gorm:"column:id;primaryKey"` + Name string `gorm:"column:name;comment:姓名"` + Age uint `gorm:"column:age;comment:年龄"` +} + +func (*TestTableFieldComment) TableName() string { return "test_table_field_comment" } + +type TestTableFieldCommentUpdate struct { + ID string `gorm:"column:id;primaryKey"` + Name string `gorm:"column:name;comment:姓名"` + Age uint `gorm:"column:age;comment:周岁"` + Birthday *time.Time `gorm:"column:birthday;comment:生日"` +} + +func (*TestTableFieldCommentUpdate) TableName() string { return "test_table_field_comment" } + +func TestMigrator_MigrateColumnComment(t *testing.T) { + db, err := gorm.Open(sqlserver.Open(sqlserverDSN)) + if err != nil { + t.Error(err) + } + migrator := db.Debug().Migrator() + + tableModel := new(TestTableFieldComment) + defer func() { + if err = migrator.DropTable(tableModel); err != nil { + t.Errorf("couldn't drop table %q, got error: %v", tableModel.TableName(), err) + } + }() + + if err = migrator.AutoMigrate(tableModel); err != nil { + t.Fatal(err) + } + tableModelUpdate := new(TestTableFieldCommentUpdate) + if err = migrator.AutoMigrate(tableModelUpdate); err != nil { + t.Error(err) + } + + if m, ok := migrator.(sqlserver.Migrator); ok { + stmt := db.Model(tableModelUpdate).Find(nil).Statement + if stmt == nil || stmt.Schema == nil { + t.Fatal("expected Statement.Schema, got nil") + } + + wantComments := []string{"", "姓名", "周岁", "生日"} + gotComments := make([]string, len(stmt.Schema.DBNames)) + + for i, fieldDBName := range stmt.Schema.DBNames { + comment := m.GetColumnComment(stmt, fieldDBName) + gotComments[i] = comment + } + + if !reflect.DeepEqual(wantComments, gotComments) { + t.Fatalf("expected comments %#v, got %#v", wantComments, gotComments) + } + t.Logf("got comments: %#v", gotComments) + } +}