Skip to content

Commit

Permalink
Add persistence.UpdatePreservationTask method
Browse files Browse the repository at this point in the history
Refs #907

- Add an `UpdatePreservationTask()` method to to the ent persistence
client
- Add an `UpdatePreservationTask()` method to the persistence service
- Add an `UpdatePreservationTask()` method to the persistence telemetry
wrapper
- Regenerate persistence mocks
  • Loading branch information
djjuhasz committed Apr 2, 2024
1 parent f522aed commit dd004da
Show file tree
Hide file tree
Showing 5 changed files with 315 additions and 1 deletion.
54 changes: 54 additions & 0 deletions internal/persistence/ent/client/preservation_task.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/google/uuid"

"github.com/artefactual-sdps/enduro/internal/datatypes"
"github.com/artefactual-sdps/enduro/internal/persistence"
)

func (c *client) CreatePreservationTask(ctx context.Context, pt *datatypes.PreservationTask) error {
Expand Down Expand Up @@ -53,3 +54,56 @@ func (c *client) CreatePreservationTask(ctx context.Context, pt *datatypes.Prese

return nil
}

func (c *client) UpdatePreservationTask(
ctx context.Context,
id uint,
updater persistence.PresTaskUpdater,
) (*datatypes.PreservationTask, error) {
tx, err := c.ent.BeginTx(ctx, nil)
if err != nil {
return nil, newDBErrorWithDetails(err, "update preservation task")

Check warning on line 65 in internal/persistence/ent/client/preservation_task.go

View check run for this annotation

Codecov / codecov/patch

internal/persistence/ent/client/preservation_task.go#L65

Added line #L65 was not covered by tests
}

pt, err := tx.PreservationTask.Get(ctx, int(id))
if err != nil {
return nil, rollback(tx, newDBError(err))
}

up, err := updater(convertPreservationTask(pt))
if err != nil {
return nil, rollback(tx, newUpdaterError(err))
}

// Set required column values.
taskID, err := uuid.Parse(up.TaskID)
if err != nil {
return nil, rollback(tx, newParseError(err, "TaskID"))
}

q := tx.PreservationTask.UpdateOneID(int(id)).
SetTaskID(taskID).
SetName(up.Name).
SetStatus(int8(up.Status)).
SetNote(up.Note).
SetPreservationActionID(int(up.PreservationActionID))

// Set nullable column values.
if up.StartedAt.Valid {
q.SetStartedAt(up.StartedAt.Time)
}
if up.CompletedAt.Valid {
q.SetCompletedAt(up.CompletedAt.Time)
}

// Save changes.
pt, err = q.Save(ctx)
if err != nil {
return nil, rollback(tx, newDBError(err))

Check warning on line 102 in internal/persistence/ent/client/preservation_task.go

View check run for this annotation

Codecov / codecov/patch

internal/persistence/ent/client/preservation_task.go#L102

Added line #L102 was not covered by tests
}
if err = tx.Commit(); err != nil {
return nil, rollback(tx, newDBError(err))

Check warning on line 105 in internal/persistence/ent/client/preservation_task.go

View check run for this annotation

Codecov / codecov/patch

internal/persistence/ent/client/preservation_task.go#L105

Added line #L105 was not covered by tests
}

return convertPreservationTask(pt), nil
}
201 changes: 201 additions & 0 deletions internal/persistence/ent/client/preservation_task_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,45 @@ package entclient_test
import (
"context"
"database/sql"
"errors"
"fmt"
"testing"
"time"

"github.com/go-logr/logr"
"github.com/google/uuid"
"gotest.tools/v3/assert"

"github.com/artefactual-sdps/enduro/internal/datatypes"
"github.com/artefactual-sdps/enduro/internal/enums"
"github.com/artefactual-sdps/enduro/internal/persistence"
"github.com/artefactual-sdps/enduro/internal/persistence/ent/db"
)

func addDBFixtures(
t *testing.T,
entc *db.Client,
) (*db.PreservationAction, *db.PreservationAction) {
t.Helper()

pkg, err := createPackage(entc, "P1", enums.PackageStatusInProgress)
if err != nil {
t.Errorf("create package: %v", err)
}

pa, err := createPreservationAction(entc, pkg.ID, enums.PreservationActionStatusInProgress)
if err != nil {
t.Errorf("create preservation action: %v", err)
}

pa2, err := createPreservationAction(entc, pkg.ID, enums.PreservationActionStatusDone)
if err != nil {
t.Errorf("create preservation action 2: %v", err)
}

return pa, pa2
}

func TestCreatePreservationTask(t *testing.T) {
taskID := "ef0193bf-a622-4a8b-b860-cda605a426b5"
started := sql.NullTime{Time: time.Now(), Valid: true}
Expand Down Expand Up @@ -122,3 +151,175 @@ func TestCreatePreservationTask(t *testing.T) {
})
}
}

func TestUpdatePreservationTask(t *testing.T) {
taskID := uuid.MustParse("c5f7c35a-d5a6-4e00-b4da-b036ce5b40bc")
taskID2 := uuid.MustParse("c04d0191-d7ce-46dd-beff-92d6830082ff")

started := sql.NullTime{
Time: time.Date(2024, 3, 31, 10, 11, 12, 0, time.UTC),
Valid: true,
}
started2 := sql.NullTime{
Time: time.Date(2024, 4, 1, 17, 5, 49, 0, time.UTC),
Valid: true,
}

completed := sql.NullTime{Time: started.Time.Add(time.Second), Valid: true}
completed2 := sql.NullTime{Time: started2.Time.Add(time.Second), Valid: true}

type params struct {
pt *datatypes.PreservationTask
updater persistence.PresTaskUpdater
}
tests := []struct {
name string
args params
want *datatypes.PreservationTask
wantErr string
}{
{
name: "Updates all preservation task columns",
args: params{
pt: &datatypes.PreservationTask{
TaskID: taskID.String(),
Name: "PT 1",
Status: enums.PreservationTaskStatusInProgress,
StartedAt: started,
CompletedAt: completed,
Note: "PT1 Note",
},
updater: func(p *datatypes.PreservationTask) (*datatypes.PreservationTask, error) {
p.ID = 100 // No-op, can't update ID.
p.Name = "PT1 Update"
p.TaskID = taskID2.String()
p.Status = enums.PreservationTaskStatusDone
p.StartedAt = started2
p.CompletedAt = completed2
p.Note = "PT1 Note updated"
return p, nil
},
},
want: &datatypes.PreservationTask{
TaskID: taskID2.String(),
Name: "PT1 Update",
Status: enums.PreservationTaskStatusDone,
StartedAt: started2,
CompletedAt: completed2,
Note: "PT1 Note updated",
},
},
{
name: "Updates selected preservation task columns",
args: params{
pt: &datatypes.PreservationTask{
Name: "PT 1",
TaskID: taskID.String(),
Status: enums.PreservationTaskStatusInProgress,
StartedAt: started,
},
updater: func(p *datatypes.PreservationTask) (*datatypes.PreservationTask, error) {
p.Status = enums.PreservationTaskStatusDone
p.CompletedAt = completed
p.Note = "PT1 Note updated"
return p, nil
},
},
want: &datatypes.PreservationTask{
TaskID: taskID.String(),
Name: "PT 1",
Status: enums.PreservationTaskStatusDone,
StartedAt: started,
CompletedAt: completed,
Note: "PT1 Note updated",
},
},
{
name: "Errors when target preservation task isn't found",
args: params{
updater: func(p *datatypes.PreservationTask) (*datatypes.PreservationTask, error) {
return nil, errors.New("Bad input")
},
},
wantErr: "not found error: db: preservation_task not found",
},
{
name: "Errors when the updater fails",
args: params{
pt: &datatypes.PreservationTask{
Name: "PT 1",
TaskID: taskID.String(),
Status: enums.PreservationTaskStatusInProgress,
},
updater: func(p *datatypes.PreservationTask) (*datatypes.PreservationTask, error) {
return nil, fmt.Errorf("Bad input")
},
},
wantErr: "invalid data error: updater error: Bad input",
},
{
name: "Errors on an invalid TaskID",
args: params{
pt: &datatypes.PreservationTask{
Name: "PT 1",
TaskID: taskID.String(),
Status: enums.PreservationTaskStatusInProgress,
},
updater: func(p *datatypes.PreservationTask) (*datatypes.PreservationTask, error) {
p.TaskID = "123456"
return p, nil
},
},
wantErr: "invalid data error: parse error: field \"TaskID\": invalid UUID length: 6",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

ctx := context.Background()
entc, svc := setUpClient(t, logr.Discard())
pa, pa2 := addDBFixtures(t, entc)

updater := tt.args.updater
var id uint
if tt.args.pt != nil {
pt := *tt.args.pt // Make a local copy of pt.
pt.PreservationActionID = uint(pa.ID)

// Create preservation task to be updated.
err := svc.CreatePreservationTask(ctx, &pt)
if err != nil {
t.Errorf("create preservation task: %v", err)
}
id = pt.ID

// Update PreservationActionID to pa2.ID.
updater = func(pt *datatypes.PreservationTask) (*datatypes.PreservationTask, error) {
pt, err := tt.args.updater(pt)
if err != nil {
return nil, err
}
pt.PreservationActionID = uint(pa2.ID)

return pt, nil
}
}

pp, err := svc.UpdatePreservationTask(ctx, id, updater)
if tt.wantErr != "" {
assert.Error(t, err, tt.wantErr)
return
}

assert.Equal(t, pp.ID, id)
assert.Equal(t, pp.TaskID, tt.want.TaskID)
assert.Equal(t, pp.Name, tt.want.Name)
assert.Equal(t, pp.Status, tt.want.Status)
assert.Equal(t, pp.StartedAt, tt.want.StartedAt)
assert.Equal(t, pp.CompletedAt, tt.want.CompletedAt)
assert.Equal(t, pp.Note, tt.want.Note)
assert.Equal(t, pp.PreservationActionID, uint(pa2.ID))
})
}
}
39 changes: 39 additions & 0 deletions internal/persistence/fake/mock_persistence.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion internal/persistence/persistence.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ var (
)

type (
PackageUpdater func(*datatypes.Package) (*datatypes.Package, error)
PackageUpdater func(*datatypes.Package) (*datatypes.Package, error)
PresTaskUpdater func(*datatypes.PreservationTask) (*datatypes.PreservationTask, error)
)

type Service interface {
Expand All @@ -30,4 +31,5 @@ type Service interface {
UpdatePackage(context.Context, uint, PackageUpdater) (*datatypes.Package, error)

CreatePreservationTask(context.Context, *datatypes.PreservationTask) error
UpdatePreservationTask(ctx context.Context, id uint, updater PresTaskUpdater) (*datatypes.PreservationTask, error)
}
18 changes: 18 additions & 0 deletions internal/persistence/telemetry.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,21 @@ func (w *wrapper) CreatePreservationTask(ctx context.Context, pt *datatypes.Pres

return nil
}

func (w *wrapper) UpdatePreservationTask(
ctx context.Context,
id uint,
updater PresTaskUpdater,
) (*datatypes.PreservationTask, error) {
ctx, span := w.tracer.Start(ctx, "UpdatePreservationTask")
defer span.End()
span.SetAttributes(attribute.Int("id", int(id)))

Check warning on line 81 in internal/persistence/telemetry.go

View check run for this annotation

Codecov / codecov/patch

internal/persistence/telemetry.go#L78-L81

Added lines #L78 - L81 were not covered by tests

r, err := w.wrapped.UpdatePreservationTask(ctx, id, updater)
if err != nil {
telemetry.RecordError(span, err)
return nil, updateError(err, "UpdatePreservationTask")

Check warning on line 86 in internal/persistence/telemetry.go

View check run for this annotation

Codecov / codecov/patch

internal/persistence/telemetry.go#L83-L86

Added lines #L83 - L86 were not covered by tests
}

return r, nil

Check warning on line 89 in internal/persistence/telemetry.go

View check run for this annotation

Codecov / codecov/patch

internal/persistence/telemetry.go#L89

Added line #L89 was not covered by tests
}

0 comments on commit dd004da

Please sign in to comment.