From 5e43d200c6780343db6a60d1f991433c084d72b0 Mon Sep 17 00:00:00 2001 From: David Juhasz Date: Tue, 26 Mar 2024 17:40:46 -0700 Subject: [PATCH 1/9] Add tests for DB data converters --- internal/db/convert_test.go | 41 +++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 internal/db/convert_test.go diff --git a/internal/db/convert_test.go b/internal/db/convert_test.go new file mode 100644 index 000000000..00b4a267c --- /dev/null +++ b/internal/db/convert_test.go @@ -0,0 +1,41 @@ +package db_test + +import ( + "database/sql" + "testing" + "time" + + "github.com/artefactual-sdps/enduro/internal/db" + "gotest.tools/v3/assert" +) + +func TestFormatOptionalString(t *testing.T) { + t.Run("Returns nil pointer for an empty string", func(t *testing.T) { + t.Parallel() + got := db.FormatOptionalString("") + assert.Assert(t, got == nil) + }) + + t.Run("Returns a pointer to a string", func(t *testing.T) { + t.Parallel() + got := db.FormatOptionalString("foo") + assert.Equal(t, *got, "foo") + }) +} + +func TestFormatOptionalTime(t *testing.T) { + t.Run("Returns nil pointer for null time", func(t *testing.T) { + t.Parallel() + got := db.FormatOptionalTime(sql.NullTime{}) + assert.Assert(t, got == nil) + }) + + t.Run("Returns an RFC3339 time string", func(t *testing.T) { + t.Parallel() + got := db.FormatOptionalTime(sql.NullTime{ + Time: time.Date(2024, 3, 6, 11, 57, 17, 115, time.UTC), + Valid: true, + }) + assert.Equal(t, *got, "2024-03-06T11:57:17Z") + }) +} From 07557576323fc40454ef643f63f7ada1006dffb7 Mon Sep 17 00:00:00 2001 From: David Juhasz Date: Tue, 26 Mar 2024 16:57:42 -0700 Subject: [PATCH 2/9] Make ent timestamps nullable - Make the ent preservation_action `StartedAt` and `CompletedAt` columns nullable - Make the ent preservation_task `StartedAt` and `CompletedAt` columns nullable --- internal/persistence/ent/db/migrate/schema.go | 8 +- internal/persistence/ent/db/mutation.go | 86 ++++++++++++++++++- .../ent/db/preservationaction/where.go | 20 +++++ .../ent/db/preservationaction_create.go | 22 +++-- .../ent/db/preservationaction_update.go | 36 ++++++++ .../ent/db/preservationtask/where.go | 20 +++++ .../ent/db/preservationtask_create.go | 22 +++-- .../ent/db/preservationtask_update.go | 36 ++++++++ .../ent/schema/preservation_action.go | 6 +- .../ent/schema/preservation_task.go | 6 +- 10 files changed, 240 insertions(+), 22 deletions(-) diff --git a/internal/persistence/ent/db/migrate/schema.go b/internal/persistence/ent/db/migrate/schema.go index 34daf46d5..77c1c081e 100644 --- a/internal/persistence/ent/db/migrate/schema.go +++ b/internal/persistence/ent/db/migrate/schema.go @@ -69,8 +69,8 @@ var ( {Name: "workflow_id", Type: field.TypeString, Size: 255}, {Name: "type", Type: field.TypeInt8}, {Name: "status", Type: field.TypeInt8}, - {Name: "started_at", Type: field.TypeTime}, - {Name: "completed_at", Type: field.TypeTime}, + {Name: "started_at", Type: field.TypeTime, Nullable: true}, + {Name: "completed_at", Type: field.TypeTime, Nullable: true}, {Name: "package_id", Type: field.TypeInt}, } // PreservationActionTable holds the schema information for the "preservation_action" table. @@ -93,8 +93,8 @@ var ( {Name: "task_id", Type: field.TypeUUID}, {Name: "name", Type: field.TypeString, Size: 2048}, {Name: "status", Type: field.TypeInt8}, - {Name: "started_at", Type: field.TypeTime}, - {Name: "completed_at", Type: field.TypeTime}, + {Name: "started_at", Type: field.TypeTime, Nullable: true}, + {Name: "completed_at", Type: field.TypeTime, Nullable: true}, {Name: "note", Type: field.TypeString, Size: 2147483647}, {Name: "preservation_action_id", Type: field.TypeInt}, } diff --git a/internal/persistence/ent/db/mutation.go b/internal/persistence/ent/db/mutation.go index b1397b059..13c7cd1ec 100644 --- a/internal/persistence/ent/db/mutation.go +++ b/internal/persistence/ent/db/mutation.go @@ -1299,9 +1299,22 @@ func (m *PreservationActionMutation) OldStartedAt(ctx context.Context) (v time.T return oldValue.StartedAt, nil } +// ClearStartedAt clears the value of the "started_at" field. +func (m *PreservationActionMutation) ClearStartedAt() { + m.started_at = nil + m.clearedFields[preservationaction.FieldStartedAt] = struct{}{} +} + +// StartedAtCleared returns if the "started_at" field was cleared in this mutation. +func (m *PreservationActionMutation) StartedAtCleared() bool { + _, ok := m.clearedFields[preservationaction.FieldStartedAt] + return ok +} + // ResetStartedAt resets all changes to the "started_at" field. func (m *PreservationActionMutation) ResetStartedAt() { m.started_at = nil + delete(m.clearedFields, preservationaction.FieldStartedAt) } // SetCompletedAt sets the "completed_at" field. @@ -1335,9 +1348,22 @@ func (m *PreservationActionMutation) OldCompletedAt(ctx context.Context) (v time return oldValue.CompletedAt, nil } +// ClearCompletedAt clears the value of the "completed_at" field. +func (m *PreservationActionMutation) ClearCompletedAt() { + m.completed_at = nil + m.clearedFields[preservationaction.FieldCompletedAt] = struct{}{} +} + +// CompletedAtCleared returns if the "completed_at" field was cleared in this mutation. +func (m *PreservationActionMutation) CompletedAtCleared() bool { + _, ok := m.clearedFields[preservationaction.FieldCompletedAt] + return ok +} + // ResetCompletedAt resets all changes to the "completed_at" field. func (m *PreservationActionMutation) ResetCompletedAt() { m.completed_at = nil + delete(m.clearedFields, preservationaction.FieldCompletedAt) } // SetPackageID sets the "package_id" field. @@ -1658,7 +1684,14 @@ func (m *PreservationActionMutation) AddField(name string, value ent.Value) erro // ClearedFields returns all nullable fields that were cleared during this // mutation. func (m *PreservationActionMutation) ClearedFields() []string { - return nil + var fields []string + if m.FieldCleared(preservationaction.FieldStartedAt) { + fields = append(fields, preservationaction.FieldStartedAt) + } + if m.FieldCleared(preservationaction.FieldCompletedAt) { + fields = append(fields, preservationaction.FieldCompletedAt) + } + return fields } // FieldCleared returns a boolean indicating if a field with the given name was @@ -1671,6 +1704,14 @@ func (m *PreservationActionMutation) FieldCleared(name string) bool { // ClearField clears the value of the field with the given name. It returns an // error if the field is not defined in the schema. func (m *PreservationActionMutation) ClearField(name string) error { + switch name { + case preservationaction.FieldStartedAt: + m.ClearStartedAt() + return nil + case preservationaction.FieldCompletedAt: + m.ClearCompletedAt() + return nil + } return fmt.Errorf("unknown PreservationAction nullable field %s", name) } @@ -2080,9 +2121,22 @@ func (m *PreservationTaskMutation) OldStartedAt(ctx context.Context) (v time.Tim return oldValue.StartedAt, nil } +// ClearStartedAt clears the value of the "started_at" field. +func (m *PreservationTaskMutation) ClearStartedAt() { + m.started_at = nil + m.clearedFields[preservationtask.FieldStartedAt] = struct{}{} +} + +// StartedAtCleared returns if the "started_at" field was cleared in this mutation. +func (m *PreservationTaskMutation) StartedAtCleared() bool { + _, ok := m.clearedFields[preservationtask.FieldStartedAt] + return ok +} + // ResetStartedAt resets all changes to the "started_at" field. func (m *PreservationTaskMutation) ResetStartedAt() { m.started_at = nil + delete(m.clearedFields, preservationtask.FieldStartedAt) } // SetCompletedAt sets the "completed_at" field. @@ -2116,9 +2170,22 @@ func (m *PreservationTaskMutation) OldCompletedAt(ctx context.Context) (v time.T return oldValue.CompletedAt, nil } +// ClearCompletedAt clears the value of the "completed_at" field. +func (m *PreservationTaskMutation) ClearCompletedAt() { + m.completed_at = nil + m.clearedFields[preservationtask.FieldCompletedAt] = struct{}{} +} + +// CompletedAtCleared returns if the "completed_at" field was cleared in this mutation. +func (m *PreservationTaskMutation) CompletedAtCleared() bool { + _, ok := m.clearedFields[preservationtask.FieldCompletedAt] + return ok +} + // ResetCompletedAt resets all changes to the "completed_at" field. func (m *PreservationTaskMutation) ResetCompletedAt() { m.completed_at = nil + delete(m.clearedFields, preservationtask.FieldCompletedAt) } // SetNote sets the "note" field. @@ -2436,7 +2503,14 @@ func (m *PreservationTaskMutation) AddField(name string, value ent.Value) error // ClearedFields returns all nullable fields that were cleared during this // mutation. func (m *PreservationTaskMutation) ClearedFields() []string { - return nil + var fields []string + if m.FieldCleared(preservationtask.FieldStartedAt) { + fields = append(fields, preservationtask.FieldStartedAt) + } + if m.FieldCleared(preservationtask.FieldCompletedAt) { + fields = append(fields, preservationtask.FieldCompletedAt) + } + return fields } // FieldCleared returns a boolean indicating if a field with the given name was @@ -2449,6 +2523,14 @@ func (m *PreservationTaskMutation) FieldCleared(name string) bool { // ClearField clears the value of the field with the given name. It returns an // error if the field is not defined in the schema. func (m *PreservationTaskMutation) ClearField(name string) error { + switch name { + case preservationtask.FieldStartedAt: + m.ClearStartedAt() + return nil + case preservationtask.FieldCompletedAt: + m.ClearCompletedAt() + return nil + } return fmt.Errorf("unknown PreservationTask nullable field %s", name) } diff --git a/internal/persistence/ent/db/preservationaction/where.go b/internal/persistence/ent/db/preservationaction/where.go index 91b1ddbbb..f9cff848b 100644 --- a/internal/persistence/ent/db/preservationaction/where.go +++ b/internal/persistence/ent/db/preservationaction/where.go @@ -270,6 +270,16 @@ func StartedAtLTE(v time.Time) predicate.PreservationAction { return predicate.PreservationAction(sql.FieldLTE(FieldStartedAt, v)) } +// StartedAtIsNil applies the IsNil predicate on the "started_at" field. +func StartedAtIsNil() predicate.PreservationAction { + return predicate.PreservationAction(sql.FieldIsNull(FieldStartedAt)) +} + +// StartedAtNotNil applies the NotNil predicate on the "started_at" field. +func StartedAtNotNil() predicate.PreservationAction { + return predicate.PreservationAction(sql.FieldNotNull(FieldStartedAt)) +} + // CompletedAtEQ applies the EQ predicate on the "completed_at" field. func CompletedAtEQ(v time.Time) predicate.PreservationAction { return predicate.PreservationAction(sql.FieldEQ(FieldCompletedAt, v)) @@ -310,6 +320,16 @@ func CompletedAtLTE(v time.Time) predicate.PreservationAction { return predicate.PreservationAction(sql.FieldLTE(FieldCompletedAt, v)) } +// CompletedAtIsNil applies the IsNil predicate on the "completed_at" field. +func CompletedAtIsNil() predicate.PreservationAction { + return predicate.PreservationAction(sql.FieldIsNull(FieldCompletedAt)) +} + +// CompletedAtNotNil applies the NotNil predicate on the "completed_at" field. +func CompletedAtNotNil() predicate.PreservationAction { + return predicate.PreservationAction(sql.FieldNotNull(FieldCompletedAt)) +} + // PackageIDEQ applies the EQ predicate on the "package_id" field. func PackageIDEQ(v int) predicate.PreservationAction { return predicate.PreservationAction(sql.FieldEQ(FieldPackageID, v)) diff --git a/internal/persistence/ent/db/preservationaction_create.go b/internal/persistence/ent/db/preservationaction_create.go index c69aa8988..08b997591 100644 --- a/internal/persistence/ent/db/preservationaction_create.go +++ b/internal/persistence/ent/db/preservationaction_create.go @@ -46,12 +46,28 @@ func (pac *PreservationActionCreate) SetStartedAt(t time.Time) *PreservationActi return pac } +// SetNillableStartedAt sets the "started_at" field if the given value is not nil. +func (pac *PreservationActionCreate) SetNillableStartedAt(t *time.Time) *PreservationActionCreate { + if t != nil { + pac.SetStartedAt(*t) + } + return pac +} + // SetCompletedAt sets the "completed_at" field. func (pac *PreservationActionCreate) SetCompletedAt(t time.Time) *PreservationActionCreate { pac.mutation.SetCompletedAt(t) return pac } +// SetNillableCompletedAt sets the "completed_at" field if the given value is not nil. +func (pac *PreservationActionCreate) SetNillableCompletedAt(t *time.Time) *PreservationActionCreate { + if t != nil { + pac.SetCompletedAt(*t) + } + return pac +} + // SetPackageID sets the "package_id" field. func (pac *PreservationActionCreate) SetPackageID(i int) *PreservationActionCreate { pac.mutation.SetPackageID(i) @@ -121,12 +137,6 @@ func (pac *PreservationActionCreate) check() error { if _, ok := pac.mutation.Status(); !ok { return &ValidationError{Name: "status", err: errors.New(`db: missing required field "PreservationAction.status"`)} } - if _, ok := pac.mutation.StartedAt(); !ok { - return &ValidationError{Name: "started_at", err: errors.New(`db: missing required field "PreservationAction.started_at"`)} - } - if _, ok := pac.mutation.CompletedAt(); !ok { - return &ValidationError{Name: "completed_at", err: errors.New(`db: missing required field "PreservationAction.completed_at"`)} - } if _, ok := pac.mutation.PackageID(); !ok { return &ValidationError{Name: "package_id", err: errors.New(`db: missing required field "PreservationAction.package_id"`)} } diff --git a/internal/persistence/ent/db/preservationaction_update.go b/internal/persistence/ent/db/preservationaction_update.go index cc45d8366..c0360845a 100644 --- a/internal/persistence/ent/db/preservationaction_update.go +++ b/internal/persistence/ent/db/preservationaction_update.go @@ -100,6 +100,12 @@ func (pau *PreservationActionUpdate) SetNillableStartedAt(t *time.Time) *Preserv return pau } +// ClearStartedAt clears the value of the "started_at" field. +func (pau *PreservationActionUpdate) ClearStartedAt() *PreservationActionUpdate { + pau.mutation.ClearStartedAt() + return pau +} + // SetCompletedAt sets the "completed_at" field. func (pau *PreservationActionUpdate) SetCompletedAt(t time.Time) *PreservationActionUpdate { pau.mutation.SetCompletedAt(t) @@ -114,6 +120,12 @@ func (pau *PreservationActionUpdate) SetNillableCompletedAt(t *time.Time) *Prese return pau } +// ClearCompletedAt clears the value of the "completed_at" field. +func (pau *PreservationActionUpdate) ClearCompletedAt() *PreservationActionUpdate { + pau.mutation.ClearCompletedAt() + return pau +} + // SetPackageID sets the "package_id" field. func (pau *PreservationActionUpdate) SetPackageID(i int) *PreservationActionUpdate { pau.mutation.SetPackageID(i) @@ -250,9 +262,15 @@ func (pau *PreservationActionUpdate) sqlSave(ctx context.Context) (n int, err er if value, ok := pau.mutation.StartedAt(); ok { _spec.SetField(preservationaction.FieldStartedAt, field.TypeTime, value) } + if pau.mutation.StartedAtCleared() { + _spec.ClearField(preservationaction.FieldStartedAt, field.TypeTime) + } if value, ok := pau.mutation.CompletedAt(); ok { _spec.SetField(preservationaction.FieldCompletedAt, field.TypeTime, value) } + if pau.mutation.CompletedAtCleared() { + _spec.ClearField(preservationaction.FieldCompletedAt, field.TypeTime) + } if pau.mutation.PackageCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, @@ -417,6 +435,12 @@ func (pauo *PreservationActionUpdateOne) SetNillableStartedAt(t *time.Time) *Pre return pauo } +// ClearStartedAt clears the value of the "started_at" field. +func (pauo *PreservationActionUpdateOne) ClearStartedAt() *PreservationActionUpdateOne { + pauo.mutation.ClearStartedAt() + return pauo +} + // SetCompletedAt sets the "completed_at" field. func (pauo *PreservationActionUpdateOne) SetCompletedAt(t time.Time) *PreservationActionUpdateOne { pauo.mutation.SetCompletedAt(t) @@ -431,6 +455,12 @@ func (pauo *PreservationActionUpdateOne) SetNillableCompletedAt(t *time.Time) *P return pauo } +// ClearCompletedAt clears the value of the "completed_at" field. +func (pauo *PreservationActionUpdateOne) ClearCompletedAt() *PreservationActionUpdateOne { + pauo.mutation.ClearCompletedAt() + return pauo +} + // SetPackageID sets the "package_id" field. func (pauo *PreservationActionUpdateOne) SetPackageID(i int) *PreservationActionUpdateOne { pauo.mutation.SetPackageID(i) @@ -597,9 +627,15 @@ func (pauo *PreservationActionUpdateOne) sqlSave(ctx context.Context) (_node *Pr if value, ok := pauo.mutation.StartedAt(); ok { _spec.SetField(preservationaction.FieldStartedAt, field.TypeTime, value) } + if pauo.mutation.StartedAtCleared() { + _spec.ClearField(preservationaction.FieldStartedAt, field.TypeTime) + } if value, ok := pauo.mutation.CompletedAt(); ok { _spec.SetField(preservationaction.FieldCompletedAt, field.TypeTime, value) } + if pauo.mutation.CompletedAtCleared() { + _spec.ClearField(preservationaction.FieldCompletedAt, field.TypeTime) + } if pauo.mutation.PackageCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, diff --git a/internal/persistence/ent/db/preservationtask/where.go b/internal/persistence/ent/db/preservationtask/where.go index 359f2e0d2..e0187409c 100644 --- a/internal/persistence/ent/db/preservationtask/where.go +++ b/internal/persistence/ent/db/preservationtask/where.go @@ -276,6 +276,16 @@ func StartedAtLTE(v time.Time) predicate.PreservationTask { return predicate.PreservationTask(sql.FieldLTE(FieldStartedAt, v)) } +// StartedAtIsNil applies the IsNil predicate on the "started_at" field. +func StartedAtIsNil() predicate.PreservationTask { + return predicate.PreservationTask(sql.FieldIsNull(FieldStartedAt)) +} + +// StartedAtNotNil applies the NotNil predicate on the "started_at" field. +func StartedAtNotNil() predicate.PreservationTask { + return predicate.PreservationTask(sql.FieldNotNull(FieldStartedAt)) +} + // CompletedAtEQ applies the EQ predicate on the "completed_at" field. func CompletedAtEQ(v time.Time) predicate.PreservationTask { return predicate.PreservationTask(sql.FieldEQ(FieldCompletedAt, v)) @@ -316,6 +326,16 @@ func CompletedAtLTE(v time.Time) predicate.PreservationTask { return predicate.PreservationTask(sql.FieldLTE(FieldCompletedAt, v)) } +// CompletedAtIsNil applies the IsNil predicate on the "completed_at" field. +func CompletedAtIsNil() predicate.PreservationTask { + return predicate.PreservationTask(sql.FieldIsNull(FieldCompletedAt)) +} + +// CompletedAtNotNil applies the NotNil predicate on the "completed_at" field. +func CompletedAtNotNil() predicate.PreservationTask { + return predicate.PreservationTask(sql.FieldNotNull(FieldCompletedAt)) +} + // NoteEQ applies the EQ predicate on the "note" field. func NoteEQ(v string) predicate.PreservationTask { return predicate.PreservationTask(sql.FieldEQ(FieldNote, v)) diff --git a/internal/persistence/ent/db/preservationtask_create.go b/internal/persistence/ent/db/preservationtask_create.go index a45a4bc7f..73a166d42 100644 --- a/internal/persistence/ent/db/preservationtask_create.go +++ b/internal/persistence/ent/db/preservationtask_create.go @@ -46,12 +46,28 @@ func (ptc *PreservationTaskCreate) SetStartedAt(t time.Time) *PreservationTaskCr return ptc } +// SetNillableStartedAt sets the "started_at" field if the given value is not nil. +func (ptc *PreservationTaskCreate) SetNillableStartedAt(t *time.Time) *PreservationTaskCreate { + if t != nil { + ptc.SetStartedAt(*t) + } + return ptc +} + // SetCompletedAt sets the "completed_at" field. func (ptc *PreservationTaskCreate) SetCompletedAt(t time.Time) *PreservationTaskCreate { ptc.mutation.SetCompletedAt(t) return ptc } +// SetNillableCompletedAt sets the "completed_at" field if the given value is not nil. +func (ptc *PreservationTaskCreate) SetNillableCompletedAt(t *time.Time) *PreservationTaskCreate { + if t != nil { + ptc.SetCompletedAt(*t) + } + return ptc +} + // SetNote sets the "note" field. func (ptc *PreservationTaskCreate) SetNote(s string) *PreservationTaskCreate { ptc.mutation.SetNote(s) @@ -118,12 +134,6 @@ func (ptc *PreservationTaskCreate) check() error { if _, ok := ptc.mutation.Status(); !ok { return &ValidationError{Name: "status", err: errors.New(`db: missing required field "PreservationTask.status"`)} } - if _, ok := ptc.mutation.StartedAt(); !ok { - return &ValidationError{Name: "started_at", err: errors.New(`db: missing required field "PreservationTask.started_at"`)} - } - if _, ok := ptc.mutation.CompletedAt(); !ok { - return &ValidationError{Name: "completed_at", err: errors.New(`db: missing required field "PreservationTask.completed_at"`)} - } if _, ok := ptc.mutation.Note(); !ok { return &ValidationError{Name: "note", err: errors.New(`db: missing required field "PreservationTask.note"`)} } diff --git a/internal/persistence/ent/db/preservationtask_update.go b/internal/persistence/ent/db/preservationtask_update.go index e1e65e88d..8938b18be 100644 --- a/internal/persistence/ent/db/preservationtask_update.go +++ b/internal/persistence/ent/db/preservationtask_update.go @@ -93,6 +93,12 @@ func (ptu *PreservationTaskUpdate) SetNillableStartedAt(t *time.Time) *Preservat return ptu } +// ClearStartedAt clears the value of the "started_at" field. +func (ptu *PreservationTaskUpdate) ClearStartedAt() *PreservationTaskUpdate { + ptu.mutation.ClearStartedAt() + return ptu +} + // SetCompletedAt sets the "completed_at" field. func (ptu *PreservationTaskUpdate) SetCompletedAt(t time.Time) *PreservationTaskUpdate { ptu.mutation.SetCompletedAt(t) @@ -107,6 +113,12 @@ func (ptu *PreservationTaskUpdate) SetNillableCompletedAt(t *time.Time) *Preserv return ptu } +// ClearCompletedAt clears the value of the "completed_at" field. +func (ptu *PreservationTaskUpdate) ClearCompletedAt() *PreservationTaskUpdate { + ptu.mutation.ClearCompletedAt() + return ptu +} + // SetNote sets the "note" field. func (ptu *PreservationTaskUpdate) SetNote(s string) *PreservationTaskUpdate { ptu.mutation.SetNote(s) @@ -224,9 +236,15 @@ func (ptu *PreservationTaskUpdate) sqlSave(ctx context.Context) (n int, err erro if value, ok := ptu.mutation.StartedAt(); ok { _spec.SetField(preservationtask.FieldStartedAt, field.TypeTime, value) } + if ptu.mutation.StartedAtCleared() { + _spec.ClearField(preservationtask.FieldStartedAt, field.TypeTime) + } if value, ok := ptu.mutation.CompletedAt(); ok { _spec.SetField(preservationtask.FieldCompletedAt, field.TypeTime, value) } + if ptu.mutation.CompletedAtCleared() { + _spec.ClearField(preservationtask.FieldCompletedAt, field.TypeTime) + } if value, ok := ptu.mutation.Note(); ok { _spec.SetField(preservationtask.FieldNote, field.TypeString, value) } @@ -342,6 +360,12 @@ func (ptuo *PreservationTaskUpdateOne) SetNillableStartedAt(t *time.Time) *Prese return ptuo } +// ClearStartedAt clears the value of the "started_at" field. +func (ptuo *PreservationTaskUpdateOne) ClearStartedAt() *PreservationTaskUpdateOne { + ptuo.mutation.ClearStartedAt() + return ptuo +} + // SetCompletedAt sets the "completed_at" field. func (ptuo *PreservationTaskUpdateOne) SetCompletedAt(t time.Time) *PreservationTaskUpdateOne { ptuo.mutation.SetCompletedAt(t) @@ -356,6 +380,12 @@ func (ptuo *PreservationTaskUpdateOne) SetNillableCompletedAt(t *time.Time) *Pre return ptuo } +// ClearCompletedAt clears the value of the "completed_at" field. +func (ptuo *PreservationTaskUpdateOne) ClearCompletedAt() *PreservationTaskUpdateOne { + ptuo.mutation.ClearCompletedAt() + return ptuo +} + // SetNote sets the "note" field. func (ptuo *PreservationTaskUpdateOne) SetNote(s string) *PreservationTaskUpdateOne { ptuo.mutation.SetNote(s) @@ -503,9 +533,15 @@ func (ptuo *PreservationTaskUpdateOne) sqlSave(ctx context.Context) (_node *Pres if value, ok := ptuo.mutation.StartedAt(); ok { _spec.SetField(preservationtask.FieldStartedAt, field.TypeTime, value) } + if ptuo.mutation.StartedAtCleared() { + _spec.ClearField(preservationtask.FieldStartedAt, field.TypeTime) + } if value, ok := ptuo.mutation.CompletedAt(); ok { _spec.SetField(preservationtask.FieldCompletedAt, field.TypeTime, value) } + if ptuo.mutation.CompletedAtCleared() { + _spec.ClearField(preservationtask.FieldCompletedAt, field.TypeTime) + } if value, ok := ptuo.mutation.Note(); ok { _spec.SetField(preservationtask.FieldNote, field.TypeString, value) } diff --git a/internal/persistence/ent/schema/preservation_action.go b/internal/persistence/ent/schema/preservation_action.go index c00a5e0df..6b1c91631 100644 --- a/internal/persistence/ent/schema/preservation_action.go +++ b/internal/persistence/ent/schema/preservation_action.go @@ -29,8 +29,10 @@ func (PreservationAction) Fields() []ent.Field { }), field.Int8("type"), field.Int8("status"), - field.Time("started_at"), - field.Time("completed_at"), + field.Time("started_at"). + Optional(), + field.Time("completed_at"). + Optional(), field.Int("package_id"). Positive(), } diff --git a/internal/persistence/ent/schema/preservation_task.go b/internal/persistence/ent/schema/preservation_task.go index 1000b94a9..cb07be830 100644 --- a/internal/persistence/ent/schema/preservation_task.go +++ b/internal/persistence/ent/schema/preservation_task.go @@ -30,8 +30,10 @@ func (PreservationTask) Fields() []ent.Field { Size: 2048, }), field.Int8("status"), - field.Time("started_at"), - field.Time("completed_at"), + field.Time("started_at"). + Optional(), + field.Time("completed_at"). + Optional(), field.Text("note"), field.Int("preservation_action_id"). Positive(), From a779dc76ea48d6ca98d8b2819db2dc5fe8f68cf1 Mon Sep 17 00:00:00 2001 From: David Juhasz Date: Tue, 26 Mar 2024 17:31:47 -0700 Subject: [PATCH 3/9] Move entclient package methods Split ent client package methods into a separate `package` file. --- internal/persistence/ent/client/client.go | 126 ------- .../persistence/ent/client/client_test.go | 311 ----------------- internal/persistence/ent/client/package.go | 128 +++++++ .../persistence/ent/client/package_test.go | 326 ++++++++++++++++++ 4 files changed, 454 insertions(+), 437 deletions(-) create mode 100644 internal/persistence/ent/client/package.go create mode 100644 internal/persistence/ent/client/package_test.go diff --git a/internal/persistence/ent/client/client.go b/internal/persistence/ent/client/client.go index a74d70dbc..d6aef4478 100644 --- a/internal/persistence/ent/client/client.go +++ b/internal/persistence/ent/client/client.go @@ -1,13 +1,8 @@ package entclient import ( - "context" - "time" - "github.com/go-logr/logr" - "github.com/google/uuid" - "github.com/artefactual-sdps/enduro/internal/datatypes" "github.com/artefactual-sdps/enduro/internal/persistence" "github.com/artefactual-sdps/enduro/internal/persistence/ent/db" ) @@ -23,124 +18,3 @@ var _ persistence.Service = (*client)(nil) func New(logger logr.Logger, ent *db.Client) persistence.Service { return &client{logger: logger, ent: ent} } - -// CreatePackage creates and persists a new package using the values from pkg -// then returns the updated package. -// -// The input pkg "ID" and "CreatedAt" values are ignored; the stored package -// "ID" is generated by the persistence implementation and "CreatedAt" is always -// set to the current time. -func (c *client) CreatePackage(ctx context.Context, pkg *datatypes.Package) error { - // Validate required fields. - if pkg.Name == "" { - return newRequiredFieldError("Name") - } - if pkg.WorkflowID == "" { - return newRequiredFieldError("WorkflowID") - } - - if pkg.RunID == "" { - return newRequiredFieldError("RunID") - } - runID, err := uuid.Parse(pkg.RunID) - if err != nil { - return newParseError(err, "RunID") - } - - q := c.ent.Pkg.Create(). - SetName(pkg.Name). - SetWorkflowID(pkg.WorkflowID). - SetRunID(runID). - SetStatus(int8(pkg.Status)) - - // Add optional fields. - if pkg.AIPID.Valid { - q.SetAipID(pkg.AIPID.UUID) - } - if pkg.LocationID.Valid { - q.SetLocationID(pkg.LocationID.UUID) - } - if pkg.StartedAt.Valid { - q.SetStartedAt(pkg.StartedAt.Time) - } - if pkg.CompletedAt.Valid { - q.SetCompletedAt(pkg.CompletedAt.Time) - } - - // Set CreatedAt to the current time - q.SetCreatedAt(time.Now()) - - // Save the package. - p, err := q.Save(ctx) - if err != nil { - return newDBErrorWithDetails(err, "create package") - } - - // Update pkg with DB data, to get generated values (e.g. ID). - *pkg = *convertPkgToPackage(p) - - return nil -} - -// UpdatePackage updates the persisted package identified by id using the -// updater function, then returns the updated package. -// -// The package "ID" and "CreatedAt" field values can not be updated with this -// method. -func (c *client) UpdatePackage( - ctx context.Context, - id uint, - updater persistence.PackageUpdater, -) (*datatypes.Package, error) { - tx, err := c.ent.BeginTx(ctx, nil) - if err != nil { - return nil, newDBError(err) - } - - p, err := tx.Pkg.Get(ctx, int(id)) - if err != nil { - return nil, rollback(tx, newDBError(err)) - } - - up, err := updater(convertPkgToPackage(p)) - if err != nil { - return nil, rollback(tx, newUpdaterError(err)) - } - - runID, err := uuid.Parse(up.RunID) - if err != nil { - return nil, rollback(tx, newParseError(err, "RunID")) - } - - // Set required column values. - q := tx.Pkg.UpdateOneID(int(id)). - SetName(up.Name). - SetWorkflowID(up.WorkflowID). - SetRunID(runID). - SetStatus(int8(up.Status)) - - // Set nullable column values. - if up.AIPID.Valid { - q.SetAipID(up.AIPID.UUID) - } - if up.LocationID.Valid { - q.SetLocationID(up.LocationID.UUID) - } - if up.StartedAt.Valid { - q.SetStartedAt(up.StartedAt.Time) - } - if up.CompletedAt.Valid { - q.SetCompletedAt(up.CompletedAt.Time) - } - - // Save changes. - p, err = q.Save(ctx) - if err != nil { - return nil, rollback(tx, newDBError(err)) - } - if err = tx.Commit(); err != nil { - return nil, rollback(tx, newDBError(err)) - } - - return convertPkgToPackage(p), nil -} diff --git a/internal/persistence/ent/client/client_test.go b/internal/persistence/ent/client/client_test.go index 5de2d9f9f..58b1453b7 100644 --- a/internal/persistence/ent/client/client_test.go +++ b/internal/persistence/ent/client/client_test.go @@ -2,18 +2,14 @@ package entclient_test import ( "context" - "database/sql" "fmt" "testing" - "time" "github.com/go-logr/logr" - "github.com/google/go-cmp/cmp/cmpopts" "github.com/google/uuid" _ "github.com/mattn/go-sqlite3" "gotest.tools/v3/assert" - "github.com/artefactual-sdps/enduro/internal/datatypes" "github.com/artefactual-sdps/enduro/internal/enums" "github.com/artefactual-sdps/enduro/internal/persistence" entclient "github.com/artefactual-sdps/enduro/internal/persistence/ent/client" @@ -57,310 +53,3 @@ func TestNew(t *testing.T) { assert.Equal(t, p.Status, int8(enums.PackageStatusInProgress)) }) } - -func TestCreatePackage(t *testing.T) { - runID := uuid.New() - aipID := uuid.NullUUID{UUID: uuid.New(), Valid: true} - locID := uuid.NullUUID{UUID: uuid.New(), Valid: true} - started := sql.NullTime{Time: time.Now(), Valid: true} - completed := sql.NullTime{Time: started.Time.Add(time.Second), Valid: true} - - type params struct { - pkg *datatypes.Package - } - tests := []struct { - name string - args params - want *datatypes.Package - wantErr string - }{ - { - name: "Saves a new package in the DB", - args: params{ - pkg: &datatypes.Package{ - Name: "Test package 1", - WorkflowID: "workflow-1", - RunID: runID.String(), - AIPID: aipID, - LocationID: locID, - Status: enums.PackageStatusInProgress, - StartedAt: started, - CompletedAt: completed, - }, - }, - want: &datatypes.Package{ - ID: 1, - Name: "Test package 1", - WorkflowID: "workflow-1", - RunID: runID.String(), - AIPID: aipID, - LocationID: locID, - Status: enums.PackageStatusInProgress, - CreatedAt: time.Now(), - StartedAt: started, - CompletedAt: completed, - }, - }, - { - name: "Saves a package with missing optional fields", - args: params{ - pkg: &datatypes.Package{ - Name: "Test package 2", - WorkflowID: "workflow-2", - RunID: runID.String(), - Status: enums.PackageStatusInProgress, - }, - }, - want: &datatypes.Package{ - ID: 1, - Name: "Test package 2", - WorkflowID: "workflow-2", - RunID: runID.String(), - Status: enums.PackageStatusInProgress, - CreatedAt: time.Now(), - }, - }, - { - name: "Required field error for missing Name", - args: params{ - pkg: &datatypes.Package{}, - }, - wantErr: "invalid data error: field \"Name\" is required", - }, - { - name: "Required field error for missing WorkflowID", - args: params{ - pkg: &datatypes.Package{ - Name: "Missing WorkflowID", - }, - }, - wantErr: "invalid data error: field \"WorkflowID\" is required", - }, - { - name: "Required field error for missing RunID", - args: params{ - pkg: &datatypes.Package{ - Name: "Missing RunID", - WorkflowID: "workflow-12345", - }, - }, - wantErr: "invalid data error: field \"RunID\" is required", - }, - { - name: "Errors on invalid RunID", - args: params{ - pkg: &datatypes.Package{ - Name: "Invalid package 1", - WorkflowID: "workflow-invalid", - RunID: "Bad UUID", - }, - }, - wantErr: "invalid data error: parse error: field \"RunID\": invalid UUID length: 8", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - _, svc := setUpClient(t, logr.Discard()) - ctx := context.Background() - pkg := *tt.args.pkg // Make a local copy of pkg. - - err := svc.CreatePackage(ctx, &pkg) - if tt.wantErr != "" { - assert.Error(t, err, tt.wantErr) - return - } - assert.NilError(t, err) - - assert.DeepEqual(t, &pkg, tt.want, - cmpopts.EquateApproxTime(time.Millisecond*100), - cmpopts.IgnoreUnexported(db.Pkg{}, db.PkgEdges{}), - ) - }) - } -} - -func TestUpdatePackage(t *testing.T) { - runID := uuid.MustParse("c5f7c35a-d5a6-4e00-b4da-b036ce5b40bc") - runID2 := uuid.MustParse("c04d0191-d7ce-46dd-beff-92d6830082ff") - - aipID := uuid.NullUUID{ - UUID: uuid.MustParse("e2ace0da-8697-453d-9ea1-4c9b62309e54"), - Valid: true, - } - aipID2 := uuid.NullUUID{ - UUID: uuid.MustParse("7d085541-af56-4444-9ce2-d6401ff4c97b"), - Valid: true, - } - - locID := uuid.NullUUID{ - UUID: uuid.MustParse("146182ff-9923-4869-bca1-0bbc0f822025"), - Valid: true, - } - locID2 := uuid.NullUUID{ - UUID: uuid.MustParse("6e30694b-6497-439f-bf99-83af165e02c3"), - Valid: true, - } - - started := sql.NullTime{Time: time.Now(), Valid: true} - started2 := sql.NullTime{ - Time: func() time.Time { - t, _ := time.Parse(time.RFC3339, "1980-01-01T09:30:00Z") - return t - }(), - Valid: true, - } - - completed := sql.NullTime{Time: started.Time.Add(time.Second), Valid: true} - completed2 := sql.NullTime{Time: started2.Time.Add(time.Second), Valid: true} - - type params struct { - pkg *datatypes.Package - updater persistence.PackageUpdater - } - tests := []struct { - name string - args params - want *datatypes.Package - wantErr string - }{ - { - name: "Updates all package columns", - args: params{ - pkg: &datatypes.Package{ - Name: "Test package", - WorkflowID: "workflow-1", - RunID: runID.String(), - AIPID: aipID, - LocationID: locID, - Status: enums.PackageStatusInProgress, - StartedAt: started, - CompletedAt: completed, - }, - updater: func(p *datatypes.Package) (*datatypes.Package, error) { - p.ID = 100 // No-op, can't update ID. - p.Name = "Updated package" - p.WorkflowID = "workflow-2" - p.RunID = runID2.String() - p.AIPID = aipID2 - p.LocationID = locID2 - p.Status = enums.PackageStatusDone - p.CreatedAt = started2.Time // No-op, can't update CreatedAt. - p.StartedAt = started2 - p.CompletedAt = completed2 - return p, nil - }, - }, - want: &datatypes.Package{ - ID: 1, - Name: "Updated package", - WorkflowID: "workflow-2", - RunID: runID2.String(), - AIPID: aipID2, - LocationID: locID2, - Status: enums.PackageStatusDone, - CreatedAt: time.Now(), - StartedAt: started2, - CompletedAt: completed2, - }, - }, - { - name: "Only updates selected columns", - args: params{ - pkg: &datatypes.Package{ - Name: "Test package", - WorkflowID: "workflow-1", - RunID: runID.String(), - AIPID: aipID, - Status: enums.PackageStatusInProgress, - StartedAt: started, - }, - updater: func(p *datatypes.Package) (*datatypes.Package, error) { - p.Status = enums.PackageStatusDone - p.CompletedAt = completed - return p, nil - }, - }, - want: &datatypes.Package{ - ID: 1, - Name: "Test package", - WorkflowID: "workflow-1", - RunID: runID.String(), - AIPID: aipID, - Status: enums.PackageStatusDone, - CreatedAt: time.Now(), - StartedAt: started, - CompletedAt: completed, - }, - }, - { - name: "Errors when package to update is not found", - args: params{ - updater: func(p *datatypes.Package) (*datatypes.Package, error) { - return nil, fmt.Errorf("Bad input") - }, - }, - wantErr: "not found error: db: pkg not found", - }, - { - name: "Errors when the updater errors", - args: params{ - pkg: &datatypes.Package{ - Name: "Test package", - WorkflowID: "workflow-1", - RunID: runID.String(), - AIPID: aipID, - }, - updater: func(p *datatypes.Package) (*datatypes.Package, error) { - return nil, fmt.Errorf("Bad input") - }, - }, - wantErr: "invalid data error: updater error: Bad input", - }, - { - name: "Errors when updater sets an invalid RunID", - args: params{ - pkg: &datatypes.Package{ - Name: "Test package", - WorkflowID: "workflow-1", - RunID: runID.String(), - AIPID: aipID, - }, - updater: func(p *datatypes.Package) (*datatypes.Package, error) { - p.RunID = "Bad UUID" - return p, nil - }, - }, - wantErr: "invalid data error: parse error: field \"RunID\": invalid UUID length: 8", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - _, svc := setUpClient(t, logr.Discard()) - ctx := context.Background() - - var id uint - if tt.args.pkg != nil { - pkg := *tt.args.pkg // Make a local copy of pkg. - err := svc.CreatePackage(ctx, &pkg) - assert.NilError(t, err) - - id = pkg.ID - } - - pp, err := svc.UpdatePackage(ctx, id, tt.args.updater) - if tt.wantErr != "" { - assert.Error(t, err, tt.wantErr) - return - } - - assert.DeepEqual(t, pp, tt.want, - cmpopts.EquateApproxTime(time.Millisecond*100), - cmpopts.IgnoreUnexported(db.Pkg{}, db.PkgEdges{}), - ) - }) - } -} diff --git a/internal/persistence/ent/client/package.go b/internal/persistence/ent/client/package.go new file mode 100644 index 000000000..d0f90203c --- /dev/null +++ b/internal/persistence/ent/client/package.go @@ -0,0 +1,128 @@ +package entclient + +import ( + "context" + "time" + + "github.com/google/uuid" + + "github.com/artefactual-sdps/enduro/internal/datatypes" + "github.com/artefactual-sdps/enduro/internal/persistence" +) + +// CreatePackage creates and persists a new package using the values from pkg +// then returns the updated package. +// +// The input pkg "ID" and "CreatedAt" values are ignored; the stored package +// "ID" is generated by the persistence implementation and "CreatedAt" is always +// set to the current time. +func (c *client) CreatePackage(ctx context.Context, pkg *datatypes.Package) error { + // Validate required fields. + if pkg.Name == "" { + return newRequiredFieldError("Name") + } + if pkg.WorkflowID == "" { + return newRequiredFieldError("WorkflowID") + } + + if pkg.RunID == "" { + return newRequiredFieldError("RunID") + } + runID, err := uuid.Parse(pkg.RunID) + if err != nil { + return newParseError(err, "RunID") + } + + q := c.ent.Pkg.Create(). + SetName(pkg.Name). + SetWorkflowID(pkg.WorkflowID). + SetRunID(runID). + SetStatus(int8(pkg.Status)) + + // Add optional fields. + if pkg.AIPID.Valid { + q.SetAipID(pkg.AIPID.UUID) + } + if pkg.LocationID.Valid { + q.SetLocationID(pkg.LocationID.UUID) + } + if pkg.StartedAt.Valid { + q.SetStartedAt(pkg.StartedAt.Time) + } + if pkg.CompletedAt.Valid { + q.SetCompletedAt(pkg.CompletedAt.Time) + } + + // Set CreatedAt to the current time + q.SetCreatedAt(time.Now()) + + // Save the package. + p, err := q.Save(ctx) + if err != nil { + return newDBErrorWithDetails(err, "create package") + } + + // Update pkg with DB data, to get generated values (e.g. ID). + *pkg = *convertPkgToPackage(p) + + return nil +} + +// UpdatePackage updates the persisted package identified by id using the +// updater function, then returns the updated package. +// +// The package "ID" and "CreatedAt" field values can not be updated with this +// method. +func (c *client) UpdatePackage(ctx context.Context, id uint, updater persistence.PackageUpdater) (*datatypes.Package, error) { + tx, err := c.ent.BeginTx(ctx, nil) + if err != nil { + return nil, newDBError(err) + } + + p, err := tx.Pkg.Get(ctx, int(id)) + if err != nil { + return nil, rollback(tx, newDBError(err)) + } + + up, err := updater(convertPkgToPackage(p)) + if err != nil { + return nil, rollback(tx, newUpdaterError(err)) + } + + runID, err := uuid.Parse(up.RunID) + if err != nil { + return nil, rollback(tx, newParseError(err, "RunID")) + } + + // Set required column values. + q := tx.Pkg.UpdateOneID(int(id)). + SetName(up.Name). + SetWorkflowID(up.WorkflowID). + SetRunID(runID). + SetStatus(int8(up.Status)) + + // Set nullable column values. + if up.AIPID.Valid { + q.SetAipID(up.AIPID.UUID) + } + if up.LocationID.Valid { + q.SetLocationID(up.LocationID.UUID) + } + if up.StartedAt.Valid { + q.SetStartedAt(up.StartedAt.Time) + } + if up.CompletedAt.Valid { + q.SetCompletedAt(up.CompletedAt.Time) + } + + // Save changes. + p, err = q.Save(ctx) + if err != nil { + return nil, rollback(tx, newDBError(err)) + } + if err = tx.Commit(); err != nil { + return nil, rollback(tx, newDBError(err)) + } + + return convertPkgToPackage(p), nil +} diff --git a/internal/persistence/ent/client/package_test.go b/internal/persistence/ent/client/package_test.go new file mode 100644 index 000000000..4d51d4e41 --- /dev/null +++ b/internal/persistence/ent/client/package_test.go @@ -0,0 +1,326 @@ +package entclient_test + +import ( + "context" + "database/sql" + "fmt" + "testing" + "time" + + "github.com/go-logr/logr" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/google/uuid" + "gotest.tools/v3/assert" + + "github.com/artefactual-sdps/enduro/internal/datatypes" + "github.com/artefactual-sdps/enduro/internal/enums" + "github.com/artefactual-sdps/enduro/internal/persistence" + "github.com/artefactual-sdps/enduro/internal/persistence/ent/db" +) + +func TestCreatePackage(t *testing.T) { + runID := uuid.New() + aipID := uuid.NullUUID{UUID: uuid.New(), Valid: true} + locID := uuid.NullUUID{UUID: uuid.New(), Valid: true} + started := sql.NullTime{Time: time.Now(), Valid: true} + completed := sql.NullTime{Time: started.Time.Add(time.Second), Valid: true} + + type params struct { + pkg *datatypes.Package + } + tests := []struct { + name string + args params + want *datatypes.Package + wantErr string + }{ + { + name: "Saves a new package in the DB", + args: params{ + pkg: &datatypes.Package{ + Name: "Test package 1", + WorkflowID: "workflow-1", + RunID: runID.String(), + AIPID: aipID, + LocationID: locID, + Status: enums.PackageStatusInProgress, + StartedAt: started, + CompletedAt: completed, + }, + }, + want: &datatypes.Package{ + ID: 1, + Name: "Test package 1", + WorkflowID: "workflow-1", + RunID: runID.String(), + AIPID: aipID, + LocationID: locID, + Status: enums.PackageStatusInProgress, + CreatedAt: time.Now(), + StartedAt: started, + CompletedAt: completed, + }, + }, + { + name: "Saves a package with missing optional fields", + args: params{ + pkg: &datatypes.Package{ + Name: "Test package 2", + WorkflowID: "workflow-2", + RunID: runID.String(), + Status: enums.PackageStatusInProgress, + }, + }, + want: &datatypes.Package{ + ID: 1, + Name: "Test package 2", + WorkflowID: "workflow-2", + RunID: runID.String(), + Status: enums.PackageStatusInProgress, + CreatedAt: time.Now(), + }, + }, + { + name: "Required field error for missing Name", + args: params{ + pkg: &datatypes.Package{}, + }, + wantErr: "invalid data error: field \"Name\" is required", + }, + { + name: "Required field error for missing WorkflowID", + args: params{ + pkg: &datatypes.Package{ + Name: "Missing WorkflowID", + }, + }, + wantErr: "invalid data error: field \"WorkflowID\" is required", + }, + { + name: "Required field error for missing RunID", + args: params{ + pkg: &datatypes.Package{ + Name: "Missing RunID", + WorkflowID: "workflow-12345", + }, + }, + wantErr: "invalid data error: field \"RunID\" is required", + }, + { + name: "Errors on invalid RunID", + args: params{ + pkg: &datatypes.Package{ + Name: "Invalid package 1", + WorkflowID: "workflow-invalid", + RunID: "Bad UUID", + }, + }, + wantErr: "invalid data error: parse error: field \"RunID\": invalid UUID length: 8", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + _, svc := setUpClient(t, logr.Discard()) + ctx := context.Background() + pkg := *tt.args.pkg // Make a local copy of pkg. + + err := svc.CreatePackage(ctx, &pkg) + if tt.wantErr != "" { + assert.Error(t, err, tt.wantErr) + return + } + assert.NilError(t, err) + + assert.DeepEqual(t, &pkg, tt.want, + cmpopts.EquateApproxTime(time.Millisecond*100), + cmpopts.IgnoreUnexported(db.Pkg{}, db.PkgEdges{}), + ) + }) + } +} + +func TestUpdatePackage(t *testing.T) { + runID := uuid.MustParse("c5f7c35a-d5a6-4e00-b4da-b036ce5b40bc") + runID2 := uuid.MustParse("c04d0191-d7ce-46dd-beff-92d6830082ff") + + aipID := uuid.NullUUID{ + UUID: uuid.MustParse("e2ace0da-8697-453d-9ea1-4c9b62309e54"), + Valid: true, + } + aipID2 := uuid.NullUUID{ + UUID: uuid.MustParse("7d085541-af56-4444-9ce2-d6401ff4c97b"), + Valid: true, + } + + locID := uuid.NullUUID{ + UUID: uuid.MustParse("146182ff-9923-4869-bca1-0bbc0f822025"), + Valid: true, + } + locID2 := uuid.NullUUID{ + UUID: uuid.MustParse("6e30694b-6497-439f-bf99-83af165e02c3"), + Valid: true, + } + + started := sql.NullTime{Time: time.Now(), Valid: true} + started2 := sql.NullTime{ + Time: func() time.Time { + t, _ := time.Parse(time.RFC3339, "1980-01-01T09:30:00Z") + return t + }(), + Valid: true, + } + + completed := sql.NullTime{Time: started.Time.Add(time.Second), Valid: true} + completed2 := sql.NullTime{Time: started2.Time.Add(time.Second), Valid: true} + + type params struct { + pkg *datatypes.Package + updater persistence.PackageUpdater + } + tests := []struct { + name string + args params + want *datatypes.Package + wantErr string + }{ + { + name: "Updates all package columns", + args: params{ + pkg: &datatypes.Package{ + Name: "Test package", + WorkflowID: "workflow-1", + RunID: runID.String(), + AIPID: aipID, + LocationID: locID, + Status: enums.PackageStatusInProgress, + StartedAt: started, + CompletedAt: completed, + }, + updater: func(p *datatypes.Package) (*datatypes.Package, error) { + p.ID = 100 // No-op, can't update ID. + p.Name = "Updated package" + p.WorkflowID = "workflow-2" + p.RunID = runID2.String() + p.AIPID = aipID2 + p.LocationID = locID2 + p.Status = enums.PackageStatusDone + p.CreatedAt = started2.Time // No-op, can't update CreatedAt. + p.StartedAt = started2 + p.CompletedAt = completed2 + return p, nil + }, + }, + want: &datatypes.Package{ + ID: 1, + Name: "Updated package", + WorkflowID: "workflow-2", + RunID: runID2.String(), + AIPID: aipID2, + LocationID: locID2, + Status: enums.PackageStatusDone, + CreatedAt: time.Now(), + StartedAt: started2, + CompletedAt: completed2, + }, + }, + { + name: "Only updates selected columns", + args: params{ + pkg: &datatypes.Package{ + Name: "Test package", + WorkflowID: "workflow-1", + RunID: runID.String(), + AIPID: aipID, + Status: enums.PackageStatusInProgress, + StartedAt: started, + }, + updater: func(p *datatypes.Package) (*datatypes.Package, error) { + p.Status = enums.PackageStatusDone + p.CompletedAt = completed + return p, nil + }, + }, + want: &datatypes.Package{ + ID: 1, + Name: "Test package", + WorkflowID: "workflow-1", + RunID: runID.String(), + AIPID: aipID, + Status: enums.PackageStatusDone, + CreatedAt: time.Now(), + StartedAt: started, + CompletedAt: completed, + }, + }, + { + name: "Errors when package to update is not found", + args: params{ + updater: func(p *datatypes.Package) (*datatypes.Package, error) { + return nil, fmt.Errorf("Bad input") + }, + }, + wantErr: "not found error: db: pkg not found", + }, + { + name: "Errors when the updater errors", + args: params{ + pkg: &datatypes.Package{ + Name: "Test package", + WorkflowID: "workflow-1", + RunID: runID.String(), + AIPID: aipID, + }, + updater: func(p *datatypes.Package) (*datatypes.Package, error) { + return nil, fmt.Errorf("Bad input") + }, + }, + wantErr: "invalid data error: updater error: Bad input", + }, + { + name: "Errors when updater sets an invalid RunID", + args: params{ + pkg: &datatypes.Package{ + Name: "Test package", + WorkflowID: "workflow-1", + RunID: runID.String(), + AIPID: aipID, + }, + updater: func(p *datatypes.Package) (*datatypes.Package, error) { + p.RunID = "Bad UUID" + return p, nil + }, + }, + wantErr: "invalid data error: parse error: field \"RunID\": invalid UUID length: 8", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + _, svc := setUpClient(t, logr.Discard()) + ctx := context.Background() + + var id uint + if tt.args.pkg != nil { + pkg := *tt.args.pkg // Make a local copy of pkg. + err := svc.CreatePackage(ctx, &pkg) + assert.NilError(t, err) + + id = pkg.ID + } + + pp, err := svc.UpdatePackage(ctx, id, tt.args.updater) + if tt.wantErr != "" { + assert.Error(t, err, tt.wantErr) + return + } + + assert.DeepEqual(t, pp, tt.want, + cmpopts.EquateApproxTime(time.Millisecond*100), + cmpopts.IgnoreUnexported(db.Pkg{}, db.PkgEdges{}), + ) + }) + } +} From 227f14ce7887236e26779d7448b7fa40719a8b67 Mon Sep 17 00:00:00 2001 From: David Juhasz Date: Tue, 26 Mar 2024 17:42:55 -0700 Subject: [PATCH 4/9] Move PreservationTask entities - Move `package_.PreservationTask` struct to `datatypes.PreservationTask` to avoid circular imports - Move PreservationTask enums to the `enums` package to avoid circular imports --- internal/a3m/a3m.go | 14 +- internal/am/job_tracker.go | 18 ++- internal/am/job_tracker_test.go | 14 +- internal/am/poll_ingest_test.go | 4 +- internal/datatypes/preservation_task.go | 20 +++ internal/db/convert_test.go | 3 +- internal/enums/preservation_task.go | 70 +++++++++ internal/package_/fake/mock_package_.go | 12 +- internal/package_/package_.go | 4 +- internal/package_/preservation_action.go | 190 +---------------------- internal/package_/preservation_task.go | 117 ++++++++++++++ internal/workflow/local_activities.go | 18 +-- internal/workflow/processing.go | 14 +- 13 files changed, 262 insertions(+), 236 deletions(-) create mode 100644 internal/datatypes/preservation_task.go create mode 100644 internal/enums/preservation_task.go create mode 100644 internal/package_/preservation_task.go diff --git a/internal/a3m/a3m.go b/internal/a3m/a3m.go index 5819606c0..9e0f41431 100644 --- a/internal/a3m/a3m.go +++ b/internal/a3m/a3m.go @@ -12,6 +12,8 @@ import ( temporalsdk_activity "go.temporal.io/sdk/activity" "google.golang.org/grpc" + "github.com/artefactual-sdps/enduro/internal/datatypes" + "github.com/artefactual-sdps/enduro/internal/enums" "github.com/artefactual-sdps/enduro/internal/package_" ) @@ -158,15 +160,15 @@ func (a *CreateAIPActivity) Execute( } func savePreservationTasks(ctx context.Context, jobs []*transferservice.Job, pkgsvc package_.Service, paID uint) error { - jobStatusToPreservationTaskStatus := map[transferservice.Job_Status]package_.PreservationTaskStatus{ - transferservice.Job_STATUS_UNSPECIFIED: package_.TaskStatusUnspecified, - transferservice.Job_STATUS_COMPLETE: package_.TaskStatusDone, - transferservice.Job_STATUS_PROCESSING: package_.TaskStatusInProgress, - transferservice.Job_STATUS_FAILED: package_.TaskStatusError, + jobStatusToPreservationTaskStatus := map[transferservice.Job_Status]enums.PreservationTaskStatus{ + transferservice.Job_STATUS_UNSPECIFIED: enums.PreservationTaskStatusUnspecified, + transferservice.Job_STATUS_COMPLETE: enums.PreservationTaskStatusDone, + transferservice.Job_STATUS_PROCESSING: enums.PreservationTaskStatusInProgress, + transferservice.Job_STATUS_FAILED: enums.PreservationTaskStatusError, } for _, job := range jobs { - pt := package_.PreservationTask{ + pt := datatypes.PreservationTask{ TaskID: job.Id, Name: job.Name, Status: jobStatusToPreservationTaskStatus[job.Status], diff --git a/internal/am/job_tracker.go b/internal/am/job_tracker.go index e50f06359..aa5a7c61e 100644 --- a/internal/am/job_tracker.go +++ b/internal/am/job_tracker.go @@ -7,14 +7,16 @@ import ( "github.com/jonboulle/clockwork" "go.artefactual.dev/amclient" + "github.com/artefactual-sdps/enduro/internal/datatypes" + "github.com/artefactual-sdps/enduro/internal/enums" "github.com/artefactual-sdps/enduro/internal/package_" ) -var jobStatusToPreservationTaskStatus = map[amclient.JobStatus]package_.PreservationTaskStatus{ - amclient.JobStatusUnknown: package_.TaskStatusUnspecified, - amclient.JobStatusComplete: package_.TaskStatusDone, - amclient.JobStatusProcessing: package_.TaskStatusInProgress, - amclient.JobStatusFailed: package_.TaskStatusError, +var jobStatusToPreservationTaskStatus = map[amclient.JobStatus]enums.PreservationTaskStatus{ + amclient.JobStatusUnknown: enums.PreservationTaskStatusUnspecified, + amclient.JobStatusComplete: enums.PreservationTaskStatusDone, + amclient.JobStatusProcessing: enums.PreservationTaskStatusInProgress, + amclient.JobStatusFailed: enums.PreservationTaskStatusError, } type JobTracker struct { @@ -116,10 +118,10 @@ func filterSavedJobs(jobs []amclient.Job, saved map[string]struct{}) []amclient. } // ConvertJobToPreservationTask converts an amclient.Job to a -// package_.PreservationTask. -func ConvertJobToPreservationTask(job amclient.Job) package_.PreservationTask { +// datatypes.PreservationTask. +func ConvertJobToPreservationTask(job amclient.Job) datatypes.PreservationTask { st, co := jobTimeRange(job) - return package_.PreservationTask{ + return datatypes.PreservationTask{ TaskID: job.ID, Name: job.Name, Status: jobStatusToPreservationTaskStatus[job.Status], diff --git a/internal/am/job_tracker_test.go b/internal/am/job_tracker_test.go index 4d566eccc..90a9f18bb 100644 --- a/internal/am/job_tracker_test.go +++ b/internal/am/job_tracker_test.go @@ -17,8 +17,8 @@ import ( "gotest.tools/v3/assert" "github.com/artefactual-sdps/enduro/internal/am" + "github.com/artefactual-sdps/enduro/internal/datatypes" "github.com/artefactual-sdps/enduro/internal/enums" - "github.com/artefactual-sdps/enduro/internal/package_" fake_package "github.com/artefactual-sdps/enduro/internal/package_/fake" ) @@ -181,7 +181,7 @@ func TestConvertJobToPreservationTask(t *testing.T) { type test struct { name string job amclient.Job - want package_.PreservationTask + want datatypes.PreservationTask } for _, tt := range []test{ @@ -212,10 +212,10 @@ func TestConvertJobToPreservationTask(t *testing.T) { }, }, }, - want: package_.PreservationTask{ + want: datatypes.PreservationTask{ TaskID: "f60018ac-da79-4769-9509-c6c41d5efe7e", Name: "Move to processing directory", - Status: package_.PreservationTaskStatus(enums.PackageStatusDone), + Status: enums.PreservationTaskStatus(enums.PackageStatusDone), StartedAt: sql.NullTime{ Time: time.Date(2024, time.January, 18, 1, 27, 49, 0, time.UTC), Valid: true, @@ -253,10 +253,10 @@ func TestConvertJobToPreservationTask(t *testing.T) { }, }, }, - want: package_.PreservationTask{ + want: datatypes.PreservationTask{ TaskID: "c2128d39-2ace-47c5-8cac-39ded8d9c9ef", Name: "Verify SIP compliance", - Status: package_.PreservationTaskStatus(enums.PackageStatusInProgress), + Status: enums.PreservationTaskStatus(enums.PackageStatusInProgress), StartedAt: sql.NullTime{ Time: time.Date(2024, time.January, 18, 1, 27, 49, 0, time.UTC), Valid: true, @@ -267,7 +267,7 @@ func TestConvertJobToPreservationTask(t *testing.T) { { name: "Returns NULL timestamps in the job has no tasks", job: amclient.Job{}, - want: package_.PreservationTask{}, + want: datatypes.PreservationTask{}, }, } { t.Run(tt.name, func(t *testing.T) { diff --git a/internal/am/poll_ingest_test.go b/internal/am/poll_ingest_test.go index aec5eb259..30eb6245d 100644 --- a/internal/am/poll_ingest_test.go +++ b/internal/am/poll_ingest_test.go @@ -18,7 +18,7 @@ import ( "gotest.tools/v3/assert" "github.com/artefactual-sdps/enduro/internal/am" - "github.com/artefactual-sdps/enduro/internal/package_" + "github.com/artefactual-sdps/enduro/internal/datatypes" fake_package "github.com/artefactual-sdps/enduro/internal/package_/fake" ) @@ -154,7 +154,7 @@ func TestPollIngestActivity(t *testing.T) { ) }, pkgRec: func(m *fake_package.MockServiceMockRecorder) { - tasks := make([]*package_.PreservationTask, len(jobs)) + tasks := make([]*datatypes.PreservationTask, len(jobs)) for i, job := range jobs { pt := am.ConvertJobToPreservationTask(job) pt.PreservationActionID = presActionID diff --git a/internal/datatypes/preservation_task.go b/internal/datatypes/preservation_task.go new file mode 100644 index 000000000..600c25c69 --- /dev/null +++ b/internal/datatypes/preservation_task.go @@ -0,0 +1,20 @@ +package datatypes + +import ( + "database/sql" + + "github.com/artefactual-sdps/enduro/internal/enums" +) + +// PreservationTask represents a preservation action task in the +// preservation_task table. +type PreservationTask struct { + ID uint `db:"id"` + TaskID string `db:"task_id"` + Name string `db:"name"` + Status enums.PreservationTaskStatus `db:"status"` + StartedAt sql.NullTime `db:"started_at"` + CompletedAt sql.NullTime `db:"completed_at"` + Note string + PreservationActionID uint `db:"preservation_action_id"` +} diff --git a/internal/db/convert_test.go b/internal/db/convert_test.go index 00b4a267c..2aa5bec1c 100644 --- a/internal/db/convert_test.go +++ b/internal/db/convert_test.go @@ -5,8 +5,9 @@ import ( "testing" "time" - "github.com/artefactual-sdps/enduro/internal/db" "gotest.tools/v3/assert" + + "github.com/artefactual-sdps/enduro/internal/db" ) func TestFormatOptionalString(t *testing.T) { diff --git a/internal/enums/preservation_task.go b/internal/enums/preservation_task.go new file mode 100644 index 000000000..4406264e7 --- /dev/null +++ b/internal/enums/preservation_task.go @@ -0,0 +1,70 @@ +package enums + +import ( + "encoding/json" + "strings" +) + +type PreservationTaskStatus uint + +const ( + PreservationTaskStatusUnspecified PreservationTaskStatus = iota + PreservationTaskStatusInProgress + PreservationTaskStatusDone + PreservationTaskStatusError + PreservationTaskStatusQueued + PreservationTaskStatusPending +) + +func NewPreservationTaskStatus(status string) PreservationTaskStatus { + var s PreservationTaskStatus + + switch strings.ToLower(status) { + case "in progress": + s = PreservationTaskStatusInProgress + case "done": + s = PreservationTaskStatusDone + case "error": + s = PreservationTaskStatusError + case "queued": + s = PreservationTaskStatusQueued + case "pending": + s = PreservationTaskStatusPending + default: + s = PreservationTaskStatusUnspecified + } + + return s +} + +func (p PreservationTaskStatus) String() string { + switch p { + case PreservationTaskStatusInProgress: + return "in progress" + case PreservationTaskStatusDone: + return "done" + case PreservationTaskStatusError: + return "error" + case PreservationTaskStatusQueued: + return "queued" + case PreservationTaskStatusPending: + return "pending" + } + + return "unspecified" +} + +func (p PreservationTaskStatus) MarshalJSON() ([]byte, error) { + return json.Marshal(p.String()) +} + +func (p *PreservationTaskStatus) UnmarshalJSON(b []byte) error { + var s string + if err := json.Unmarshal(b, &s); err != nil { + return err + } + + *p = NewPreservationTaskStatus(s) + + return nil +} diff --git a/internal/package_/fake/mock_package_.go b/internal/package_/fake/mock_package_.go index c8183d88f..eae61cea8 100644 --- a/internal/package_/fake/mock_package_.go +++ b/internal/package_/fake/mock_package_.go @@ -84,7 +84,7 @@ func (c *MockServiceCompletePreservationActionCall) DoAndReturn(f func(context.C } // CompletePreservationTask mocks base method. -func (m *MockService) CompletePreservationTask(arg0 context.Context, arg1 uint, arg2 package_0.PreservationTaskStatus, arg3 time.Time, arg4 *string) error { +func (m *MockService) CompletePreservationTask(arg0 context.Context, arg1 uint, arg2 enums.PreservationTaskStatus, arg3 time.Time, arg4 *string) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "CompletePreservationTask", arg0, arg1, arg2, arg3, arg4) ret0, _ := ret[0].(error) @@ -110,13 +110,13 @@ func (c *MockServiceCompletePreservationTaskCall) Return(arg0 error) *MockServic } // Do rewrite *gomock.Call.Do -func (c *MockServiceCompletePreservationTaskCall) Do(f func(context.Context, uint, package_0.PreservationTaskStatus, time.Time, *string) error) *MockServiceCompletePreservationTaskCall { +func (c *MockServiceCompletePreservationTaskCall) Do(f func(context.Context, uint, enums.PreservationTaskStatus, time.Time, *string) error) *MockServiceCompletePreservationTaskCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockServiceCompletePreservationTaskCall) DoAndReturn(f func(context.Context, uint, package_0.PreservationTaskStatus, time.Time, *string) error) *MockServiceCompletePreservationTaskCall { +func (c *MockServiceCompletePreservationTaskCall) DoAndReturn(f func(context.Context, uint, enums.PreservationTaskStatus, time.Time, *string) error) *MockServiceCompletePreservationTaskCall { c.Call = c.Call.DoAndReturn(f) return c } @@ -198,7 +198,7 @@ func (c *MockServiceCreatePreservationActionCall) DoAndReturn(f func(context.Con } // CreatePreservationTask mocks base method. -func (m *MockService) CreatePreservationTask(arg0 context.Context, arg1 *package_0.PreservationTask) error { +func (m *MockService) CreatePreservationTask(arg0 context.Context, arg1 *datatypes.PreservationTask) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "CreatePreservationTask", arg0, arg1) ret0, _ := ret[0].(error) @@ -224,13 +224,13 @@ func (c *MockServiceCreatePreservationTaskCall) Return(arg0 error) *MockServiceC } // Do rewrite *gomock.Call.Do -func (c *MockServiceCreatePreservationTaskCall) Do(f func(context.Context, *package_0.PreservationTask) error) *MockServiceCreatePreservationTaskCall { +func (c *MockServiceCreatePreservationTaskCall) Do(f func(context.Context, *datatypes.PreservationTask) error) *MockServiceCreatePreservationTaskCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockServiceCreatePreservationTaskCall) DoAndReturn(f func(context.Context, *package_0.PreservationTask) error) *MockServiceCreatePreservationTaskCall { +func (c *MockServiceCreatePreservationTaskCall) DoAndReturn(f func(context.Context, *datatypes.PreservationTask) error) *MockServiceCreatePreservationTaskCall { c.Call = c.Call.DoAndReturn(f) return c } diff --git a/internal/package_/package_.go b/internal/package_/package_.go index aeb370395..93df88498 100644 --- a/internal/package_/package_.go +++ b/internal/package_/package_.go @@ -42,11 +42,11 @@ type Service interface { status PreservationActionStatus, completedAt time.Time, ) error - CreatePreservationTask(ctx context.Context, pt *PreservationTask) error + CreatePreservationTask(ctx context.Context, pt *datatypes.PreservationTask) error CompletePreservationTask( ctx context.Context, ID uint, - status PreservationTaskStatus, + status enums.PreservationTaskStatus, completedAt time.Time, note *string, ) error diff --git a/internal/package_/preservation_action.go b/internal/package_/preservation_action.go index 21339512b..cf9f1e84b 100644 --- a/internal/package_/preservation_action.go +++ b/internal/package_/preservation_action.go @@ -11,6 +11,7 @@ import ( "go.artefactual.dev/tools/ref" goapackage "github.com/artefactual-sdps/enduro/internal/api/gen/package_" + "github.com/artefactual-sdps/enduro/internal/datatypes" "github.com/artefactual-sdps/enduro/internal/db" "github.com/artefactual-sdps/enduro/internal/event" ) @@ -144,82 +145,6 @@ type PreservationAction struct { PackageID uint `db:"package_id"` } -type PreservationTaskStatus uint - -const ( - TaskStatusUnspecified PreservationTaskStatus = iota - TaskStatusInProgress - TaskStatusDone - TaskStatusError - TaskStatusQueued - TaskStatusPending -) - -func NewPreservationTaskStatus(status string) PreservationTaskStatus { - var s PreservationTaskStatus - - switch strings.ToLower(status) { - case "in progress": - s = TaskStatusInProgress - case "done": - s = TaskStatusDone - case "error": - s = TaskStatusError - case "queued": - s = TaskStatusQueued - case "pending": - s = TaskStatusPending - default: - s = TaskStatusUnspecified - } - - return s -} - -func (p PreservationTaskStatus) String() string { - switch p { - case TaskStatusInProgress: - return "in progress" - case TaskStatusDone: - return "done" - case TaskStatusError: - return "error" - case TaskStatusQueued: - return "queued" - case TaskStatusPending: - return "pending" - } - - return "unspecified" -} - -func (p PreservationTaskStatus) MarshalJSON() ([]byte, error) { - return json.Marshal(p.String()) -} - -func (p *PreservationTaskStatus) UnmarshalJSON(b []byte) error { - var s string - if err := json.Unmarshal(b, &s); err != nil { - return err - } - - *p = NewPreservationTaskStatus(s) - - return nil -} - -// PreservationTask represents a preservation action task in the preservation_task table. -type PreservationTask struct { - ID uint `db:"id"` - TaskID string `db:"task_id"` - Name string `db:"name"` - Status PreservationTaskStatus `db:"status"` - StartedAt sql.NullTime `db:"started_at"` - CompletedAt sql.NullTime `db:"completed_at"` - Note string - PreservationActionID uint `db:"preservation_action_id"` -} - func (w *goaWrapper) PreservationActions( ctx context.Context, payload *goapackage.PreservationActionsPayload, @@ -264,7 +189,7 @@ func (w *goaWrapper) PreservationActions( preservation_tasks := []*goapackage.EnduroPackagePreservationTask{} for ptRows.Next() { - pt := PreservationTask{} + pt := datatypes.PreservationTask{} if err := ptRows.StructScan(&pt); err != nil { return nil, fmt.Errorf("error scanning database result: %w", err) } @@ -381,78 +306,6 @@ func (svc *packageImpl) CompletePreservationAction( return nil } -func (svc *packageImpl) CreatePreservationTask(ctx context.Context, pt *PreservationTask) error { - startedAt := &pt.StartedAt.Time - completedAt := &pt.CompletedAt.Time - if pt.StartedAt.Time.IsZero() { - startedAt = nil - } - if pt.CompletedAt.Time.IsZero() { - completedAt = nil - } - - query := `INSERT INTO preservation_task (task_id, name, status, started_at, completed_at, note, preservation_action_id) VALUES (?, ?, ?, ?, ?, ?, ?)` - args := []interface{}{ - pt.TaskID, - pt.Name, - pt.Status, - startedAt, - completedAt, - pt.Note, - pt.PreservationActionID, - } - - res, err := svc.db.ExecContext(ctx, query, args...) - if err != nil { - return fmt.Errorf("error inserting preservation task: %w", err) - } - - var id int64 - if id, err = res.LastInsertId(); err != nil { - return fmt.Errorf("error retrieving insert ID: %w", err) - } - - pt.ID = uint(id) - - if item, err := svc.readPreservationTask(ctx, pt.ID); err == nil { - ev := &goapackage.PreservationTaskCreatedEvent{ID: pt.ID, Item: item} - event.PublishEvent(ctx, svc.evsvc, ev) - } - - return nil -} - -func (svc *packageImpl) CompletePreservationTask( - ctx context.Context, - ID uint, - status PreservationTaskStatus, - completedAt time.Time, - note *string, -) error { - var query string - args := []interface{}{} - - if note != nil { - query = `UPDATE preservation_task SET note = ?, status = ?, completed_at = ? WHERE id = ?` - args = append(args, note, status, completedAt, ID) - } else { - query = `UPDATE preservation_task SET status = ?, completed_at = ? WHERE id = ?` - args = append(args, status, completedAt, ID) - } - - _, err := svc.db.ExecContext(ctx, query, args...) - if err != nil { - return fmt.Errorf("error updating preservation task: %w", err) - } - - if item, err := svc.readPreservationTask(ctx, ID); err == nil { - ev := &goapackage.PreservationTaskUpdatedEvent{ID: ID, Item: item} - event.PublishEvent(ctx, svc.evsvc, ev) - } - - return nil -} - func (svc *packageImpl) readPreservationAction( ctx context.Context, ID uint, @@ -489,42 +342,3 @@ func (svc *packageImpl) readPreservationAction( return &item, nil } - -func (svc *packageImpl) readPreservationTask( - ctx context.Context, - ID uint, -) (*goapackage.EnduroPackagePreservationTask, error) { - query := ` - SELECT - preservation_task.id, - preservation_task.task_id, - preservation_task.name, - preservation_task.status, - CONVERT_TZ(preservation_task.started_at, @@session.time_zone, '+00:00') AS started_at, - CONVERT_TZ(preservation_task.completed_at, @@session.time_zone, '+00:00') AS completed_at, - preservation_task.note, - preservation_task.preservation_action_id - FROM preservation_task - LEFT JOIN preservation_action ON (preservation_task.preservation_action_id = preservation_action.id) - WHERE preservation_task.id = ? - ` - args := []interface{}{ID} - dbItem := PreservationTask{} - - if err := svc.db.GetContext(ctx, &dbItem, query, args...); err != nil { - return nil, err - } - - item := goapackage.EnduroPackagePreservationTask{ - ID: dbItem.ID, - TaskID: dbItem.TaskID, - Name: dbItem.Name, - Status: dbItem.Status.String(), - StartedAt: ref.DerefZero(db.FormatOptionalTime(dbItem.StartedAt)), - CompletedAt: db.FormatOptionalTime(dbItem.CompletedAt), - Note: ref.New(dbItem.Note), - PreservationActionID: ref.New(dbItem.PreservationActionID), - } - - return &item, nil -} diff --git a/internal/package_/preservation_task.go b/internal/package_/preservation_task.go new file mode 100644 index 000000000..f34614d1f --- /dev/null +++ b/internal/package_/preservation_task.go @@ -0,0 +1,117 @@ +package package_ + +import ( + "context" + "fmt" + "time" + + "go.artefactual.dev/tools/ref" + + goapackage "github.com/artefactual-sdps/enduro/internal/api/gen/package_" + "github.com/artefactual-sdps/enduro/internal/datatypes" + "github.com/artefactual-sdps/enduro/internal/db" + "github.com/artefactual-sdps/enduro/internal/enums" + "github.com/artefactual-sdps/enduro/internal/event" +) + +func (svc *packageImpl) CreatePreservationTask(ctx context.Context, pt *datatypes.PreservationTask) error { + startedAt := &pt.StartedAt.Time + completedAt := &pt.CompletedAt.Time + if pt.StartedAt.Time.IsZero() { + startedAt = nil + } + if pt.CompletedAt.Time.IsZero() { + completedAt = nil + } + + query := `INSERT INTO preservation_task (task_id, name, status, started_at, completed_at, note, preservation_action_id) VALUES (?, ?, ?, ?, ?, ?, ?)` + args := []interface{}{ + pt.TaskID, + pt.Name, + pt.Status, + startedAt, + completedAt, + pt.Note, + pt.PreservationActionID, + } + + res, err := svc.db.ExecContext(ctx, query, args...) + if err != nil { + return fmt.Errorf("error inserting preservation task: %w", err) + } + + var id int64 + if id, err = res.LastInsertId(); err != nil { + return fmt.Errorf("error retrieving insert ID: %w", err) + } + + pt.ID = uint(id) + + if item, err := svc.readPreservationTask(ctx, pt.ID); err == nil { + ev := &goapackage.PreservationTaskCreatedEvent{ID: pt.ID, Item: item} + event.PublishEvent(ctx, svc.evsvc, ev) + } + + return nil +} + +func (svc *packageImpl) CompletePreservationTask(ctx context.Context, ID uint, status enums.PreservationTaskStatus, completedAt time.Time, note *string) error { + var query string + args := []interface{}{} + + if note != nil { + query = `UPDATE preservation_task SET note = ?, status = ?, completed_at = ? WHERE id = ?` + args = append(args, note, status, completedAt, ID) + } else { + query = `UPDATE preservation_task SET status = ?, completed_at = ? WHERE id = ?` + args = append(args, status, completedAt, ID) + } + + _, err := svc.db.ExecContext(ctx, query, args...) + if err != nil { + return fmt.Errorf("error updating preservation task: %w", err) + } + + if item, err := svc.readPreservationTask(ctx, ID); err == nil { + ev := &goapackage.PreservationTaskUpdatedEvent{ID: ID, Item: item} + event.PublishEvent(ctx, svc.evsvc, ev) + } + + return nil +} + +func (svc *packageImpl) readPreservationTask(ctx context.Context, ID uint) (*goapackage.EnduroPackagePreservationTask, error) { + query := ` + SELECT + preservation_task.id, + preservation_task.task_id, + preservation_task.name, + preservation_task.status, + CONVERT_TZ(preservation_task.started_at, @@session.time_zone, '+00:00') AS started_at, + CONVERT_TZ(preservation_task.completed_at, @@session.time_zone, '+00:00') AS completed_at, + preservation_task.note, + preservation_task.preservation_action_id + FROM preservation_task + LEFT JOIN preservation_action ON (preservation_task.preservation_action_id = preservation_action.id) + WHERE preservation_task.id = ? + ` + args := []interface{}{ID} + dbItem := datatypes.PreservationTask{} + + if err := svc.db.GetContext(ctx, &dbItem, query, args...); err != nil { + return nil, err + } + + item := goapackage.EnduroPackagePreservationTask{ + ID: dbItem.ID, + TaskID: dbItem.TaskID, + Name: dbItem.Name, + Status: dbItem.Status.String(), + StartedAt: ref.DerefZero(db.FormatOptionalTime(dbItem.StartedAt)), + CompletedAt: db.FormatOptionalTime(dbItem.CompletedAt), + Note: ref.New(dbItem.Note), + PreservationActionID: ref.New(dbItem.PreservationActionID), + } + + return &item, nil +} diff --git a/internal/workflow/local_activities.go b/internal/workflow/local_activities.go index 3ca30c261..3298f5888 100644 --- a/internal/workflow/local_activities.go +++ b/internal/workflow/local_activities.go @@ -140,14 +140,14 @@ func saveLocationMovePreservationActionLocalActivity( return &saveLocationMovePreservationActionLocalActivityResult{}, err } - actionStatusToTaskStatus := map[package_.PreservationActionStatus]package_.PreservationTaskStatus{ - package_.ActionStatusUnspecified: package_.TaskStatusUnspecified, - package_.ActionStatusDone: package_.TaskStatusDone, - package_.ActionStatusInProgress: package_.TaskStatusInProgress, - package_.ActionStatusError: package_.TaskStatusError, + actionStatusToTaskStatus := map[package_.PreservationActionStatus]enums.PreservationTaskStatus{ + package_.ActionStatusUnspecified: enums.PreservationTaskStatusUnspecified, + package_.ActionStatusDone: enums.PreservationTaskStatusDone, + package_.ActionStatusInProgress: enums.PreservationTaskStatusInProgress, + package_.ActionStatusError: enums.PreservationTaskStatusError, } - pt := package_.PreservationTask{ + pt := datatypes.PreservationTask{ TaskID: uuid.NewString(), Name: "Move AIP", Status: actionStatusToTaskStatus[params.Status], @@ -225,7 +225,7 @@ func completePreservationActionLocalActivity( type createPreservationTaskLocalActivityParams struct { TaskID string Name string - Status package_.PreservationTaskStatus + Status enums.PreservationTaskStatus StartedAt time.Time CompletedAt time.Time Note string @@ -237,7 +237,7 @@ func createPreservationTaskLocalActivity( pkgsvc package_.Service, params *createPreservationTaskLocalActivityParams, ) (uint, error) { - pt := package_.PreservationTask{ + pt := datatypes.PreservationTask{ TaskID: params.TaskID, Name: params.Name, Status: params.Status, @@ -256,7 +256,7 @@ func createPreservationTaskLocalActivity( type completePreservationTaskLocalActivityParams struct { ID uint - Status package_.PreservationTaskStatus + Status enums.PreservationTaskStatus CompletedAt time.Time Note *string } diff --git a/internal/workflow/processing.go b/internal/workflow/processing.go index 13d6e8268..70de82ef5 100644 --- a/internal/workflow/processing.go +++ b/internal/workflow/processing.go @@ -418,7 +418,7 @@ func (w *ProcessingWorkflow) SessionHandler( err := temporalsdk_workflow.ExecuteLocalActivity(ctx, createPreservationTaskLocalActivity, w.pkgsvc, &createPreservationTaskLocalActivityParams{ TaskID: uuid.NewString(), Name: "Move AIP", - Status: package_.TaskStatusInProgress, + Status: enums.PreservationTaskStatusInProgress, StartedAt: temporalsdk_workflow.Now(sessCtx).UTC(), Note: "Moving to review bucket", PreservationActionID: tinfo.PreservationActionID, @@ -455,7 +455,7 @@ func (w *ProcessingWorkflow) SessionHandler( ctx := withLocalActivityOpts(sessCtx) err := temporalsdk_workflow.ExecuteLocalActivity(ctx, completePreservationTaskLocalActivity, w.pkgsvc, &completePreservationTaskLocalActivityParams{ ID: uploadPreservationTaskID, - Status: package_.TaskStatusDone, + Status: enums.PreservationTaskStatusDone, CompletedAt: temporalsdk_workflow.Now(sessCtx).UTC(), Note: ref.New("Moved to review bucket"), }). @@ -500,7 +500,7 @@ func (w *ProcessingWorkflow) SessionHandler( err := temporalsdk_workflow.ExecuteLocalActivity(ctx, createPreservationTaskLocalActivity, w.pkgsvc, &createPreservationTaskLocalActivityParams{ TaskID: uuid.NewString(), Name: "Review AIP", - Status: package_.TaskStatusPending, + Status: enums.PreservationTaskStatusPending, StartedAt: temporalsdk_workflow.Now(sessCtx).UTC(), Note: "Awaiting user decision", PreservationActionID: tinfo.PreservationActionID, @@ -539,7 +539,7 @@ func (w *ProcessingWorkflow) SessionHandler( ctx := withLocalActivityOpts(sessCtx) err := temporalsdk_workflow.ExecuteLocalActivity(ctx, completePreservationTaskLocalActivity, w.pkgsvc, &completePreservationTaskLocalActivityParams{ ID: reviewPreservationTaskID, - Status: package_.TaskStatusDone, + Status: enums.PreservationTaskStatusDone, CompletedAt: reviewCompletedAt, Note: ref.New("Reviewed and accepted"), }). @@ -558,7 +558,7 @@ func (w *ProcessingWorkflow) SessionHandler( err := temporalsdk_workflow.ExecuteLocalActivity(ctx, createPreservationTaskLocalActivity, w.pkgsvc, &createPreservationTaskLocalActivityParams{ TaskID: uuid.NewString(), Name: "Move AIP", - Status: package_.TaskStatusInProgress, + Status: enums.PreservationTaskStatusInProgress, StartedAt: temporalsdk_workflow.Now(sessCtx).UTC(), Note: "Moving to permanent storage", PreservationActionID: tinfo.PreservationActionID, @@ -599,7 +599,7 @@ func (w *ProcessingWorkflow) SessionHandler( ctx := withLocalActivityOpts(sessCtx) err := temporalsdk_workflow.ExecuteLocalActivity(ctx, completePreservationTaskLocalActivity, w.pkgsvc, &completePreservationTaskLocalActivityParams{ ID: movePreservationTaskID, - Status: package_.TaskStatusDone, + Status: enums.PreservationTaskStatusDone, CompletedAt: temporalsdk_workflow.Now(sessCtx).UTC(), Note: ref.New(fmt.Sprintf("Moved to location %s", *reviewResult.LocationID)), }). @@ -624,7 +624,7 @@ func (w *ProcessingWorkflow) SessionHandler( ctx := withLocalActivityOpts(sessCtx) err := temporalsdk_workflow.ExecuteLocalActivity(ctx, completePreservationTaskLocalActivity, w.pkgsvc, &completePreservationTaskLocalActivityParams{ ID: reviewPreservationTaskID, - Status: package_.TaskStatusDone, + Status: enums.PreservationTaskStatusDone, CompletedAt: reviewCompletedAt, Note: ref.New("Reviewed and rejected"), }).Get(ctx, nil) From 91f3ae610d83ca29b56cc1b45bc28bd1afff27f6 Mon Sep 17 00:00:00 2001 From: David Juhasz Date: Tue, 26 Mar 2024 18:05:22 -0700 Subject: [PATCH 5/9] Move PreservationAction entities - Move PreservationAction struct to `datatypes` package - Move PreservationAction enums to `enums` package --- internal/datatypes/preservation_action.go | 18 +++ internal/enums/preservation_action.go | 124 +++++++++++++++++++ internal/package_/fake/mock_package_.go | 19 ++- internal/package_/package_.go | 11 +- internal/package_/preservation_action.go | 143 +--------------------- internal/workflow/local_activities.go | 24 ++-- internal/workflow/move.go | 12 +- internal/workflow/processing.go | 20 +-- 8 files changed, 188 insertions(+), 183 deletions(-) create mode 100644 internal/datatypes/preservation_action.go create mode 100644 internal/enums/preservation_action.go diff --git a/internal/datatypes/preservation_action.go b/internal/datatypes/preservation_action.go new file mode 100644 index 000000000..d090d1ed8 --- /dev/null +++ b/internal/datatypes/preservation_action.go @@ -0,0 +1,18 @@ +package datatypes + +import ( + "database/sql" + + "github.com/artefactual-sdps/enduro/internal/enums" +) + +// PreservationAction represents a preservation action in the preservation_action table. +type PreservationAction struct { + ID uint `db:"id"` + WorkflowID string `db:"workflow_id"` + Type enums.PreservationActionType `db:"type"` + Status enums.PreservationActionStatus `db:"status"` + StartedAt sql.NullTime `db:"started_at"` + CompletedAt sql.NullTime `db:"completed_at"` + PackageID uint `db:"package_id"` +} diff --git a/internal/enums/preservation_action.go b/internal/enums/preservation_action.go new file mode 100644 index 000000000..9166c135b --- /dev/null +++ b/internal/enums/preservation_action.go @@ -0,0 +1,124 @@ +package enums + +import ( + "encoding/json" + "strings" +) + +type PreservationActionType uint + +const ( + PreservationActionTypeUnspecified PreservationActionType = iota + PreservationActionTypeCreateAIP + PreservationActionTypeCreateAndReviewAIP + PreservationActionTypeMovePackage +) + +func NewPreservationActionType(status string) PreservationActionType { + var s PreservationActionType + + switch strings.ToLower(status) { + case "create-aip": + s = PreservationActionTypeCreateAIP + case "create-and-review-aip": + s = PreservationActionTypeCreateAndReviewAIP + case "move-package": + s = PreservationActionTypeMovePackage + default: + s = PreservationActionTypeUnspecified + } + + return s +} + +func (p PreservationActionType) String() string { + switch p { + case PreservationActionTypeCreateAIP: + return "create-aip" + case PreservationActionTypeCreateAndReviewAIP: + return "create-and-review-aip" + case PreservationActionTypeMovePackage: + return "move-package" + } + + return "unspecified" +} + +func (p PreservationActionType) MarshalJSON() ([]byte, error) { + return json.Marshal(p.String()) +} + +func (p *PreservationActionType) UnmarshalJSON(b []byte) error { + var s string + if err := json.Unmarshal(b, &s); err != nil { + return err + } + + *p = NewPreservationActionType(s) + + return nil +} + +type PreservationActionStatus uint + +const ( + PreservationActionStatusUnspecified PreservationActionStatus = iota + PreservationActionStatusInProgress + PreservationActionStatusDone + PreservationActionStatusError + PreservationActionStatusQueued + PreservationActionStatusPending +) + +func NewPreservationActionStatus(status string) PreservationActionStatus { + var s PreservationActionStatus + + switch strings.ToLower(status) { + case "in progress": + s = PreservationActionStatusInProgress + case "done": + s = PreservationActionStatusDone + case "error": + s = PreservationActionStatusError + case "queued": + s = PreservationActionStatusQueued + case "pending": + s = PreservationActionStatusPending + default: + s = PreservationActionStatusUnspecified + } + + return s +} + +func (p PreservationActionStatus) String() string { + switch p { + case PreservationActionStatusInProgress: + return "in progress" + case PreservationActionStatusDone: + return "done" + case PreservationActionStatusError: + return "error" + case PreservationActionStatusQueued: + return "queued" + case PreservationActionStatusPending: + return "pending" + } + + return "unspecified" +} + +func (p PreservationActionStatus) MarshalJSON() ([]byte, error) { + return json.Marshal(p.String()) +} + +func (p *PreservationActionStatus) UnmarshalJSON(b []byte) error { + var s string + if err := json.Unmarshal(b, &s); err != nil { + return err + } + + *p = NewPreservationActionStatus(s) + + return nil +} diff --git a/internal/package_/fake/mock_package_.go b/internal/package_/fake/mock_package_.go index eae61cea8..ce225eb24 100644 --- a/internal/package_/fake/mock_package_.go +++ b/internal/package_/fake/mock_package_.go @@ -17,7 +17,6 @@ import ( package_ "github.com/artefactual-sdps/enduro/internal/api/gen/package_" datatypes "github.com/artefactual-sdps/enduro/internal/datatypes" enums "github.com/artefactual-sdps/enduro/internal/enums" - package_0 "github.com/artefactual-sdps/enduro/internal/package_" uuid "github.com/google/uuid" gomock "go.uber.org/mock/gomock" ) @@ -46,7 +45,7 @@ func (m *MockService) EXPECT() *MockServiceMockRecorder { } // CompletePreservationAction mocks base method. -func (m *MockService) CompletePreservationAction(arg0 context.Context, arg1 uint, arg2 package_0.PreservationActionStatus, arg3 time.Time) error { +func (m *MockService) CompletePreservationAction(arg0 context.Context, arg1 uint, arg2 enums.PreservationActionStatus, arg3 time.Time) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "CompletePreservationAction", arg0, arg1, arg2, arg3) ret0, _ := ret[0].(error) @@ -72,13 +71,13 @@ func (c *MockServiceCompletePreservationActionCall) Return(arg0 error) *MockServ } // Do rewrite *gomock.Call.Do -func (c *MockServiceCompletePreservationActionCall) Do(f func(context.Context, uint, package_0.PreservationActionStatus, time.Time) error) *MockServiceCompletePreservationActionCall { +func (c *MockServiceCompletePreservationActionCall) Do(f func(context.Context, uint, enums.PreservationActionStatus, time.Time) error) *MockServiceCompletePreservationActionCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockServiceCompletePreservationActionCall) DoAndReturn(f func(context.Context, uint, package_0.PreservationActionStatus, time.Time) error) *MockServiceCompletePreservationActionCall { +func (c *MockServiceCompletePreservationActionCall) DoAndReturn(f func(context.Context, uint, enums.PreservationActionStatus, time.Time) error) *MockServiceCompletePreservationActionCall { c.Call = c.Call.DoAndReturn(f) return c } @@ -160,7 +159,7 @@ func (c *MockServiceCreateCall) DoAndReturn(f func(context.Context, *datatypes.P } // CreatePreservationAction mocks base method. -func (m *MockService) CreatePreservationAction(arg0 context.Context, arg1 *package_0.PreservationAction) error { +func (m *MockService) CreatePreservationAction(arg0 context.Context, arg1 *datatypes.PreservationAction) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "CreatePreservationAction", arg0, arg1) ret0, _ := ret[0].(error) @@ -186,13 +185,13 @@ func (c *MockServiceCreatePreservationActionCall) Return(arg0 error) *MockServic } // Do rewrite *gomock.Call.Do -func (c *MockServiceCreatePreservationActionCall) Do(f func(context.Context, *package_0.PreservationAction) error) *MockServiceCreatePreservationActionCall { +func (c *MockServiceCreatePreservationActionCall) Do(f func(context.Context, *datatypes.PreservationAction) error) *MockServiceCreatePreservationActionCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockServiceCreatePreservationActionCall) DoAndReturn(f func(context.Context, *package_0.PreservationAction) error) *MockServiceCreatePreservationActionCall { +func (c *MockServiceCreatePreservationActionCall) DoAndReturn(f func(context.Context, *datatypes.PreservationAction) error) *MockServiceCreatePreservationActionCall { c.Call = c.Call.DoAndReturn(f) return c } @@ -312,7 +311,7 @@ func (c *MockServiceSetLocationIDCall) DoAndReturn(f func(context.Context, uint, } // SetPreservationActionStatus mocks base method. -func (m *MockService) SetPreservationActionStatus(arg0 context.Context, arg1 uint, arg2 package_0.PreservationActionStatus) error { +func (m *MockService) SetPreservationActionStatus(arg0 context.Context, arg1 uint, arg2 enums.PreservationActionStatus) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SetPreservationActionStatus", arg0, arg1, arg2) ret0, _ := ret[0].(error) @@ -338,13 +337,13 @@ func (c *MockServiceSetPreservationActionStatusCall) Return(arg0 error) *MockSer } // Do rewrite *gomock.Call.Do -func (c *MockServiceSetPreservationActionStatusCall) Do(f func(context.Context, uint, package_0.PreservationActionStatus) error) *MockServiceSetPreservationActionStatusCall { +func (c *MockServiceSetPreservationActionStatusCall) Do(f func(context.Context, uint, enums.PreservationActionStatus) error) *MockServiceSetPreservationActionStatusCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockServiceSetPreservationActionStatusCall) DoAndReturn(f func(context.Context, uint, package_0.PreservationActionStatus) error) *MockServiceSetPreservationActionStatusCall { +func (c *MockServiceSetPreservationActionStatusCall) DoAndReturn(f func(context.Context, uint, enums.PreservationActionStatus) error) *MockServiceSetPreservationActionStatusCall { c.Call = c.Call.DoAndReturn(f) return c } diff --git a/internal/package_/package_.go b/internal/package_/package_.go index 93df88498..bfdec1404 100644 --- a/internal/package_/package_.go +++ b/internal/package_/package_.go @@ -34,14 +34,9 @@ type Service interface { SetStatusInProgress(ctx context.Context, ID uint, startedAt time.Time) error SetStatusPending(ctx context.Context, ID uint) error SetLocationID(ctx context.Context, ID uint, locationID uuid.UUID) error - CreatePreservationAction(ctx context.Context, pa *PreservationAction) error - SetPreservationActionStatus(ctx context.Context, ID uint, status PreservationActionStatus) error - CompletePreservationAction( - ctx context.Context, - ID uint, - status PreservationActionStatus, - completedAt time.Time, - ) error + CreatePreservationAction(ctx context.Context, pa *datatypes.PreservationAction) error + SetPreservationActionStatus(ctx context.Context, ID uint, status enums.PreservationActionStatus) error + CompletePreservationAction(ctx context.Context, ID uint, status enums.PreservationActionStatus, completedAt time.Time) error CreatePreservationTask(ctx context.Context, pt *datatypes.PreservationTask) error CompletePreservationTask( ctx context.Context, diff --git a/internal/package_/preservation_action.go b/internal/package_/preservation_action.go index cf9f1e84b..772109ce8 100644 --- a/internal/package_/preservation_action.go +++ b/internal/package_/preservation_action.go @@ -2,10 +2,7 @@ package package_ import ( "context" - "database/sql" - "encoding/json" "fmt" - "strings" "time" "go.artefactual.dev/tools/ref" @@ -13,138 +10,10 @@ import ( goapackage "github.com/artefactual-sdps/enduro/internal/api/gen/package_" "github.com/artefactual-sdps/enduro/internal/datatypes" "github.com/artefactual-sdps/enduro/internal/db" + "github.com/artefactual-sdps/enduro/internal/enums" "github.com/artefactual-sdps/enduro/internal/event" ) -type PreservationActionType uint - -const ( - ActionTypeUnspecified PreservationActionType = iota - ActionTypeCreateAIP - ActionTypeCreateAndReviewAIP - ActionTypeMovePackage -) - -func NewPreservationActionType(status string) PreservationActionType { - var s PreservationActionType - - switch strings.ToLower(status) { - case "create-aip": - s = ActionTypeCreateAIP - case "create-and-review-aip": - s = ActionTypeCreateAndReviewAIP - case "move-package": - s = ActionTypeMovePackage - default: - s = ActionTypeUnspecified - } - - return s -} - -func (p PreservationActionType) String() string { - switch p { - case ActionTypeCreateAIP: - return "create-aip" - case ActionTypeCreateAndReviewAIP: - return "create-and-review-aip" - case ActionTypeMovePackage: - return "move-package" - } - - return "unspecified" -} - -func (p PreservationActionType) MarshalJSON() ([]byte, error) { - return json.Marshal(p.String()) -} - -func (p *PreservationActionType) UnmarshalJSON(b []byte) error { - var s string - if err := json.Unmarshal(b, &s); err != nil { - return err - } - - *p = NewPreservationActionType(s) - - return nil -} - -type PreservationActionStatus uint - -const ( - ActionStatusUnspecified PreservationActionStatus = iota - ActionStatusInProgress - ActionStatusDone - ActionStatusError - ActionStatusQueued - ActionStatusPending -) - -func NewPreservationActionStatus(status string) PreservationActionStatus { - var s PreservationActionStatus - - switch strings.ToLower(status) { - case "in progress": - s = ActionStatusInProgress - case "done": - s = ActionStatusDone - case "error": - s = ActionStatusError - case "queued": - s = ActionStatusQueued - case "pending": - s = ActionStatusPending - default: - s = ActionStatusUnspecified - } - - return s -} - -func (p PreservationActionStatus) String() string { - switch p { - case ActionStatusInProgress: - return "in progress" - case ActionStatusDone: - return "done" - case ActionStatusError: - return "error" - case ActionStatusQueued: - return "queued" - case ActionStatusPending: - return "pending" - } - - return "unspecified" -} - -func (p PreservationActionStatus) MarshalJSON() ([]byte, error) { - return json.Marshal(p.String()) -} - -func (p *PreservationActionStatus) UnmarshalJSON(b []byte) error { - var s string - if err := json.Unmarshal(b, &s); err != nil { - return err - } - - *p = NewPreservationActionStatus(s) - - return nil -} - -// PreservationAction represents a preservation action in the preservation_action table. -type PreservationAction struct { - ID uint `db:"id"` - WorkflowID string `db:"workflow_id"` - Type PreservationActionType `db:"type"` - Status PreservationActionStatus `db:"status"` - StartedAt sql.NullTime `db:"started_at"` - CompletedAt sql.NullTime `db:"completed_at"` - PackageID uint `db:"package_id"` -} - func (w *goaWrapper) PreservationActions( ctx context.Context, payload *goapackage.PreservationActionsPayload, @@ -165,7 +34,7 @@ func (w *goaWrapper) PreservationActions( preservation_actions := []*goapackage.EnduroPackagePreservationAction{} for rows.Next() { - pa := PreservationAction{} + pa := datatypes.PreservationAction{} if err := rows.StructScan(&pa); err != nil { return nil, fmt.Errorf("error scanning database result: %w", err) } @@ -216,7 +85,7 @@ func (w *goaWrapper) PreservationActions( return result, nil } -func (svc *packageImpl) CreatePreservationAction(ctx context.Context, pa *PreservationAction) error { +func (svc *packageImpl) CreatePreservationAction(ctx context.Context, pa *datatypes.PreservationAction) error { startedAt := &pa.StartedAt.Time completedAt := &pa.CompletedAt.Time if pa.StartedAt.Time.IsZero() { @@ -259,7 +128,7 @@ func (svc *packageImpl) CreatePreservationAction(ctx context.Context, pa *Preser func (svc *packageImpl) SetPreservationActionStatus( ctx context.Context, ID uint, - status PreservationActionStatus, + status enums.PreservationActionStatus, ) error { query := `UPDATE preservation_action SET status = ? WHERE id = ?` args := []interface{}{ @@ -283,7 +152,7 @@ func (svc *packageImpl) SetPreservationActionStatus( func (svc *packageImpl) CompletePreservationAction( ctx context.Context, ID uint, - status PreservationActionStatus, + status enums.PreservationActionStatus, completedAt time.Time, ) error { query := `UPDATE preservation_action SET status = ?, completed_at = ? WHERE id = ?` @@ -324,7 +193,7 @@ func (svc *packageImpl) readPreservationAction( WHERE preservation_action.id = ? ` args := []interface{}{ID} - dbItem := PreservationAction{} + dbItem := datatypes.PreservationAction{} if err := svc.db.GetContext(ctx, &dbItem, query, args...); err != nil { return nil, err diff --git a/internal/workflow/local_activities.go b/internal/workflow/local_activities.go index 3298f5888..14bcb83e1 100644 --- a/internal/workflow/local_activities.go +++ b/internal/workflow/local_activities.go @@ -115,8 +115,8 @@ type saveLocationMovePreservationActionLocalActivityParams struct { PackageID uint LocationID uuid.UUID WorkflowID string - Type package_.PreservationActionType - Status package_.PreservationActionStatus + Type enums.PreservationActionType + Status enums.PreservationActionStatus StartedAt time.Time CompletedAt time.Time } @@ -140,11 +140,11 @@ func saveLocationMovePreservationActionLocalActivity( return &saveLocationMovePreservationActionLocalActivityResult{}, err } - actionStatusToTaskStatus := map[package_.PreservationActionStatus]enums.PreservationTaskStatus{ - package_.ActionStatusUnspecified: enums.PreservationTaskStatusUnspecified, - package_.ActionStatusDone: enums.PreservationTaskStatusDone, - package_.ActionStatusInProgress: enums.PreservationTaskStatusInProgress, - package_.ActionStatusError: enums.PreservationTaskStatusError, + actionStatusToTaskStatus := map[enums.PreservationActionStatus]enums.PreservationTaskStatus{ + enums.PreservationActionStatusUnspecified: enums.PreservationTaskStatusUnspecified, + enums.PreservationActionStatusDone: enums.PreservationTaskStatusDone, + enums.PreservationActionStatusInProgress: enums.PreservationTaskStatusInProgress, + enums.PreservationActionStatusError: enums.PreservationTaskStatusError, } pt := datatypes.PreservationTask{ @@ -162,8 +162,8 @@ func saveLocationMovePreservationActionLocalActivity( type createPreservationActionLocalActivityParams struct { WorkflowID string - Type package_.PreservationActionType - Status package_.PreservationActionStatus + Type enums.PreservationActionType + Status enums.PreservationActionStatus StartedAt time.Time CompletedAt time.Time PackageID uint @@ -174,7 +174,7 @@ func createPreservationActionLocalActivity( pkgsvc package_.Service, params *createPreservationActionLocalActivityParams, ) (uint, error) { - pa := package_.PreservationAction{ + pa := datatypes.PreservationAction{ WorkflowID: params.WorkflowID, Type: params.Type, Status: params.Status, @@ -196,14 +196,14 @@ func setPreservationActionStatusLocalActivity( ctx context.Context, pkgsvc package_.Service, ID uint, - status package_.PreservationActionStatus, + status enums.PreservationActionStatus, ) (*setPreservationActionStatusLocalActivityResult, error) { return &setPreservationActionStatusLocalActivityResult{}, pkgsvc.SetPreservationActionStatus(ctx, ID, status) } type completePreservationActionLocalActivityParams struct { PreservationActionID uint - Status package_.PreservationActionStatus + Status enums.PreservationActionStatus CompletedAt time.Time } diff --git a/internal/workflow/move.go b/internal/workflow/move.go index c26d77f09..bc2009048 100644 --- a/internal/workflow/move.go +++ b/internal/workflow/move.go @@ -26,7 +26,7 @@ func (w *MoveWorkflow) Execute(ctx temporalsdk_workflow.Context, req *package_.M startedAt := temporalsdk_workflow.Now(ctx).UTC() // Assume the preservation action will be successful. - status := package_.ActionStatusDone + status := enums.PreservationActionStatusDone // Set package to in progress status. { @@ -47,20 +47,20 @@ func (w *MoveWorkflow) Execute(ctx temporalsdk_workflow.Context, req *package_.M }). Get(activityOpts, nil) if err != nil { - status = package_.ActionStatusError + status = enums.PreservationActionStatusError } } // Poll package move to permanent storage { - if status != package_.ActionStatusError { + if status != enums.PreservationActionStatusError { activityOpts := withActivityOptsForLongLivedRequest(ctx) err := temporalsdk_workflow.ExecuteActivity(activityOpts, activities.PollMoveToPermanentStorageActivityName, &activities.PollMoveToPermanentStorageActivityParams{ AIPID: req.AIPID, }). Get(activityOpts, nil) if err != nil { - status = package_.ActionStatusError + status = enums.PreservationActionStatusError } } } @@ -79,7 +79,7 @@ func (w *MoveWorkflow) Execute(ctx temporalsdk_workflow.Context, req *package_.M // Set package location. { - if status != package_.ActionStatusError { + if status != enums.PreservationActionStatusError { ctx := withLocalActivityOpts(ctx) err := temporalsdk_workflow.ExecuteLocalActivity(ctx, setLocationIDLocalActivity, w.pkgsvc, req.ID, req.LocationID). Get(ctx, nil) @@ -97,7 +97,7 @@ func (w *MoveWorkflow) Execute(ctx temporalsdk_workflow.Context, req *package_.M PackageID: req.ID, LocationID: req.LocationID, WorkflowID: temporalsdk_workflow.GetInfo(ctx).WorkflowExecution.ID, - Type: package_.ActionTypeMovePackage, + Type: enums.PreservationActionTypeMovePackage, Status: status, StartedAt: startedAt, CompletedAt: completedAt, diff --git a/internal/workflow/processing.go b/internal/workflow/processing.go index 70de82ef5..d59b2c519 100644 --- a/internal/workflow/processing.go +++ b/internal/workflow/processing.go @@ -115,7 +115,7 @@ func (w *ProcessingWorkflow) Execute(ctx temporalsdk_workflow.Context, req *pack status = enums.PackageStatusQueued // Create AIP preservation action status. - paStatus = package_.ActionStatusUnspecified + paStatus = enums.PreservationActionStatusUnspecified ) // Persist package as early as possible. @@ -165,8 +165,8 @@ func (w *ProcessingWorkflow) Execute(ctx temporalsdk_workflow.Context, req *pack }). Get(activityOpts, nil) - if paStatus != package_.ActionStatusDone { - paStatus = package_.ActionStatusError + if paStatus != enums.PreservationActionStatusDone { + paStatus = enums.PreservationActionStatusError } _ = temporalsdk_workflow.ExecuteLocalActivity(activityOpts, completePreservationActionLocalActivity, w.pkgsvc, &completePreservationActionLocalActivityParams{ @@ -228,7 +228,7 @@ func (w *ProcessingWorkflow) Execute(ctx temporalsdk_workflow.Context, req *pack status = enums.PackageStatusDone - paStatus = package_.ActionStatusDone + paStatus = enums.PreservationActionStatusDone } // Schedule deletion of the original in the watched data source. @@ -283,18 +283,18 @@ func (w *ProcessingWorkflow) SessionHandler( // Persist the preservation action for creating the AIP. { { - var preservationActionType package_.PreservationActionType + var preservationActionType enums.PreservationActionType if tinfo.req.AutoApproveAIP { - preservationActionType = package_.ActionTypeCreateAIP + preservationActionType = enums.PreservationActionTypeCreateAIP } else { - preservationActionType = package_.ActionTypeCreateAndReviewAIP + preservationActionType = enums.PreservationActionTypeCreateAndReviewAIP } ctx := withLocalActivityOpts(sessCtx) err := temporalsdk_workflow.ExecuteLocalActivity(ctx, createPreservationActionLocalActivity, w.pkgsvc, &createPreservationActionLocalActivityParams{ WorkflowID: temporalsdk_workflow.GetInfo(ctx).WorkflowExecution.ID, Type: preservationActionType, - Status: package_.ActionStatusInProgress, + Status: enums.PreservationActionStatusInProgress, StartedAt: packageStartedAt, PackageID: tinfo.req.PackageID, }). @@ -488,7 +488,7 @@ func (w *ProcessingWorkflow) SessionHandler( // Set preservation action to pending status. { ctx := withLocalActivityOpts(sessCtx) - err := temporalsdk_workflow.ExecuteLocalActivity(ctx, setPreservationActionStatusLocalActivity, w.pkgsvc, tinfo.PreservationActionID, package_.ActionStatusPending).Get(ctx, nil) + err := temporalsdk_workflow.ExecuteLocalActivity(ctx, setPreservationActionStatusLocalActivity, w.pkgsvc, tinfo.PreservationActionID, enums.PreservationActionStatusPending).Get(ctx, nil) if err != nil { return err } @@ -524,7 +524,7 @@ func (w *ProcessingWorkflow) SessionHandler( // Set preservation action to in progress status. { ctx := withLocalActivityOpts(sessCtx) - err := temporalsdk_workflow.ExecuteLocalActivity(ctx, setPreservationActionStatusLocalActivity, w.pkgsvc, tinfo.PreservationActionID, package_.ActionStatusInProgress).Get(ctx, nil) + err := temporalsdk_workflow.ExecuteLocalActivity(ctx, setPreservationActionStatusLocalActivity, w.pkgsvc, tinfo.PreservationActionID, enums.PreservationActionStatusInProgress).Get(ctx, nil) if err != nil { return err } From 94e767e8cf86137c1c631b0dde3543108f32b90c Mon Sep 17 00:00:00 2001 From: David Juhasz Date: Wed, 27 Mar 2024 09:33:57 -0700 Subject: [PATCH 6/9] Remove eventManager persistence wrapper I originally intended the eventManager wrapper as layer to compose the event and persistence services. With the switch to including the persistence service in the package service, it is less work to compose the event and persistence services in the package service and doesn't require a separate service. --- internal/persistence/events.go | 60 ----------- internal/persistence/events_test.go | 151 ---------------------------- 2 files changed, 211 deletions(-) delete mode 100644 internal/persistence/events.go delete mode 100644 internal/persistence/events_test.go diff --git a/internal/persistence/events.go b/internal/persistence/events.go deleted file mode 100644 index ecef653d6..000000000 --- a/internal/persistence/events.go +++ /dev/null @@ -1,60 +0,0 @@ -package persistence - -import ( - "context" - - goapackage "github.com/artefactual-sdps/enduro/internal/api/gen/package_" - "github.com/artefactual-sdps/enduro/internal/datatypes" - "github.com/artefactual-sdps/enduro/internal/event" -) - -type eventManager struct { - evsvc event.EventService - inner Service -} - -var _ Service = (*eventManager)(nil) - -// WithEvents decorates a persistence service implementation with event -// publication to evsvc. -func WithEvents(evsvc event.EventService, inner Service) *eventManager { - return &eventManager{evsvc: evsvc, inner: inner} -} - -// CreatePackage creates and persists a new package using the values from pkg, -// publishes a "package created" event, then returns the updated package. -// -// The input pkg "ID" and "CreatedAt" values are ignored; the stored package -// "ID" is generated by the persistence implementation and "CreatedAt" is always -// set to the current time. -func (m *eventManager) CreatePackage(ctx context.Context, pkg *datatypes.Package) error { - err := m.inner.CreatePackage(ctx, pkg) - if err != nil { - return err - } - - // Publish a "package created" event. - ev := &goapackage.PackageCreatedEvent{ID: uint(pkg.ID), Item: pkg.Goa()} - event.PublishEvent(ctx, m.evsvc, ev) - - return nil -} - -// UpdatePackage updates the persisted package identified by id using the -// updater function, publishes a "package updated" event, then returns the -// updated package. -// -// The package "ID" and "CreatedAt" field values can not be updated with this -// method. -func (m *eventManager) UpdatePackage(ctx context.Context, id uint, updater PackageUpdater) (*datatypes.Package, error) { - pkg, err := m.inner.UpdatePackage(ctx, id, updater) - if err != nil { - return nil, err - } - - // Publish a "package updated" event. - ev := &goapackage.PackageUpdatedEvent{ID: pkg.ID, Item: pkg.Goa()} - event.PublishEvent(ctx, m.evsvc, ev) - - return pkg, nil -} diff --git a/internal/persistence/events_test.go b/internal/persistence/events_test.go deleted file mode 100644 index 0fe7e4940..000000000 --- a/internal/persistence/events_test.go +++ /dev/null @@ -1,151 +0,0 @@ -package persistence_test - -import ( - "context" - "database/sql" - "testing" - "time" - - "github.com/google/uuid" - "go.artefactual.dev/tools/mockutil" - "go.uber.org/mock/gomock" - "gotest.tools/v3/assert" - - "github.com/artefactual-sdps/enduro/internal/datatypes" - "github.com/artefactual-sdps/enduro/internal/enums" - "github.com/artefactual-sdps/enduro/internal/event" - "github.com/artefactual-sdps/enduro/internal/persistence" - mockclient "github.com/artefactual-sdps/enduro/internal/persistence/fake" -) - -var ( - CreatedAt = time.Unix(1694213364, 0) // 2023-09-08T22:49:24+00:00 - StartedAt = time.Unix(1694213435, 0) // 2023-09-08T22:50:35+00:00 -) - -func TestCreatePackage(t *testing.T) { - ctx := context.Background() - aipID := uuid.NullUUID{ - UUID: uuid.MustParse("57e9d085-5716-43d2-bad9-bba3c9a74bd8"), - Valid: true, - } - - evsvc := event.NewEventServiceInMemImpl() - sub, err := evsvc.Subscribe(ctx) - assert.NilError(t, err) - - msvc := mockclient.NewMockService(gomock.NewController(t)) - msvc. - EXPECT(). - CreatePackage(mockutil.Context(), - &datatypes.Package{ - Name: "Fake package", - WorkflowID: "workflow-1", - RunID: "d1fec389-d50f-423f-843f-a510584cc02c", - AIPID: aipID, - Status: enums.PackageStatusInProgress, - StartedAt: sql.NullTime{Time: StartedAt, Valid: true}, - }, - ). - DoAndReturn(func(ctx context.Context, p *datatypes.Package) error { - p.ID = 1 - p.CreatedAt = CreatedAt - - return nil - }) - - svc := persistence.WithEvents(evsvc, msvc) - pkg := datatypes.Package{ - Name: "Fake package", - WorkflowID: "workflow-1", - RunID: "d1fec389-d50f-423f-843f-a510584cc02c", - AIPID: aipID, - Status: enums.PackageStatusInProgress, - StartedAt: sql.NullTime{Time: StartedAt, Valid: true}, - } - - err = svc.CreatePackage(ctx, &pkg) - - assert.NilError(t, err) - assert.DeepEqual(t, pkg, datatypes.Package{ - ID: 1, - Name: "Fake package", - WorkflowID: "workflow-1", - RunID: "d1fec389-d50f-423f-843f-a510584cc02c", - AIPID: aipID, - Status: enums.PackageStatusInProgress, - CreatedAt: CreatedAt, - StartedAt: sql.NullTime{Time: StartedAt, Valid: true}, - }) - - // Verify subscriber received the event. - select { - case ev := <-sub.C(): - assert.Assert(t, ev.Event != nil) - default: - t.Fatal("expected event") - } -} - -func TestUpdatePackage(t *testing.T) { - ctx := context.Background() - aipID := uuid.NullUUID{ - UUID: uuid.MustParse("57e9d085-5716-43d2-bad9-bba3c9a74bd8"), - Valid: true, - } - completed := time.Now() - - evsvc := event.NewEventServiceInMemImpl() - sub, err := evsvc.Subscribe(ctx) - assert.NilError(t, err) - - msvc := mockclient.NewMockService(gomock.NewController(t)) - msvc. - EXPECT(). - UpdatePackage(mockutil.Context(), uint(1), mockutil.Func( - "updates package", - func(updater persistence.PackageUpdater) error { - _, err := updater(&datatypes.Package{}) - return err - }), - ). - Return(&datatypes.Package{ - ID: 1, - Name: "Fake package", - WorkflowID: "workflow-1", - RunID: "d1fec389-d50f-423f-843f-a510584cc02c", - AIPID: aipID, - Status: enums.PackageStatusDone, - CreatedAt: CreatedAt, - StartedAt: sql.NullTime{Time: StartedAt, Valid: true}, - CompletedAt: sql.NullTime{Time: completed, Valid: true}, - }, nil) - - svc := persistence.WithEvents(evsvc, msvc) - got, err := svc.UpdatePackage(ctx, uint(1), func(pkg *datatypes.Package) (*datatypes.Package, error) { - pkg.Status = enums.PackageStatusDone - pkg.CompletedAt = sql.NullTime{Time: completed, Valid: true} - return pkg, nil - }) - - assert.NilError(t, err) - assert.DeepEqual(t, got, &datatypes.Package{ - ID: 1, - Name: "Fake package", - WorkflowID: "workflow-1", - RunID: "d1fec389-d50f-423f-843f-a510584cc02c", - AIPID: aipID, - Status: enums.PackageStatusDone, - CreatedAt: CreatedAt, - StartedAt: sql.NullTime{Time: StartedAt, Valid: true}, - CompletedAt: sql.NullTime{Time: completed, Valid: true}, - }) - - // Verify subscriber received the event. - select { - case ev := <-sub.C(): - assert.Assert(t, ev.Event != nil) - default: - t.Fatal("expected event") - } -} From e584c2d3a076fb11803f705887f1225b4f19fe21 Mon Sep 17 00:00:00 2001 From: David Juhasz Date: Tue, 26 Mar 2024 17:34:30 -0700 Subject: [PATCH 7/9] Add `persistence.CreatePreservationTask()` method - Add `CreatePreservationTask()` to the ent client - Add `CreatePreservationTask()` to the persistence service - Regenerate the persistence service mocks - Add a `convertPreservationTask()` function to convert a `db.PreservationTask` to a `datatype.PreservationTask` struct --- .../persistence/ent/client/client_test.go | 51 +++++-- internal/persistence/ent/client/convert.go | 27 +++- .../ent/client/preservation_task.go | 55 ++++++++ .../ent/client/preservation_task_test.go | 124 ++++++++++++++++++ internal/persistence/fake/mock_persistence.go | 38 ++++++ internal/persistence/persistence.go | 2 + 6 files changed, 283 insertions(+), 14 deletions(-) create mode 100644 internal/persistence/ent/client/preservation_task.go create mode 100644 internal/persistence/ent/client/preservation_task_test.go diff --git a/internal/persistence/ent/client/client_test.go b/internal/persistence/ent/client/client_test.go index 58b1453b7..730f8b996 100644 --- a/internal/persistence/ent/client/client_test.go +++ b/internal/persistence/ent/client/client_test.go @@ -29,27 +29,52 @@ func setUpClient(t *testing.T, logger logr.Logger) (*db.Client, persistence.Serv return entc, c } +func createPackage( + entc *db.Client, + name string, + status enums.PackageStatus, +) (*db.Pkg, error) { + runID := uuid.MustParse("aee9644d-6397-4b34-92f7-442ad3dd3b13") + aipID := uuid.MustParse("30223842-0650-4f79-80bd-7bf43b810656") + + return entc.Pkg.Create(). + SetName(name). + SetWorkflowID("12345"). + SetRunID(runID). + SetAipID(aipID). + SetStatus(int8(status)). + Save(context.Background()) +} + +func createPreservationAction( + entc *db.Client, + pkgID int, + status enums.PreservationActionStatus, +) (*db.PreservationAction, error) { + return entc.PreservationAction.Create(). + SetWorkflowID("12345"). + SetType(int8(enums.PreservationActionTypeCreateAIP)). + SetStatus(int8(status)). + SetPackageID(pkgID). + Save(context.Background()) +} + func TestNew(t *testing.T) { t.Run("Returns a working ent DB client", func(t *testing.T) { t.Parallel() entc, _ := setUpClient(t, logr.Discard()) - runID := uuid.New() - aipID := uuid.New() - - p, err := entc.Pkg.Create(). - SetName("testing 1-2-3"). - SetWorkflowID("12345"). - SetRunID(runID). - SetAipID(aipID). - SetStatus(int8(enums.NewPackageStatus("in progress"))). - Save(context.Background()) - + p, err := createPackage( + entc, + "testing 1-2-3", + enums.NewPackageStatus("in progress"), + ) assert.NilError(t, err) + assert.Equal(t, p.Name, "testing 1-2-3") assert.Equal(t, p.WorkflowID, "12345") - assert.Equal(t, p.RunID, runID) - assert.Equal(t, p.AipID, aipID) + assert.Equal(t, p.RunID, uuid.MustParse("aee9644d-6397-4b34-92f7-442ad3dd3b13")) + assert.Equal(t, p.AipID, uuid.MustParse("30223842-0650-4f79-80bd-7bf43b810656")) assert.Equal(t, p.Status, int8(enums.PackageStatusInProgress)) }) } diff --git a/internal/persistence/ent/client/convert.go b/internal/persistence/ent/client/convert.go index 58ef499e1..4bc1680a1 100644 --- a/internal/persistence/ent/client/convert.go +++ b/internal/persistence/ent/client/convert.go @@ -10,7 +10,7 @@ import ( "github.com/artefactual-sdps/enduro/internal/persistence/ent/db" ) -// convertPkgToPackage converts an ent `db.Pkg` package representation to a +// convertPkgToPackage converts an entgo `db.Pkg` package representation to a // `datatypes.Package` representation. func convertPkgToPackage(pkg *db.Pkg) *datatypes.Package { var started, completed sql.NullTime @@ -44,3 +44,28 @@ func convertPkgToPackage(pkg *db.Pkg) *datatypes.Package { CompletedAt: completed, } } + +// convertPreservationTask converts an entgo `db.PreservationTask` representation +// to a `datatypes.PreservationTask` representation. +func convertPreservationTask(pt *db.PreservationTask) *datatypes.PreservationTask { + var started sql.NullTime + if !pt.StartedAt.IsZero() { + started = sql.NullTime{Time: pt.StartedAt, Valid: true} + } + + var completed sql.NullTime + if !pt.CompletedAt.IsZero() { + completed = sql.NullTime{Time: pt.CompletedAt, Valid: true} + } + + return &datatypes.PreservationTask{ + ID: uint(pt.ID), + TaskID: pt.TaskID.String(), + Name: pt.Name, + Status: enums.PreservationTaskStatus(pt.Status), + StartedAt: started, + CompletedAt: completed, + Note: pt.Note, + PreservationActionID: uint(pt.PreservationActionID), + } +} diff --git a/internal/persistence/ent/client/preservation_task.go b/internal/persistence/ent/client/preservation_task.go new file mode 100644 index 000000000..8fc55914b --- /dev/null +++ b/internal/persistence/ent/client/preservation_task.go @@ -0,0 +1,55 @@ +package entclient + +import ( + "context" + "time" + + "github.com/google/uuid" + + "github.com/artefactual-sdps/enduro/internal/datatypes" +) + +func (c *client) CreatePreservationTask(ctx context.Context, pt *datatypes.PreservationTask) error { + // Validate required fields. + taskID, err := uuid.Parse(pt.TaskID) + if err != nil { + return newParseError(err, "TaskID") + } + if pt.Name == "" { + return newRequiredFieldError("Name") + } + if pt.PreservationActionID == 0 { + return newRequiredFieldError("PreservationActionID") + } + // TODO: Validate Status. + + // Handle nullable fields. + var startedAt *time.Time + if pt.StartedAt.Valid { + startedAt = &pt.StartedAt.Time + } + + var completedAt *time.Time + if pt.CompletedAt.Valid { + completedAt = &pt.CompletedAt.Time + } + + q := c.ent.PreservationTask.Create(). + SetTaskID(taskID). + SetName(pt.Name). + SetStatus(int8(pt.Status)). + SetNillableStartedAt(startedAt). + SetNillableCompletedAt(completedAt). + SetNote(pt.Note). + SetPreservationActionID(int(pt.PreservationActionID)) + + r, err := q.Save(ctx) + if err != nil { + return newDBErrorWithDetails(err, "create preservation task") + } + + // Update value of pt with data from DB (e.g. ID). + *pt = *convertPreservationTask(r) + + return nil +} diff --git a/internal/persistence/ent/client/preservation_task_test.go b/internal/persistence/ent/client/preservation_task_test.go new file mode 100644 index 000000000..31f2d24f1 --- /dev/null +++ b/internal/persistence/ent/client/preservation_task_test.go @@ -0,0 +1,124 @@ +package entclient_test + +import ( + "context" + "database/sql" + "testing" + "time" + + "github.com/go-logr/logr" + "gotest.tools/v3/assert" + + "github.com/artefactual-sdps/enduro/internal/datatypes" + "github.com/artefactual-sdps/enduro/internal/enums" +) + +func TestCreatePreservationTask(t *testing.T) { + taskID := "ef0193bf-a622-4a8b-b860-cda605a426b5" + started := sql.NullTime{Time: time.Now(), Valid: true} + completed := sql.NullTime{Time: started.Time.Add(time.Second), Valid: true} + + type params struct { + pt *datatypes.PreservationTask + zeroPreservationActionID bool + } + tests := []struct { + name string + args params + want *datatypes.PreservationTask + wantErr string + }{ + { + name: "Saves a new preservation task in the DB", + args: params{ + pt: &datatypes.PreservationTask{ + TaskID: taskID, + Name: "PT1", + Status: enums.PreservationTaskStatusInProgress, + StartedAt: started, + CompletedAt: completed, + Note: "PT1 Note", + }, + }, + want: &datatypes.PreservationTask{ + ID: 1, + TaskID: taskID, + Name: "PT1", + Status: enums.PreservationTaskStatusInProgress, + StartedAt: started, + CompletedAt: completed, + Note: "PT1 Note", + }, + }, + { + name: "Errors on invalid TaskID", + args: params{ + pt: &datatypes.PreservationTask{ + TaskID: "123456", + }, + }, + wantErr: "invalid data error: parse error: field \"TaskID\": invalid UUID length: 6", + }, + { + name: "Required field error for missing Name", + args: params{ + pt: &datatypes.PreservationTask{ + TaskID: "ef0193bf-a622-4a8b-b860-cda605a426b5", + }, + }, + wantErr: "invalid data error: field \"Name\" is required", + }, + { + name: "Required field error for missing PreservationActionID", + args: params{ + pt: &datatypes.PreservationTask{ + TaskID: taskID, + Name: "PT1", + Status: enums.PreservationTaskStatusInProgress, + }, + zeroPreservationActionID: true, + }, + wantErr: "invalid data error: field \"PreservationActionID\" is required", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + entc, svc := setUpClient(t, logr.Discard()) + ctx := context.Background() + pkg, _ := createPackage( + entc, + "Test package", + enums.PackageStatusInProgress, + ) + pa, _ := createPreservationAction( + entc, + pkg.ID, + enums.PreservationActionStatusInProgress, + ) + + pt := *tt.args.pt // Make a local copy of pt. + + if !tt.args.zeroPreservationActionID { + pt.PreservationActionID = uint(pa.ID) + } + + err := svc.CreatePreservationTask(ctx, &pt) + if tt.wantErr != "" { + assert.Error(t, err, tt.wantErr) + return + } + assert.NilError(t, err) + + assert.Equal(t, pt.ID, tt.want.ID) + assert.Equal(t, pt.TaskID, tt.want.TaskID) + assert.Equal(t, pt.Name, tt.want.Name) + assert.Equal(t, pt.Status, tt.want.Status) + assert.Equal(t, pt.StartedAt, tt.want.StartedAt) + assert.Equal(t, pt.CompletedAt, tt.want.CompletedAt) + assert.Equal(t, pt.Note, tt.want.Note) + assert.Equal(t, pt.PreservationActionID, uint(pa.ID)) + }) + } +} diff --git a/internal/persistence/fake/mock_persistence.go b/internal/persistence/fake/mock_persistence.go index 13a90efaf..6805bb7bd 100644 --- a/internal/persistence/fake/mock_persistence.go +++ b/internal/persistence/fake/mock_persistence.go @@ -79,6 +79,44 @@ func (c *MockServiceCreatePackageCall) DoAndReturn(f func(context.Context, *data return c } +// CreatePreservationTask mocks base method. +func (m *MockService) CreatePreservationTask(arg0 context.Context, arg1 *datatypes.PreservationTask) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreatePreservationTask", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// CreatePreservationTask indicates an expected call of CreatePreservationTask. +func (mr *MockServiceMockRecorder) CreatePreservationTask(arg0, arg1 any) *MockServiceCreatePreservationTaskCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreatePreservationTask", reflect.TypeOf((*MockService)(nil).CreatePreservationTask), arg0, arg1) + return &MockServiceCreatePreservationTaskCall{Call: call} +} + +// MockServiceCreatePreservationTaskCall wrap *gomock.Call +type MockServiceCreatePreservationTaskCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockServiceCreatePreservationTaskCall) Return(arg0 error) *MockServiceCreatePreservationTaskCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockServiceCreatePreservationTaskCall) Do(f func(context.Context, *datatypes.PreservationTask) error) *MockServiceCreatePreservationTaskCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockServiceCreatePreservationTaskCall) DoAndReturn(f func(context.Context, *datatypes.PreservationTask) error) *MockServiceCreatePreservationTaskCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + // UpdatePackage mocks base method. func (m *MockService) UpdatePackage(arg0 context.Context, arg1 uint, arg2 persistence.PackageUpdater) (*datatypes.Package, error) { m.ctrl.T.Helper() diff --git a/internal/persistence/persistence.go b/internal/persistence/persistence.go index f9a502c2b..7aed6e6c4 100644 --- a/internal/persistence/persistence.go +++ b/internal/persistence/persistence.go @@ -28,4 +28,6 @@ type Service interface { // (e.g. ID, CreatedAt). CreatePackage(context.Context, *datatypes.Package) error UpdatePackage(context.Context, uint, PackageUpdater) (*datatypes.Package, error) + + CreatePreservationTask(context.Context, *datatypes.PreservationTask) error } From 9ac3b101459fca3c2f6b877a09624cd7fb665b8a Mon Sep 17 00:00:00 2001 From: David Juhasz Date: Wed, 27 Mar 2024 11:39:15 -0700 Subject: [PATCH 8/9] Use `persistence.CreatePreservationTask()` Fixes #900 - Use `persistence.CreatePreservationTask()` instead of a SQL query in the package service - Add a `preservationTaskToGoa()` function to convert `datatypes.PreservationTask` to a Goa `EnduroPackagePreservationTask` struct [skip codecov] --- internal/package_/convert.go | 27 ++++ internal/package_/preservation_task.go | 38 +----- internal/package_/preservation_task_test.go | 135 ++++++++++++++++++++ 3 files changed, 168 insertions(+), 32 deletions(-) create mode 100644 internal/package_/convert.go create mode 100644 internal/package_/preservation_task_test.go diff --git a/internal/package_/convert.go b/internal/package_/convert.go new file mode 100644 index 000000000..3092275d3 --- /dev/null +++ b/internal/package_/convert.go @@ -0,0 +1,27 @@ +package package_ + +import ( + "go.artefactual.dev/tools/ref" + + goapackage "github.com/artefactual-sdps/enduro/internal/api/gen/package_" + "github.com/artefactual-sdps/enduro/internal/datatypes" + "github.com/artefactual-sdps/enduro/internal/db" +) + +// preservationTaskToGoa returns the API representation of a preservation task. +func preservationTaskToGoa(pt *datatypes.PreservationTask) *goapackage.EnduroPackagePreservationTask { + return &goapackage.EnduroPackagePreservationTask{ + ID: pt.ID, + TaskID: pt.TaskID, + Name: pt.Name, + Status: pt.Status.String(), + + // TODO: Make Goa StartedAt a pointer to a string to avoid having to + // convert a null time to an empty (zero value) string. + StartedAt: ref.DerefZero(db.FormatOptionalTime(pt.CompletedAt)), + + CompletedAt: db.FormatOptionalTime(pt.CompletedAt), + Note: &pt.Note, + PreservationActionID: &pt.PreservationActionID, + } +} diff --git a/internal/package_/preservation_task.go b/internal/package_/preservation_task.go index f34614d1f..a12a8c96b 100644 --- a/internal/package_/preservation_task.go +++ b/internal/package_/preservation_task.go @@ -15,42 +15,16 @@ import ( ) func (svc *packageImpl) CreatePreservationTask(ctx context.Context, pt *datatypes.PreservationTask) error { - startedAt := &pt.StartedAt.Time - completedAt := &pt.CompletedAt.Time - if pt.StartedAt.Time.IsZero() { - startedAt = nil - } - if pt.CompletedAt.Time.IsZero() { - completedAt = nil - } - - query := `INSERT INTO preservation_task (task_id, name, status, started_at, completed_at, note, preservation_action_id) VALUES (?, ?, ?, ?, ?, ?, ?)` - args := []interface{}{ - pt.TaskID, - pt.Name, - pt.Status, - startedAt, - completedAt, - pt.Note, - pt.PreservationActionID, - } - - res, err := svc.db.ExecContext(ctx, query, args...) + err := svc.perSvc.CreatePreservationTask(ctx, pt) if err != nil { - return fmt.Errorf("error inserting preservation task: %w", err) + return fmt.Errorf("preservation task: create: %v", err) } - var id int64 - if id, err = res.LastInsertId(); err != nil { - return fmt.Errorf("error retrieving insert ID: %w", err) - } - - pt.ID = uint(id) - - if item, err := svc.readPreservationTask(ctx, pt.ID); err == nil { - ev := &goapackage.PreservationTaskCreatedEvent{ID: pt.ID, Item: item} - event.PublishEvent(ctx, svc.evsvc, ev) + ev := &goapackage.PreservationTaskCreatedEvent{ + ID: pt.ID, + Item: preservationTaskToGoa(pt), } + event.PublishEvent(ctx, svc.evsvc, ev) return nil } diff --git a/internal/package_/preservation_task_test.go b/internal/package_/preservation_task_test.go new file mode 100644 index 000000000..309244d8b --- /dev/null +++ b/internal/package_/preservation_task_test.go @@ -0,0 +1,135 @@ +package package__test + +import ( + "context" + "database/sql" + "fmt" + "testing" + "time" + + "go.artefactual.dev/tools/mockutil" + "gotest.tools/v3/assert" + + "github.com/artefactual-sdps/enduro/internal/datatypes" + "github.com/artefactual-sdps/enduro/internal/enums" + persistence_fake "github.com/artefactual-sdps/enduro/internal/persistence/fake" +) + +func TestCreatePreservationTask(t *testing.T) { + taskID := "a499e8fc-7309-4e26-b39d-d8ab68466c27" + + type test struct { + name string + pt datatypes.PreservationTask + mock func(*persistence_fake.MockService, datatypes.PreservationTask) *persistence_fake.MockService + want datatypes.PreservationTask + wantErr string + } + for _, tt := range []test{ + { + name: "Creates a preservation task", + pt: datatypes.PreservationTask{ + TaskID: taskID, + Name: "PT1", + Status: enums.PreservationTaskStatusInProgress, + PreservationActionID: 11, + }, + want: datatypes.PreservationTask{ + ID: 1, + TaskID: taskID, + Name: "PT1", + Status: enums.PreservationTaskStatusInProgress, + PreservationActionID: 11, + }, + mock: func(svc *persistence_fake.MockService, pt datatypes.PreservationTask) *persistence_fake.MockService { + svc.EXPECT(). + CreatePreservationTask(mockutil.Context(), &pt). + DoAndReturn( + func(ctx context.Context, pt *datatypes.PreservationTask) error { + pt.ID = 1 + return nil + }, + ) + return svc + }, + }, + { + name: "Creates a preservation task with optional values", + pt: datatypes.PreservationTask{ + TaskID: taskID, + Name: "PT2", + Status: enums.PreservationTaskStatusInProgress, + StartedAt: sql.NullTime{ + Time: time.Date(2024, 3, 27, 11, 32, 41, 0, time.UTC), + Valid: true, + }, + CompletedAt: sql.NullTime{ + Time: time.Date(2024, 3, 27, 11, 32, 43, 0, time.UTC), + Valid: true, + }, + Note: "PT2 Note", + PreservationActionID: 12, + }, + mock: func(svc *persistence_fake.MockService, pt datatypes.PreservationTask) *persistence_fake.MockService { + svc.EXPECT(). + CreatePreservationTask(mockutil.Context(), &pt). + DoAndReturn( + func(ctx context.Context, pt *datatypes.PreservationTask) error { + pt.ID = 2 + return nil + }, + ) + return svc + }, + want: datatypes.PreservationTask{ + ID: 2, + TaskID: taskID, + Name: "PT2", + Status: enums.PreservationTaskStatusInProgress, + StartedAt: sql.NullTime{ + Time: time.Date(2024, 3, 27, 11, 32, 41, 0, time.UTC), + Valid: true, + }, + CompletedAt: sql.NullTime{ + Time: time.Date(2024, 3, 27, 11, 32, 43, 0, time.UTC), + Valid: true, + }, + Note: "PT2 Note", + PreservationActionID: 12, + }, + }, + { + name: "Errors creating a package with a missing TaskID", + pt: datatypes.PreservationTask{}, + mock: func(svc *persistence_fake.MockService, pt datatypes.PreservationTask) *persistence_fake.MockService { + svc.EXPECT(). + CreatePreservationTask(mockutil.Context(), &pt). + Return( + fmt.Errorf("invalid data error: field \"TaskID\" is required"), + ) + return svc + }, + wantErr: "preservation task: create: invalid data error: field \"TaskID\" is required", + }, + } { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + pkgSvc, perSvc := testSvc(t) + if tt.mock != nil { + tt.mock(perSvc, tt.pt) + } + + pt := tt.pt + err := pkgSvc.CreatePreservationTask(context.Background(), &pt) + + if tt.wantErr != "" { + assert.Error(t, err, tt.wantErr) + return + } + + assert.NilError(t, err) + assert.DeepEqual(t, pt, tt.want) + }) + } +} From 049f105ce7a9c01b4a6ef9af9b4c30298bde798c Mon Sep 17 00:00:00 2001 From: David Juhasz Date: Thu, 28 Mar 2024 08:47:35 -0700 Subject: [PATCH 9/9] Cosmetic: split long lines with golines [skip codecov] --- internal/package_/package_.go | 7 ++++++- internal/package_/preservation_task.go | 13 +++++++++++-- internal/persistence/ent/client/package.go | 6 +++++- 3 files changed, 22 insertions(+), 4 deletions(-) diff --git a/internal/package_/package_.go b/internal/package_/package_.go index bfdec1404..9b6643383 100644 --- a/internal/package_/package_.go +++ b/internal/package_/package_.go @@ -36,7 +36,12 @@ type Service interface { SetLocationID(ctx context.Context, ID uint, locationID uuid.UUID) error CreatePreservationAction(ctx context.Context, pa *datatypes.PreservationAction) error SetPreservationActionStatus(ctx context.Context, ID uint, status enums.PreservationActionStatus) error - CompletePreservationAction(ctx context.Context, ID uint, status enums.PreservationActionStatus, completedAt time.Time) error + CompletePreservationAction( + ctx context.Context, + ID uint, + status enums.PreservationActionStatus, + completedAt time.Time, + ) error CreatePreservationTask(ctx context.Context, pt *datatypes.PreservationTask) error CompletePreservationTask( ctx context.Context, diff --git a/internal/package_/preservation_task.go b/internal/package_/preservation_task.go index a12a8c96b..18ab76469 100644 --- a/internal/package_/preservation_task.go +++ b/internal/package_/preservation_task.go @@ -29,7 +29,13 @@ func (svc *packageImpl) CreatePreservationTask(ctx context.Context, pt *datatype return nil } -func (svc *packageImpl) CompletePreservationTask(ctx context.Context, ID uint, status enums.PreservationTaskStatus, completedAt time.Time, note *string) error { +func (svc *packageImpl) CompletePreservationTask( + ctx context.Context, + ID uint, + status enums.PreservationTaskStatus, + completedAt time.Time, + note *string, +) error { var query string args := []interface{}{} @@ -54,7 +60,10 @@ func (svc *packageImpl) CompletePreservationTask(ctx context.Context, ID uint, s return nil } -func (svc *packageImpl) readPreservationTask(ctx context.Context, ID uint) (*goapackage.EnduroPackagePreservationTask, error) { +func (svc *packageImpl) readPreservationTask( + ctx context.Context, + ID uint, +) (*goapackage.EnduroPackagePreservationTask, error) { query := ` SELECT preservation_task.id, diff --git a/internal/persistence/ent/client/package.go b/internal/persistence/ent/client/package.go index d0f90203c..4f19f3e71 100644 --- a/internal/persistence/ent/client/package.go +++ b/internal/persistence/ent/client/package.go @@ -73,7 +73,11 @@ func (c *client) CreatePackage(ctx context.Context, pkg *datatypes.Package) erro // // The package "ID" and "CreatedAt" field values can not be updated with this // method. -func (c *client) UpdatePackage(ctx context.Context, id uint, updater persistence.PackageUpdater) (*datatypes.Package, error) { +func (c *client) UpdatePackage( + ctx context.Context, + id uint, + updater persistence.PackageUpdater, +) (*datatypes.Package, error) { tx, err := c.ent.BeginTx(ctx, nil) if err != nil { return nil, newDBError(err)