From 384d7f8a65dfae854f5dadbf055dc8976ab208cb Mon Sep 17 00:00:00 2001 From: Jeff Ortel Date: Mon, 9 Sep 2024 11:53:21 -0500 Subject: [PATCH] :sparkles: Generate Primary Keys. (#635) Generate primary keys instead of GORM. This fixes the issue of GORM reusing the highest key after the model with that ID is deleted. When the PK is 0, GORM assigns the next (highest) ID. This approach is to assign the ID ahead of time using a pool managed by tackle. --------- Signed-off-by: Jeff Ortel --- api/migrationwave.go | 1 + cmd/main.go | 5 + database/db_test.go | 53 ++++++++++ database/pk.go | 160 +++++++++++++++++++++++++++++ database/pkg.go | 11 +- migration/v14/model/core.go | 7 ++ migration/v14/model/pkg.go | 1 + model/pkg.go | 3 + test/api/migrationwave/api_test.go | 6 ++ test/api/migrationwave/samples.go | 3 - test/api/review/api_test.go | 2 +- test/api/review/samples.go | 2 - test/api/ticket/api_test.go | 10 +- test/api/ticket/samples.go | 5 +- test/api/tracker/samples.go | 2 - 15 files changed, 257 insertions(+), 14 deletions(-) create mode 100644 database/pk.go diff --git a/api/migrationwave.go b/api/migrationwave.go index b92b23d6b..df75eb7d0 100644 --- a/api/migrationwave.go +++ b/api/migrationwave.go @@ -118,6 +118,7 @@ func (h MigrationWaveHandler) Create(ctx *gin.Context) { _ = ctx.Error(err) return } + r.With(m) h.Respond(ctx, http.StatusCreated, r) diff --git a/cmd/main.go b/cmd/main.go index 8be4667bb..91a6adcb3 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -17,6 +17,7 @@ import ( crd "github.com/konveyor/tackle2-hub/k8s/api" "github.com/konveyor/tackle2-hub/metrics" "github.com/konveyor/tackle2-hub/migration" + "github.com/konveyor/tackle2-hub/model" "github.com/konveyor/tackle2-hub/reaper" "github.com/konveyor/tackle2-hub/seed" "github.com/konveyor/tackle2-hub/settings" @@ -53,6 +54,10 @@ func Setup() (db *gorm.DB, err error) { if err != nil { return } + err = database.PK.Load(db, model.ALL) + if err != nil { + return + } return } diff --git a/database/db_test.go b/database/db_test.go index 37102df91..dde8c6fae 100644 --- a/database/db_test.go +++ b/database/db_test.go @@ -69,3 +69,56 @@ func TestConcurrent(t *testing.T) { fmt.Printf("Done %d\n", id) } } + +func TestKeyGen(t *testing.T) { + pid := os.Getpid() + Settings.DB.Path = fmt.Sprintf("/tmp/keygen-%d.db", pid) + defer func() { + _ = os.Remove(Settings.DB.Path) + }() + db, err := Open(true) + if err != nil { + panic(err) + } + // ids 1-7 created. + N = 8 + for n := 1; n < N; n++ { + m := &model.Setting{Key: fmt.Sprintf("key-%d", n), Value: n} + err := db.Create(m).Error + if err != nil { + panic(err) + } + fmt.Printf("CREATED: %d/%d\n", m.ID, n) + if uint(n) != m.ID { + t.Errorf("id:%d but expected: %d", m.ID, n) + return + } + } + // delete ids=2,4,7. + err = db.Delete(&model.Setting{}, []uint{2, 4, 7}).Error + if err != nil { + panic(err) + } + + var count int64 + err = db.Model(&model.Setting{}).Where([]uint{2, 4, 7}).Count(&count).Error + if err != nil { + panic(err) + } + if count > 0 { + t.Errorf("DELETED ids: 2,4,7 found.") + return + } + // id=8 (next) created. + next := N + m := &model.Setting{Key: fmt.Sprintf("key-%d", next), Value: next} + err = db.Create(m).Error + if err != nil { + panic(err) + } + fmt.Printf("CREATED: %d/%d (next)\n", m.ID, next) + if uint(N) != m.ID { + t.Errorf("id:%d but expected: %d", m.ID, next) + return + } +} diff --git a/database/pk.go b/database/pk.go new file mode 100644 index 000000000..d9f19d49a --- /dev/null +++ b/database/pk.go @@ -0,0 +1,160 @@ +package database + +import ( + "errors" + "reflect" + "strings" + "sync" + + "github.com/konveyor/tackle2-hub/model" + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/logger" +) + +// PK singleton pk sequence. +var PK PkSequence + +// PkSequence provides a primary key sequence. +type PkSequence struct { + mutex sync.Mutex +} + +// Load highest key for all models. +func (r *PkSequence) Load(db *gorm.DB, models []any) (err error) { + r.mutex.Lock() + defer r.mutex.Unlock() + for _, m := range models { + mt := reflect.TypeOf(m) + if mt.Kind() == reflect.Ptr { + mt = mt.Elem() + } + kind := strings.ToUpper(mt.Name()) + db = r.session(db) + q := db.Table(kind) + q = q.Select("MAX(ID) id") + cursor, err := q.Rows() + if err != nil || !cursor.Next() { + // not a table with id. + // discarded. + continue + } + id := int64(0) + err = cursor.Scan(&id) + _ = cursor.Close() + if err != nil { + r.add(db, kind, uint(0)) + } else { + r.add(db, kind, uint(id)) + } + } + return +} + +// Next returns the next primary key. +func (r *PkSequence) Next(db *gorm.DB) (id uint) { + r.mutex.Lock() + defer r.mutex.Unlock() + kind := strings.ToUpper(db.Statement.Table) + m := &model.PK{} + db = r.session(db) + err := db.First(m, "Kind", kind).Error + if err != nil { + return + } + m.LastID++ + id = m.LastID + err = db.Save(m).Error + if err != nil { + panic(err) + } + return +} + +// session returns a new DB with a new session. +func (r *PkSequence) session(in *gorm.DB) (out *gorm.DB) { + out = &gorm.DB{ + Config: in.Config, + } + out.Config.Logger.LogMode(logger.Warn) + out.Statement = &gorm.Statement{ + DB: out, + ConnPool: in.Statement.ConnPool, + Context: in.Statement.Context, + Clauses: map[string]clause.Clause{}, + Vars: make([]interface{}, 0, 8), + } + return +} + +// add the last (higher) id for the kind. +func (r *PkSequence) add(db *gorm.DB, kind string, id uint) { + m := &model.PK{Kind: kind} + db = r.session(db) + err := db.First(m).Error + if err != nil { + if !errors.Is(err, gorm.ErrRecordNotFound) { + panic(err) + } + } + if m.LastID > id { + return + } + m.LastID = id + db = r.session(db) + err = db.Save(m).Error + if err != nil { + panic(err) + } +} + +// assignPk assigns PK as needed. +func assignPk(db *gorm.DB) { + statement := db.Statement + schema := statement.Schema + if schema == nil { + return + } + switch statement.ReflectValue.Kind() { + case reflect.Slice, + reflect.Array: + for i := 0; i < statement.ReflectValue.Len(); i++ { + for _, f := range schema.Fields { + if f.Name != "ID" { + continue + } + _, isZero := f.ValueOf( + statement.Context, + statement.ReflectValue.Index(i)) + if isZero { + id := PK.Next(db) + _ = f.Set( + statement.Context, + statement.ReflectValue.Index(i), + id) + + } + break + } + } + case reflect.Struct: + for _, f := range schema.Fields { + if f.Name != "ID" { + continue + } + _, isZero := f.ValueOf( + statement.Context, + statement.ReflectValue) + if isZero { + id := PK.Next(db) + _ = f.Set( + statement.Context, + statement.ReflectValue, + id) + } + break + } + default: + log.Info("[WARN] assignPk: unknown kind.") + } +} diff --git a/database/pkg.go b/database/pkg.go index b513c1c6c..6e9c5eb20 100644 --- a/database/pkg.go +++ b/database/pkg.go @@ -51,7 +51,16 @@ func Open(enforceFKs bool) (db *gorm.DB, err error) { err = liberr.Wrap(err) return } - err = db.AutoMigrate(model.Setting{}) + err = db.AutoMigrate(model.PK{}, model.Setting{}) + if err != nil { + err = liberr.Wrap(err) + return + } + err = PK.Load(db, []any{model.Setting{}}) + if err != nil { + return + } + err = db.Callback().Create().Before("gorm:before_create").Register("assign-pk", assignPk) if err != nil { err = liberr.Wrap(err) return diff --git a/migration/v14/model/core.go b/migration/v14/model/core.go index a5d6a5ab1..4f384bb6b 100644 --- a/migration/v14/model/core.go +++ b/migration/v14/model/core.go @@ -20,6 +20,13 @@ type Model struct { UpdateUser string } +// PK sequence. +type PK struct { + Kind string `gorm:"<-:create;primaryKey"` + LastID uint +} + +// Setting hub settings. type Setting struct { Model Key string `gorm:"<-:create;uniqueIndex"` diff --git a/migration/v14/model/pkg.go b/migration/v14/model/pkg.go index 6827e3c96..8f612b488 100644 --- a/migration/v14/model/pkg.go +++ b/migration/v14/model/pkg.go @@ -32,6 +32,7 @@ func All() []any { ImportTag{}, JobFunction{}, MigrationWave{}, + PK{}, Proxy{}, Review{}, Setting{}, diff --git a/model/pkg.go b/model/pkg.go index 5cf00cd5a..d6b546fb3 100644 --- a/model/pkg.go +++ b/model/pkg.go @@ -8,6 +8,8 @@ import ( // Field (data) types. type JSON = model.JSON +var ALL = model.All() + // Models type Model = model.Model type Application = model.Application @@ -29,6 +31,7 @@ type ImportSummary = model.ImportSummary type ImportTag = model.ImportTag type JobFunction = model.JobFunction type MigrationWave = model.MigrationWave +type PK = model.PK type Proxy = model.Proxy type Questionnaire = model.Questionnaire type Review = model.Review diff --git a/test/api/migrationwave/api_test.go b/test/api/migrationwave/api_test.go index afa2ac846..7c2622a84 100644 --- a/test/api/migrationwave/api_test.go +++ b/test/api/migrationwave/api_test.go @@ -18,6 +18,7 @@ func TestMigrationWaveCRUD(t *testing.T) { } assert.Must(t, Application.Create(&expectedApp)) createdApps = append(createdApps, expectedApp) + r.Applications[0].ID = expectedApp.ID } createdStakeholders := []api.Stakeholder{} @@ -28,6 +29,7 @@ func TestMigrationWaveCRUD(t *testing.T) { } assert.Must(t, Stakeholder.Create(&expectedStakeholder)) createdStakeholders = append(createdStakeholders, expectedStakeholder) + r.Stakeholders[0].ID = expectedStakeholder.ID } createdStakeholderGroups := []api.StakeholderGroup{} @@ -38,6 +40,7 @@ func TestMigrationWaveCRUD(t *testing.T) { } assert.Must(t, StakeholderGroup.Create(&expectedStakeholderGroup)) createdStakeholderGroups = append(createdStakeholderGroups, expectedStakeholderGroup) + r.StakeholderGroups[0].ID = expectedStakeholderGroup.ID } assert.Must(t, MigrationWave.Create(&r)) @@ -102,6 +105,7 @@ func TestMigrationWaveList(t *testing.T) { } assert.Must(t, Application.Create(&expectedApp)) createdApps = append(createdApps, expectedApp) + r.Applications[0].ID = expectedApp.ID } for _, stakeholder := range r.Stakeholders { @@ -111,6 +115,7 @@ func TestMigrationWaveList(t *testing.T) { } assert.Must(t, Stakeholder.Create(&expectedStakeholder)) createdStakeholders = append(createdStakeholders, expectedStakeholder) + r.Stakeholders[0].ID = expectedStakeholder.ID } for _, stakeholderGroup := range r.StakeholderGroups { @@ -120,6 +125,7 @@ func TestMigrationWaveList(t *testing.T) { } assert.Must(t, StakeholderGroup.Create(&expectedStakeholderGroup)) createdStakeholderGroups = append(createdStakeholderGroups, expectedStakeholderGroup) + r.StakeholderGroups[0].ID = expectedStakeholderGroup.ID } assert.Must(t, MigrationWave.Create(&r)) createdMigrationWaves = append(createdMigrationWaves, r) diff --git a/test/api/migrationwave/samples.go b/test/api/migrationwave/samples.go index 4c0ef9fa1..535d5dbae 100644 --- a/test/api/migrationwave/samples.go +++ b/test/api/migrationwave/samples.go @@ -13,19 +13,16 @@ var Samples = []api.MigrationWave{ EndDate: time.Date(time.Now().Year(), time.Now().Month(), time.Now().Day(), 0, 0, 0, 0, time.Local).Add(30 * time.Minute), Applications: []api.Ref{ { - ID: 1, Name: "Sample Application", }, }, Stakeholders: []api.Ref{ { - ID: 1, Name: "Sample Stakeholders", }, }, StakeholderGroups: []api.Ref{ { - ID: 1, Name: "Sample Stakeholders Groups", }, }, diff --git a/test/api/review/api_test.go b/test/api/review/api_test.go index 013f84842..d32555859 100644 --- a/test/api/review/api_test.go +++ b/test/api/review/api_test.go @@ -143,7 +143,7 @@ func TestReviewList(t *testing.T) { // Delete related reviews and applications. for _, review := range createdReviews { - assert.Must(t, Application.Delete(review.ID)) + assert.Must(t, Application.Delete(review.Application.ID)) assert.Must(t, Review.Delete(review.ID)) } } diff --git a/test/api/review/samples.go b/test/api/review/samples.go index 4fb33b981..64df44e99 100644 --- a/test/api/review/samples.go +++ b/test/api/review/samples.go @@ -12,7 +12,6 @@ var Samples = []api.Review{ WorkPriority: 1, Comments: "nil", Application: &api.Ref{ - ID: 1, Name: "Sample Review 1", }, }, @@ -23,7 +22,6 @@ var Samples = []api.Review{ WorkPriority: 2, Comments: "nil", Application: &api.Ref{ - ID: 2, Name: "Sample Review 2", }, }, diff --git a/test/api/ticket/api_test.go b/test/api/ticket/api_test.go index af6f277c1..71ea3d178 100644 --- a/test/api/ticket/api_test.go +++ b/test/api/ticket/api_test.go @@ -17,6 +17,7 @@ func TestTicketCRUD(t *testing.T) { Name: r.Application.Name, } assert.Must(t, Application.Create(&app)) + r.Application.ID = app.ID createdIdentities := []api.Identity{} createdTrackers := []api.Tracker{} @@ -27,8 +28,11 @@ func TestTicketCRUD(t *testing.T) { Kind: tracker.Kind, } assert.Must(t, Identity.Create(&identity)) + tracker.Identity.ID = identity.ID createdIdentities = append(createdIdentities, identity) assert.Must(t, Tracker.Create(&tracker)) + r.Tracker.ID = tracker.ID + r.Tracker.Name = tracker.Name createdTrackers = append(createdTrackers, tracker) } @@ -72,6 +76,7 @@ func TestTicketList(t *testing.T) { Name: r.Application.Name, } assert.Must(t, Application.Create(&app)) + r.Application.ID = app.ID createdIdentities := []api.Identity{} createdTrackers := []api.Tracker{} @@ -82,8 +87,11 @@ func TestTicketList(t *testing.T) { Kind: tracker.Kind, } assert.Must(t, Identity.Create(&identity)) + tracker.Identity.ID = identity.ID createdIdentities = append(createdIdentities, identity) assert.Must(t, Tracker.Create(&tracker)) + r.Tracker.ID = tracker.ID + r.Tracker.Name = tracker.Name createdTrackers = append(createdTrackers, tracker) } @@ -113,7 +121,7 @@ func TestTicketList(t *testing.T) { // Delete tickets and related resources. for _, ticket := range createdTickets { assert.Must(t, Ticket.Delete(ticket.ID)) - assert.Must(t, Application.Delete(ticket.ID)) + assert.Must(t, Application.Delete(ticket.Application.ID)) } for _, tracker := range createdTrackers { assert.Must(t, Tracker.Delete(tracker.ID)) diff --git a/test/api/ticket/samples.go b/test/api/ticket/samples.go index 5a74246ad..a4fe35eb4 100644 --- a/test/api/ticket/samples.go +++ b/test/api/ticket/samples.go @@ -2,7 +2,6 @@ package ticket import ( "github.com/konveyor/tackle2-hub/api" - TrackerSamples "github.com/konveyor/tackle2-hub/test/api/tracker" ) var Samples = []api.Ticket{ @@ -10,12 +9,10 @@ var Samples = []api.Ticket{ Kind: "10001", Parent: "10000", Application: api.Ref{ - ID: 1, Name: "Sample Application1", }, Tracker: api.Ref{ - ID: 1, - Name: TrackerSamples.Samples[0].Name, + Name: "Sample Ticket-Tracker", }, }, } diff --git a/test/api/tracker/samples.go b/test/api/tracker/samples.go index 1face0d01..e9701e729 100644 --- a/test/api/tracker/samples.go +++ b/test/api/tracker/samples.go @@ -14,7 +14,6 @@ var Samples = []api.Tracker{ Message: "Description of tracker", LastUpdated: time.Date(time.Now().Year(), time.Now().Month(), time.Now().Day(), 0, 0, 0, 0, time.Local), Identity: api.Ref{ - ID: 1, Name: "Sample Tracker Identity", }, Insecure: false, @@ -26,7 +25,6 @@ var Samples = []api.Tracker{ Message: "Description of tracker1", LastUpdated: time.Date(time.Now().Year(), time.Now().Month(), time.Now().Day(), 0, 0, 0, 0, time.Local), Identity: api.Ref{ - ID: 2, Name: "Sample Tracker Identity1", }, Insecure: false,