diff --git a/go.mod b/go.mod index 9168d2bf..43459a5b 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/SafetyCulture/safetyculture-exporter go 1.18 require ( + github.com/MickStanciu/go-fn v1.3.0 github.com/dghubble/sling v1.4.1 github.com/gocarina/gocsv v0.0.0-20220310154401-d4df709ca055 github.com/gofrs/uuid v4.4.0+incompatible diff --git a/go.sum b/go.sum index aa8e217d..c5f073db 100644 --- a/go.sum +++ b/go.sum @@ -43,6 +43,8 @@ github.com/Azure/azure-sdk-for-go/sdk/internal v1.0.0/go.mod h1:eWRD7oawr1Mu1sLC github.com/AzureAD/microsoft-authentication-library-for-go v0.8.1/go.mod h1:4qFor3D/HDsvBME35Xy9rwW9DecL+M2sNw1ybjPtwA0= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= +github.com/MickStanciu/go-fn v1.3.0 h1:GjdM1xEHtjHJj2fS8VAtQj33T0j2uQAPf2et80QqP1s= +github.com/MickStanciu/go-fn v1.3.0/go.mod h1:VXISTRTin8MigXG12B3/9IOhVadpp07v/vXKpUIQr6Q= github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= diff --git a/pkg/internal/feed/feed_action_assignees.go b/pkg/internal/feed/feed_action_assignees.go index a9706569..51428d19 100644 --- a/pkg/internal/feed/feed_action_assignees.go +++ b/pkg/internal/feed/feed_action_assignees.go @@ -6,6 +6,8 @@ import ( "fmt" "time" + "github.com/MickStanciu/go-fn/fn" + "github.com/SafetyCulture/safetyculture-exporter/pkg/internal/util" "github.com/SafetyCulture/safetyculture-exporter/pkg/logger" "github.com/SafetyCulture/safetyculture-exporter/pkg/httpapi" @@ -76,27 +78,25 @@ func (f *ActionAssigneeFeed) CreateSchema(exporter Exporter) error { func (f *ActionAssigneeFeed) writeRows(exporter Exporter, rows []*ActionAssignee) error { // Calculate the size of the batch we can insert into the DB at once. Column count + buffer to account for primary keys batchSize := exporter.ParameterLimit() / (len(f.Columns()) + 4) + err := util.SplitSliceInBatch(batchSize, rows, func(batch []*ActionAssignee) error { + // Delete the actions if already exists + actionIDs := fn.Map(batch, func(row *ActionAssignee) string { + return row.ActionID + }) - for i := 0; i < len(rows); i += batchSize { - j := i + batchSize - if j > len(rows) { - j = len(rows) - } - var actionIDs []string - for k := range rows[i:j] { - actionIDs = append(actionIDs, rows[k].ActionID) - } - - // Delete the actions if already exist if err := exporter.DeleteRowsIfExist(f, "action_id IN ?", actionIDs); err != nil { return fmt.Errorf("delete rows: %w", err) } - if err := exporter.WriteRows(f, rows[i:j]); err != nil { + if err := exporter.WriteRows(f, batch); err != nil { return events.WrapEventError(err, "write rows") } - } + return nil + }) + if err != nil { + return err + } return nil } diff --git a/pkg/internal/util/slice.go b/pkg/internal/util/slice.go new file mode 100644 index 00000000..616761e2 --- /dev/null +++ b/pkg/internal/util/slice.go @@ -0,0 +1,24 @@ +package util + +import ( + "fmt" +) + +// SplitSliceInBatch splits a slice into batches and calls the callback function for each batch. +func SplitSliceInBatch[T any](size int, collection []T, fn func(batch []T) error) error { + if size == 0 { + return fmt.Errorf("batch size cannot be 0") + } + + for i := 0; i < len(collection); i += size { + j := i + size + if j > len(collection) { + j = len(collection) + } + + if err := fn(collection[i:j]); err != nil { + return err + } + } + return nil +} diff --git a/pkg/internal/util/slice_test.go b/pkg/internal/util/slice_test.go new file mode 100644 index 00000000..d66efc1f --- /dev/null +++ b/pkg/internal/util/slice_test.go @@ -0,0 +1,69 @@ +package util_test + +import ( + "fmt" + "testing" + + "github.com/SafetyCulture/safetyculture-exporter/pkg/internal/util" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSplitSliceInBatch(t *testing.T) { + tests := map[string]struct { + size int + collection []string + fn func([]string) error + expectedErr error + }{ + "batch size is 0": { + size: 0, + expectedErr: fmt.Errorf("batch size cannot be 0"), + }, + "batch size is greater than the collection size": { + size: 100, + collection: []string{"a", "b", "c", "d", "e", "f", "g"}, + fn: func(strings []string) error { + require.True(t, len(strings) == 7) + return nil + }, + }, + "when not divisible": { + size: 3, + collection: []string{"a", "b", "c", "d", "e", "f", "g"}, + fn: func(strings []string) error { + require.True(t, len(strings) == 3 || len(strings) == 1) + return nil + }, + }, + "when is divisible": { + size: 3, + collection: []string{"a", "b", "c", "d", "e", "f"}, + fn: func(strings []string) error { + require.True(t, len(strings) == 3) + return nil + }, + }, + "when errors out": { + size: 3, + collection: []string{"a", "b", "c", "d", "e", "f"}, + fn: func(strings []string) error { + if strings[0] == "d" { + return fmt.Errorf("error in processing function") + } + return nil + }, + expectedErr: fmt.Errorf("error in processing function"), + }, + } + + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + err := util.SplitSliceInBatch(tt.size, tt.collection, tt.fn) + if tt.expectedErr != nil { + require.Error(t, err) + assert.Equal(t, tt.expectedErr, err) + } + }) + } +}