diff --git a/migrator.go b/migrator.go index bea2422..3f80515 100644 --- a/migrator.go +++ b/migrator.go @@ -256,8 +256,6 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { column.DefaultValueValue.String = matches[1] matches = defaultValueTrimRegexp.FindStringSubmatch(column.DefaultValueValue.String) } - } else { - column.DefaultValueValue.Valid = true } for _, c := range rawColumnTypes { diff --git a/migrator_test.go b/migrator_test.go index bfeffd4..fb01944 100644 --- a/migrator_test.go +++ b/migrator_test.go @@ -2,7 +2,9 @@ package sqlserver_test import ( "os" + "reflect" "testing" + "time" "gorm.io/driver/sqlserver" "gorm.io/gorm" @@ -115,3 +117,74 @@ func TestCreateIndex(t *testing.T) { t.Error("couldn't drop table testtable", tx.Error) } } + +type TestTableDefaultValue struct { + ID string `gorm:"column:id;primaryKey"` + Name string `gorm:"column:name"` + Age uint `gorm:"column:age"` + Birthday *time.Time `gorm:"column:birthday"` + CompanyID *int `gorm:"column:company_id;default:0"` + ManagerID *uint `gorm:"column:manager_id;default:0"` + Active bool `gorm:"column:active;default:1"` +} + +func (*TestTableDefaultValue) TableName() string { return "test_table_default_value" } + +func TestReMigrateTableFieldsWithoutDefaultValue(t *testing.T) { + db, err := gorm.Open(sqlserver.Open(sqlserverDSN)) + if err != nil { + t.Error(err) + } + + var ( + migrator = db.Migrator() + tableModel = new(TestTableDefaultValue) + fieldsWithDefault = []string{"company_id", "manager_id", "active"} + fieldsWithoutDefault = []string{"id", "name", "age", "birthday"} + + columnsWithDefault []string + columnsWithoutDefault []string + ) + + defer func() { + if err = migrator.DropTable(tableModel); err != nil { + t.Errorf("couldn't drop table %q, got error: %v", tableModel.TableName(), err) + } + }() + if !migrator.HasTable(tableModel) { + if err = migrator.AutoMigrate(tableModel); err != nil { + t.Errorf("couldn't auto migrate table %q, got error: %v", tableModel.TableName(), err) + } + } + // If in the `Migrator.ColumnTypes` method `column.DefaultValueValue.Valid = true`, + // re-migrate the table will alter all fields without default value except for the primary key. + if err = db.Debug().Migrator().AutoMigrate(tableModel); err != nil { + t.Errorf("couldn't re-migrate table %q, got error: %v", tableModel.TableName(), err) + } + + columnsWithDefault, columnsWithoutDefault, err = testGetMigrateColumns(db, tableModel) + if !reflect.DeepEqual(columnsWithDefault, fieldsWithDefault) { + // If in the `Migrator.ColumnTypes` method `column.DefaultValueValue.Valid = true`, + // fields with default value will include all fields: `[id name age birthday company_id manager_id active]`. + t.Errorf("expected columns with default value %v, got %v", fieldsWithDefault, columnsWithDefault) + } + if !reflect.DeepEqual(columnsWithoutDefault, fieldsWithoutDefault) { + t.Errorf("expected columns without default value %v, got %v", fieldsWithoutDefault, columnsWithoutDefault) + } +} + +func testGetMigrateColumns(db *gorm.DB, dst interface{}) (columnsWithDefault, columnsWithoutDefault []string, err error) { + migrator := db.Migrator() + var columnTypes []gorm.ColumnType + if columnTypes, err = migrator.ColumnTypes(dst); err != nil { + return + } + for _, columnType := range columnTypes { + if _, ok := columnType.DefaultValue(); ok { + columnsWithDefault = append(columnsWithDefault, columnType.Name()) + } else { + columnsWithoutDefault = append(columnsWithoutDefault, columnType.Name()) + } + } + return +}