diff --git a/internal/package_/convert.go b/internal/package_/convert.go new file mode 100644 index 000000000..3092275d3 --- /dev/null +++ b/internal/package_/convert.go @@ -0,0 +1,27 @@ +package package_ + +import ( + "go.artefactual.dev/tools/ref" + + goapackage "github.com/artefactual-sdps/enduro/internal/api/gen/package_" + "github.com/artefactual-sdps/enduro/internal/datatypes" + "github.com/artefactual-sdps/enduro/internal/db" +) + +// preservationTaskToGoa returns the API representation of a preservation task. +func preservationTaskToGoa(pt *datatypes.PreservationTask) *goapackage.EnduroPackagePreservationTask { + return &goapackage.EnduroPackagePreservationTask{ + ID: pt.ID, + TaskID: pt.TaskID, + Name: pt.Name, + Status: pt.Status.String(), + + // TODO: Make Goa StartedAt a pointer to a string to avoid having to + // convert a null time to an empty (zero value) string. + StartedAt: ref.DerefZero(db.FormatOptionalTime(pt.CompletedAt)), + + CompletedAt: db.FormatOptionalTime(pt.CompletedAt), + Note: &pt.Note, + PreservationActionID: &pt.PreservationActionID, + } +} diff --git a/internal/package_/preservation_task.go b/internal/package_/preservation_task.go index f34614d1f..a12a8c96b 100644 --- a/internal/package_/preservation_task.go +++ b/internal/package_/preservation_task.go @@ -15,42 +15,16 @@ import ( ) func (svc *packageImpl) CreatePreservationTask(ctx context.Context, pt *datatypes.PreservationTask) error { - startedAt := &pt.StartedAt.Time - completedAt := &pt.CompletedAt.Time - if pt.StartedAt.Time.IsZero() { - startedAt = nil - } - if pt.CompletedAt.Time.IsZero() { - completedAt = nil - } - - query := `INSERT INTO preservation_task (task_id, name, status, started_at, completed_at, note, preservation_action_id) VALUES (?, ?, ?, ?, ?, ?, ?)` - args := []interface{}{ - pt.TaskID, - pt.Name, - pt.Status, - startedAt, - completedAt, - pt.Note, - pt.PreservationActionID, - } - - res, err := svc.db.ExecContext(ctx, query, args...) + err := svc.perSvc.CreatePreservationTask(ctx, pt) if err != nil { - return fmt.Errorf("error inserting preservation task: %w", err) + return fmt.Errorf("preservation task: create: %v", err) } - var id int64 - if id, err = res.LastInsertId(); err != nil { - return fmt.Errorf("error retrieving insert ID: %w", err) - } - - pt.ID = uint(id) - - if item, err := svc.readPreservationTask(ctx, pt.ID); err == nil { - ev := &goapackage.PreservationTaskCreatedEvent{ID: pt.ID, Item: item} - event.PublishEvent(ctx, svc.evsvc, ev) + ev := &goapackage.PreservationTaskCreatedEvent{ + ID: pt.ID, + Item: preservationTaskToGoa(pt), } + event.PublishEvent(ctx, svc.evsvc, ev) return nil } diff --git a/internal/package_/preservation_task_test.go b/internal/package_/preservation_task_test.go new file mode 100644 index 000000000..309244d8b --- /dev/null +++ b/internal/package_/preservation_task_test.go @@ -0,0 +1,135 @@ +package package__test + +import ( + "context" + "database/sql" + "fmt" + "testing" + "time" + + "go.artefactual.dev/tools/mockutil" + "gotest.tools/v3/assert" + + "github.com/artefactual-sdps/enduro/internal/datatypes" + "github.com/artefactual-sdps/enduro/internal/enums" + persistence_fake "github.com/artefactual-sdps/enduro/internal/persistence/fake" +) + +func TestCreatePreservationTask(t *testing.T) { + taskID := "a499e8fc-7309-4e26-b39d-d8ab68466c27" + + type test struct { + name string + pt datatypes.PreservationTask + mock func(*persistence_fake.MockService, datatypes.PreservationTask) *persistence_fake.MockService + want datatypes.PreservationTask + wantErr string + } + for _, tt := range []test{ + { + name: "Creates a preservation task", + pt: datatypes.PreservationTask{ + TaskID: taskID, + Name: "PT1", + Status: enums.PreservationTaskStatusInProgress, + PreservationActionID: 11, + }, + want: datatypes.PreservationTask{ + ID: 1, + TaskID: taskID, + Name: "PT1", + Status: enums.PreservationTaskStatusInProgress, + PreservationActionID: 11, + }, + mock: func(svc *persistence_fake.MockService, pt datatypes.PreservationTask) *persistence_fake.MockService { + svc.EXPECT(). + CreatePreservationTask(mockutil.Context(), &pt). + DoAndReturn( + func(ctx context.Context, pt *datatypes.PreservationTask) error { + pt.ID = 1 + return nil + }, + ) + return svc + }, + }, + { + name: "Creates a preservation task with optional values", + pt: datatypes.PreservationTask{ + TaskID: taskID, + Name: "PT2", + Status: enums.PreservationTaskStatusInProgress, + StartedAt: sql.NullTime{ + Time: time.Date(2024, 3, 27, 11, 32, 41, 0, time.UTC), + Valid: true, + }, + CompletedAt: sql.NullTime{ + Time: time.Date(2024, 3, 27, 11, 32, 43, 0, time.UTC), + Valid: true, + }, + Note: "PT2 Note", + PreservationActionID: 12, + }, + mock: func(svc *persistence_fake.MockService, pt datatypes.PreservationTask) *persistence_fake.MockService { + svc.EXPECT(). + CreatePreservationTask(mockutil.Context(), &pt). + DoAndReturn( + func(ctx context.Context, pt *datatypes.PreservationTask) error { + pt.ID = 2 + return nil + }, + ) + return svc + }, + want: datatypes.PreservationTask{ + ID: 2, + TaskID: taskID, + Name: "PT2", + Status: enums.PreservationTaskStatusInProgress, + StartedAt: sql.NullTime{ + Time: time.Date(2024, 3, 27, 11, 32, 41, 0, time.UTC), + Valid: true, + }, + CompletedAt: sql.NullTime{ + Time: time.Date(2024, 3, 27, 11, 32, 43, 0, time.UTC), + Valid: true, + }, + Note: "PT2 Note", + PreservationActionID: 12, + }, + }, + { + name: "Errors creating a package with a missing TaskID", + pt: datatypes.PreservationTask{}, + mock: func(svc *persistence_fake.MockService, pt datatypes.PreservationTask) *persistence_fake.MockService { + svc.EXPECT(). + CreatePreservationTask(mockutil.Context(), &pt). + Return( + fmt.Errorf("invalid data error: field \"TaskID\" is required"), + ) + return svc + }, + wantErr: "preservation task: create: invalid data error: field \"TaskID\" is required", + }, + } { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + pkgSvc, perSvc := testSvc(t) + if tt.mock != nil { + tt.mock(perSvc, tt.pt) + } + + pt := tt.pt + err := pkgSvc.CreatePreservationTask(context.Background(), &pt) + + if tt.wantErr != "" { + assert.Error(t, err, tt.wantErr) + return + } + + assert.NilError(t, err) + assert.DeepEqual(t, pt, tt.want) + }) + } +}