From 18e5c4a9b0a58aae3d1a29cbc003862f5b8845d3 Mon Sep 17 00:00:00 2001 From: Jeff Ortel Date: Mon, 19 Aug 2024 15:57:54 -0700 Subject: [PATCH] Fix primary keys reused. Signed-off-by: Jeff Ortel --- api/migrationwave.go | 1 + cmd/main.go | 5 + database/pk.go | 160 +++++++++++++++++++++++++++++ database/pkg.go | 11 +- migration/v14/model/core.go | 7 ++ migration/v14/model/pkg.go | 1 + model/pkg.go | 3 + test/api/migrationwave/api_test.go | 6 ++ test/api/review/api_test.go | 2 +- test/api/ticket/api_test.go | 10 +- 10 files changed, 203 insertions(+), 3 deletions(-) create mode 100644 database/pk.go diff --git a/api/migrationwave.go b/api/migrationwave.go index b92b23d6b..df75eb7d0 100644 --- a/api/migrationwave.go +++ b/api/migrationwave.go @@ -118,6 +118,7 @@ func (h MigrationWaveHandler) Create(ctx *gin.Context) { _ = ctx.Error(err) return } + r.With(m) h.Respond(ctx, http.StatusCreated, r) diff --git a/cmd/main.go b/cmd/main.go index 8be4667bb..91a6adcb3 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -17,6 +17,7 @@ import ( crd "github.com/konveyor/tackle2-hub/k8s/api" "github.com/konveyor/tackle2-hub/metrics" "github.com/konveyor/tackle2-hub/migration" + "github.com/konveyor/tackle2-hub/model" "github.com/konveyor/tackle2-hub/reaper" "github.com/konveyor/tackle2-hub/seed" "github.com/konveyor/tackle2-hub/settings" @@ -53,6 +54,10 @@ func Setup() (db *gorm.DB, err error) { if err != nil { return } + err = database.PK.Load(db, model.ALL) + if err != nil { + return + } return } diff --git a/database/pk.go b/database/pk.go new file mode 100644 index 000000000..d9f19d49a --- /dev/null +++ b/database/pk.go @@ -0,0 +1,160 @@ +package database + +import ( + "errors" + "reflect" + "strings" + "sync" + + "github.com/konveyor/tackle2-hub/model" + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/logger" +) + +// PK singleton pk sequence. +var PK PkSequence + +// PkSequence provides a primary key sequence. +type PkSequence struct { + mutex sync.Mutex +} + +// Load highest key for all models. +func (r *PkSequence) Load(db *gorm.DB, models []any) (err error) { + r.mutex.Lock() + defer r.mutex.Unlock() + for _, m := range models { + mt := reflect.TypeOf(m) + if mt.Kind() == reflect.Ptr { + mt = mt.Elem() + } + kind := strings.ToUpper(mt.Name()) + db = r.session(db) + q := db.Table(kind) + q = q.Select("MAX(ID) id") + cursor, err := q.Rows() + if err != nil || !cursor.Next() { + // not a table with id. + // discarded. + continue + } + id := int64(0) + err = cursor.Scan(&id) + _ = cursor.Close() + if err != nil { + r.add(db, kind, uint(0)) + } else { + r.add(db, kind, uint(id)) + } + } + return +} + +// Next returns the next primary key. +func (r *PkSequence) Next(db *gorm.DB) (id uint) { + r.mutex.Lock() + defer r.mutex.Unlock() + kind := strings.ToUpper(db.Statement.Table) + m := &model.PK{} + db = r.session(db) + err := db.First(m, "Kind", kind).Error + if err != nil { + return + } + m.LastID++ + id = m.LastID + err = db.Save(m).Error + if err != nil { + panic(err) + } + return +} + +// session returns a new DB with a new session. +func (r *PkSequence) session(in *gorm.DB) (out *gorm.DB) { + out = &gorm.DB{ + Config: in.Config, + } + out.Config.Logger.LogMode(logger.Warn) + out.Statement = &gorm.Statement{ + DB: out, + ConnPool: in.Statement.ConnPool, + Context: in.Statement.Context, + Clauses: map[string]clause.Clause{}, + Vars: make([]interface{}, 0, 8), + } + return +} + +// add the last (higher) id for the kind. +func (r *PkSequence) add(db *gorm.DB, kind string, id uint) { + m := &model.PK{Kind: kind} + db = r.session(db) + err := db.First(m).Error + if err != nil { + if !errors.Is(err, gorm.ErrRecordNotFound) { + panic(err) + } + } + if m.LastID > id { + return + } + m.LastID = id + db = r.session(db) + err = db.Save(m).Error + if err != nil { + panic(err) + } +} + +// assignPk assigns PK as needed. +func assignPk(db *gorm.DB) { + statement := db.Statement + schema := statement.Schema + if schema == nil { + return + } + switch statement.ReflectValue.Kind() { + case reflect.Slice, + reflect.Array: + for i := 0; i < statement.ReflectValue.Len(); i++ { + for _, f := range schema.Fields { + if f.Name != "ID" { + continue + } + _, isZero := f.ValueOf( + statement.Context, + statement.ReflectValue.Index(i)) + if isZero { + id := PK.Next(db) + _ = f.Set( + statement.Context, + statement.ReflectValue.Index(i), + id) + + } + break + } + } + case reflect.Struct: + for _, f := range schema.Fields { + if f.Name != "ID" { + continue + } + _, isZero := f.ValueOf( + statement.Context, + statement.ReflectValue) + if isZero { + id := PK.Next(db) + _ = f.Set( + statement.Context, + statement.ReflectValue, + id) + } + break + } + default: + log.Info("[WARN] assignPk: unknown kind.") + } +} diff --git a/database/pkg.go b/database/pkg.go index b513c1c6c..6e9c5eb20 100644 --- a/database/pkg.go +++ b/database/pkg.go @@ -51,7 +51,16 @@ func Open(enforceFKs bool) (db *gorm.DB, err error) { err = liberr.Wrap(err) return } - err = db.AutoMigrate(model.Setting{}) + err = db.AutoMigrate(model.PK{}, model.Setting{}) + if err != nil { + err = liberr.Wrap(err) + return + } + err = PK.Load(db, []any{model.Setting{}}) + if err != nil { + return + } + err = db.Callback().Create().Before("gorm:before_create").Register("assign-pk", assignPk) if err != nil { err = liberr.Wrap(err) return diff --git a/migration/v14/model/core.go b/migration/v14/model/core.go index a5d6a5ab1..4f384bb6b 100644 --- a/migration/v14/model/core.go +++ b/migration/v14/model/core.go @@ -20,6 +20,13 @@ type Model struct { UpdateUser string } +// PK sequence. +type PK struct { + Kind string `gorm:"<-:create;primaryKey"` + LastID uint +} + +// Setting hub settings. type Setting struct { Model Key string `gorm:"<-:create;uniqueIndex"` diff --git a/migration/v14/model/pkg.go b/migration/v14/model/pkg.go index 6827e3c96..8f612b488 100644 --- a/migration/v14/model/pkg.go +++ b/migration/v14/model/pkg.go @@ -32,6 +32,7 @@ func All() []any { ImportTag{}, JobFunction{}, MigrationWave{}, + PK{}, Proxy{}, Review{}, Setting{}, diff --git a/model/pkg.go b/model/pkg.go index 5cf00cd5a..d6b546fb3 100644 --- a/model/pkg.go +++ b/model/pkg.go @@ -8,6 +8,8 @@ import ( // Field (data) types. type JSON = model.JSON +var ALL = model.All() + // Models type Model = model.Model type Application = model.Application @@ -29,6 +31,7 @@ type ImportSummary = model.ImportSummary type ImportTag = model.ImportTag type JobFunction = model.JobFunction type MigrationWave = model.MigrationWave +type PK = model.PK type Proxy = model.Proxy type Questionnaire = model.Questionnaire type Review = model.Review diff --git a/test/api/migrationwave/api_test.go b/test/api/migrationwave/api_test.go index afa2ac846..7c2622a84 100644 --- a/test/api/migrationwave/api_test.go +++ b/test/api/migrationwave/api_test.go @@ -18,6 +18,7 @@ func TestMigrationWaveCRUD(t *testing.T) { } assert.Must(t, Application.Create(&expectedApp)) createdApps = append(createdApps, expectedApp) + r.Applications[0].ID = expectedApp.ID } createdStakeholders := []api.Stakeholder{} @@ -28,6 +29,7 @@ func TestMigrationWaveCRUD(t *testing.T) { } assert.Must(t, Stakeholder.Create(&expectedStakeholder)) createdStakeholders = append(createdStakeholders, expectedStakeholder) + r.Stakeholders[0].ID = expectedStakeholder.ID } createdStakeholderGroups := []api.StakeholderGroup{} @@ -38,6 +40,7 @@ func TestMigrationWaveCRUD(t *testing.T) { } assert.Must(t, StakeholderGroup.Create(&expectedStakeholderGroup)) createdStakeholderGroups = append(createdStakeholderGroups, expectedStakeholderGroup) + r.StakeholderGroups[0].ID = expectedStakeholderGroup.ID } assert.Must(t, MigrationWave.Create(&r)) @@ -102,6 +105,7 @@ func TestMigrationWaveList(t *testing.T) { } assert.Must(t, Application.Create(&expectedApp)) createdApps = append(createdApps, expectedApp) + r.Applications[0].ID = expectedApp.ID } for _, stakeholder := range r.Stakeholders { @@ -111,6 +115,7 @@ func TestMigrationWaveList(t *testing.T) { } assert.Must(t, Stakeholder.Create(&expectedStakeholder)) createdStakeholders = append(createdStakeholders, expectedStakeholder) + r.Stakeholders[0].ID = expectedStakeholder.ID } for _, stakeholderGroup := range r.StakeholderGroups { @@ -120,6 +125,7 @@ func TestMigrationWaveList(t *testing.T) { } assert.Must(t, StakeholderGroup.Create(&expectedStakeholderGroup)) createdStakeholderGroups = append(createdStakeholderGroups, expectedStakeholderGroup) + r.StakeholderGroups[0].ID = expectedStakeholderGroup.ID } assert.Must(t, MigrationWave.Create(&r)) createdMigrationWaves = append(createdMigrationWaves, r) diff --git a/test/api/review/api_test.go b/test/api/review/api_test.go index 013f84842..d32555859 100644 --- a/test/api/review/api_test.go +++ b/test/api/review/api_test.go @@ -143,7 +143,7 @@ func TestReviewList(t *testing.T) { // Delete related reviews and applications. for _, review := range createdReviews { - assert.Must(t, Application.Delete(review.ID)) + assert.Must(t, Application.Delete(review.Application.ID)) assert.Must(t, Review.Delete(review.ID)) } } diff --git a/test/api/ticket/api_test.go b/test/api/ticket/api_test.go index af6f277c1..71ea3d178 100644 --- a/test/api/ticket/api_test.go +++ b/test/api/ticket/api_test.go @@ -17,6 +17,7 @@ func TestTicketCRUD(t *testing.T) { Name: r.Application.Name, } assert.Must(t, Application.Create(&app)) + r.Application.ID = app.ID createdIdentities := []api.Identity{} createdTrackers := []api.Tracker{} @@ -27,8 +28,11 @@ func TestTicketCRUD(t *testing.T) { Kind: tracker.Kind, } assert.Must(t, Identity.Create(&identity)) + tracker.Identity.ID = identity.ID createdIdentities = append(createdIdentities, identity) assert.Must(t, Tracker.Create(&tracker)) + r.Tracker.ID = tracker.ID + r.Tracker.Name = tracker.Name createdTrackers = append(createdTrackers, tracker) } @@ -72,6 +76,7 @@ func TestTicketList(t *testing.T) { Name: r.Application.Name, } assert.Must(t, Application.Create(&app)) + r.Application.ID = app.ID createdIdentities := []api.Identity{} createdTrackers := []api.Tracker{} @@ -82,8 +87,11 @@ func TestTicketList(t *testing.T) { Kind: tracker.Kind, } assert.Must(t, Identity.Create(&identity)) + tracker.Identity.ID = identity.ID createdIdentities = append(createdIdentities, identity) assert.Must(t, Tracker.Create(&tracker)) + r.Tracker.ID = tracker.ID + r.Tracker.Name = tracker.Name createdTrackers = append(createdTrackers, tracker) } @@ -113,7 +121,7 @@ func TestTicketList(t *testing.T) { // Delete tickets and related resources. for _, ticket := range createdTickets { assert.Must(t, Ticket.Delete(ticket.ID)) - assert.Must(t, Application.Delete(ticket.ID)) + assert.Must(t, Application.Delete(ticket.Application.ID)) } for _, tracker := range createdTrackers { assert.Must(t, Tracker.Delete(tracker.ID))