diff --git a/internal/persistence/ent/client/preservation_task.go b/internal/persistence/ent/client/preservation_task.go index 8fc55914b..892536626 100644 --- a/internal/persistence/ent/client/preservation_task.go +++ b/internal/persistence/ent/client/preservation_task.go @@ -7,6 +7,7 @@ import ( "github.com/google/uuid" "github.com/artefactual-sdps/enduro/internal/datatypes" + "github.com/artefactual-sdps/enduro/internal/persistence" ) func (c *client) CreatePreservationTask(ctx context.Context, pt *datatypes.PreservationTask) error { @@ -53,3 +54,56 @@ func (c *client) CreatePreservationTask(ctx context.Context, pt *datatypes.Prese return nil } + +func (c *client) UpdatePreservationTask( + ctx context.Context, + id uint, + updater persistence.PresTaskUpdater, +) (*datatypes.PreservationTask, error) { + tx, err := c.ent.BeginTx(ctx, nil) + if err != nil { + return nil, newDBErrorWithDetails(err, "update preservation task") + } + + pt, err := tx.PreservationTask.Get(ctx, int(id)) + if err != nil { + return nil, rollback(tx, newDBError(err)) + } + + up, err := updater(convertPreservationTask(pt)) + if err != nil { + return nil, rollback(tx, newUpdaterError(err)) + } + + // Set required column values. + taskID, err := uuid.Parse(up.TaskID) + if err != nil { + return nil, rollback(tx, newParseError(err, "TaskID")) + } + + q := tx.PreservationTask.UpdateOneID(int(id)). + SetTaskID(taskID). + SetName(up.Name). + SetStatus(int8(up.Status)). + SetNote(up.Note). + SetPreservationActionID(int(up.PreservationActionID)) + + // Set nullable column values. + if up.StartedAt.Valid { + q.SetStartedAt(up.StartedAt.Time) + } + if up.CompletedAt.Valid { + q.SetCompletedAt(up.CompletedAt.Time) + } + + // Save changes. + pt, err = q.Save(ctx) + if err != nil { + return nil, rollback(tx, newDBError(err)) + } + if err = tx.Commit(); err != nil { + return nil, rollback(tx, newDBError(err)) + } + + return convertPreservationTask(pt), nil +} diff --git a/internal/persistence/ent/client/preservation_task_test.go b/internal/persistence/ent/client/preservation_task_test.go index 31f2d24f1..7d956391f 100644 --- a/internal/persistence/ent/client/preservation_task_test.go +++ b/internal/persistence/ent/client/preservation_task_test.go @@ -3,16 +3,45 @@ package entclient_test import ( "context" "database/sql" + "errors" + "fmt" "testing" "time" "github.com/go-logr/logr" + "github.com/google/uuid" "gotest.tools/v3/assert" "github.com/artefactual-sdps/enduro/internal/datatypes" "github.com/artefactual-sdps/enduro/internal/enums" + "github.com/artefactual-sdps/enduro/internal/persistence" + "github.com/artefactual-sdps/enduro/internal/persistence/ent/db" ) +func addDBFixtures( + t *testing.T, + entc *db.Client, +) (*db.PreservationAction, *db.PreservationAction) { + t.Helper() + + pkg, err := createPackage(entc, "P1", enums.PackageStatusInProgress) + if err != nil { + t.Errorf("create package: %v", err) + } + + pa, err := createPreservationAction(entc, pkg.ID, enums.PreservationActionStatusInProgress) + if err != nil { + t.Errorf("create preservation action: %v", err) + } + + pa2, err := createPreservationAction(entc, pkg.ID, enums.PreservationActionStatusDone) + if err != nil { + t.Errorf("create preservation action 2: %v", err) + } + + return pa, pa2 +} + func TestCreatePreservationTask(t *testing.T) { taskID := "ef0193bf-a622-4a8b-b860-cda605a426b5" started := sql.NullTime{Time: time.Now(), Valid: true} @@ -122,3 +151,175 @@ func TestCreatePreservationTask(t *testing.T) { }) } } + +func TestUpdatePreservationTask(t *testing.T) { + taskID := uuid.MustParse("c5f7c35a-d5a6-4e00-b4da-b036ce5b40bc") + taskID2 := uuid.MustParse("c04d0191-d7ce-46dd-beff-92d6830082ff") + + started := sql.NullTime{ + Time: time.Date(2024, 3, 31, 10, 11, 12, 0, time.UTC), + Valid: true, + } + started2 := sql.NullTime{ + Time: time.Date(2024, 4, 1, 17, 5, 49, 0, time.UTC), + Valid: true, + } + + completed := sql.NullTime{Time: started.Time.Add(time.Second), Valid: true} + completed2 := sql.NullTime{Time: started2.Time.Add(time.Second), Valid: true} + + type params struct { + pt *datatypes.PreservationTask + updater persistence.PresTaskUpdater + } + tests := []struct { + name string + args params + want *datatypes.PreservationTask + wantErr string + }{ + { + name: "Updates all preservation task columns", + args: params{ + pt: &datatypes.PreservationTask{ + TaskID: taskID.String(), + Name: "PT 1", + Status: enums.PreservationTaskStatusInProgress, + StartedAt: started, + CompletedAt: completed, + Note: "PT1 Note", + }, + updater: func(p *datatypes.PreservationTask) (*datatypes.PreservationTask, error) { + p.ID = 100 // No-op, can't update ID. + p.Name = "PT1 Update" + p.TaskID = taskID2.String() + p.Status = enums.PreservationTaskStatusDone + p.StartedAt = started2 + p.CompletedAt = completed2 + p.Note = "PT1 Note updated" + return p, nil + }, + }, + want: &datatypes.PreservationTask{ + TaskID: taskID2.String(), + Name: "PT1 Update", + Status: enums.PreservationTaskStatusDone, + StartedAt: started2, + CompletedAt: completed2, + Note: "PT1 Note updated", + }, + }, + { + name: "Updates selected preservation task columns", + args: params{ + pt: &datatypes.PreservationTask{ + Name: "PT 1", + TaskID: taskID.String(), + Status: enums.PreservationTaskStatusInProgress, + StartedAt: started, + }, + updater: func(p *datatypes.PreservationTask) (*datatypes.PreservationTask, error) { + p.Status = enums.PreservationTaskStatusDone + p.CompletedAt = completed + p.Note = "PT1 Note updated" + return p, nil + }, + }, + want: &datatypes.PreservationTask{ + TaskID: taskID.String(), + Name: "PT 1", + Status: enums.PreservationTaskStatusDone, + StartedAt: started, + CompletedAt: completed, + Note: "PT1 Note updated", + }, + }, + { + name: "Errors when target preservation task isn't found", + args: params{ + updater: func(p *datatypes.PreservationTask) (*datatypes.PreservationTask, error) { + return nil, errors.New("Bad input") + }, + }, + wantErr: "not found error: db: preservation_task not found", + }, + { + name: "Errors when the updater fails", + args: params{ + pt: &datatypes.PreservationTask{ + Name: "PT 1", + TaskID: taskID.String(), + Status: enums.PreservationTaskStatusInProgress, + }, + updater: func(p *datatypes.PreservationTask) (*datatypes.PreservationTask, error) { + return nil, fmt.Errorf("Bad input") + }, + }, + wantErr: "invalid data error: updater error: Bad input", + }, + { + name: "Errors on an invalid TaskID", + args: params{ + pt: &datatypes.PreservationTask{ + Name: "PT 1", + TaskID: taskID.String(), + Status: enums.PreservationTaskStatusInProgress, + }, + updater: func(p *datatypes.PreservationTask) (*datatypes.PreservationTask, error) { + p.TaskID = "123456" + return p, nil + }, + }, + wantErr: "invalid data error: parse error: field \"TaskID\": invalid UUID length: 6", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + entc, svc := setUpClient(t, logr.Discard()) + pa, pa2 := addDBFixtures(t, entc) + + updater := tt.args.updater + var id uint + if tt.args.pt != nil { + pt := *tt.args.pt // Make a local copy of pt. + pt.PreservationActionID = uint(pa.ID) + + // Create preservation task to be updated. + err := svc.CreatePreservationTask(ctx, &pt) + if err != nil { + t.Errorf("create preservation task: %v", err) + } + id = pt.ID + + // Update PreservationActionID to pa2.ID. + updater = func(pt *datatypes.PreservationTask) (*datatypes.PreservationTask, error) { + pt, err := tt.args.updater(pt) + if err != nil { + return nil, err + } + pt.PreservationActionID = uint(pa2.ID) + + return pt, nil + } + } + + pp, err := svc.UpdatePreservationTask(ctx, id, updater) + if tt.wantErr != "" { + assert.Error(t, err, tt.wantErr) + return + } + + assert.Equal(t, pp.ID, id) + assert.Equal(t, pp.TaskID, tt.want.TaskID) + assert.Equal(t, pp.Name, tt.want.Name) + assert.Equal(t, pp.Status, tt.want.Status) + assert.Equal(t, pp.StartedAt, tt.want.StartedAt) + assert.Equal(t, pp.CompletedAt, tt.want.CompletedAt) + assert.Equal(t, pp.Note, tt.want.Note) + assert.Equal(t, pp.PreservationActionID, uint(pa2.ID)) + }) + } +} diff --git a/internal/persistence/fake/mock_persistence.go b/internal/persistence/fake/mock_persistence.go index 6805bb7bd..5b6404df2 100644 --- a/internal/persistence/fake/mock_persistence.go +++ b/internal/persistence/fake/mock_persistence.go @@ -155,3 +155,42 @@ func (c *MockServiceUpdatePackageCall) DoAndReturn(f func(context.Context, uint, c.Call = c.Call.DoAndReturn(f) return c } + +// UpdatePreservationTask mocks base method. +func (m *MockService) UpdatePreservationTask(arg0 context.Context, arg1 uint, arg2 persistence.PresTaskUpdater) (*datatypes.PreservationTask, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdatePreservationTask", arg0, arg1, arg2) + ret0, _ := ret[0].(*datatypes.PreservationTask) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpdatePreservationTask indicates an expected call of UpdatePreservationTask. +func (mr *MockServiceMockRecorder) UpdatePreservationTask(arg0, arg1, arg2 any) *MockServiceUpdatePreservationTaskCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatePreservationTask", reflect.TypeOf((*MockService)(nil).UpdatePreservationTask), arg0, arg1, arg2) + return &MockServiceUpdatePreservationTaskCall{Call: call} +} + +// MockServiceUpdatePreservationTaskCall wrap *gomock.Call +type MockServiceUpdatePreservationTaskCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockServiceUpdatePreservationTaskCall) Return(arg0 *datatypes.PreservationTask, arg1 error) *MockServiceUpdatePreservationTaskCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockServiceUpdatePreservationTaskCall) Do(f func(context.Context, uint, persistence.PresTaskUpdater) (*datatypes.PreservationTask, error)) *MockServiceUpdatePreservationTaskCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockServiceUpdatePreservationTaskCall) DoAndReturn(f func(context.Context, uint, persistence.PresTaskUpdater) (*datatypes.PreservationTask, error)) *MockServiceUpdatePreservationTaskCall { + c.Call = c.Call.DoAndReturn(f) + return c +} diff --git a/internal/persistence/persistence.go b/internal/persistence/persistence.go index 7aed6e6c4..4723b3aa8 100644 --- a/internal/persistence/persistence.go +++ b/internal/persistence/persistence.go @@ -19,7 +19,8 @@ var ( ) type ( - PackageUpdater func(*datatypes.Package) (*datatypes.Package, error) + PackageUpdater func(*datatypes.Package) (*datatypes.Package, error) + PresTaskUpdater func(*datatypes.PreservationTask) (*datatypes.PreservationTask, error) ) type Service interface { @@ -30,4 +31,5 @@ type Service interface { UpdatePackage(context.Context, uint, PackageUpdater) (*datatypes.Package, error) CreatePreservationTask(context.Context, *datatypes.PreservationTask) error + UpdatePreservationTask(ctx context.Context, id uint, updater PresTaskUpdater) (*datatypes.PreservationTask, error) } diff --git a/internal/persistence/telemetry.go b/internal/persistence/telemetry.go index e057be771..777f44e76 100644 --- a/internal/persistence/telemetry.go +++ b/internal/persistence/telemetry.go @@ -70,3 +70,21 @@ func (w *wrapper) CreatePreservationTask(ctx context.Context, pt *datatypes.Pres return nil } + +func (w *wrapper) UpdatePreservationTask( + ctx context.Context, + id uint, + updater PresTaskUpdater, +) (*datatypes.PreservationTask, error) { + ctx, span := w.tracer.Start(ctx, "UpdatePreservationTask") + defer span.End() + span.SetAttributes(attribute.Int("id", int(id))) + + r, err := w.wrapped.UpdatePreservationTask(ctx, id, updater) + if err != nil { + telemetry.RecordError(span, err) + return nil, updateError(err, "UpdatePreservationTask") + } + + return r, nil +}