Skip to content

Commit 6230f67

Browse files
committed
address comments on Get behavior
1 parent 740357b commit 6230f67

File tree

2 files changed

+147
-39
lines changed

2 files changed

+147
-39
lines changed

internal/batch/batch_future.go

Lines changed: 27 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ func (b *batchFutureImpl) start(ctx internal.Context) {
7575
idx := idx
7676

7777
wgForFutures.Add(1)
78-
internal.GoNamed(ctx, "batch-future-processor-one-future", func(ctx internal.Context) {
78+
internal.GoNamed(ctx, fmt.Sprintf("batch-future-processor-one-future-%d", idx), func(ctx internal.Context) {
7979
defer wgForFutures.Done()
8080

8181
// fork a future and chain it to the processed future for user to get the result
@@ -100,49 +100,42 @@ func (b *batchFutureImpl) IsReady() bool {
100100
return true
101101
}
102102

103+
// Get assigns the result of the futures to the valuePtr.
104+
// NOTE: valuePtr must be a pointer to a slice, or nil.
105+
// If valuePtr is a pointer to a slice, the slice will be resized to the length of the futures. Each element of the slice will be assigned with the underlying Future.Get() and thus behaves the same way.
106+
// If valuePtr is nil, no assignment will be made.
107+
// If error occurs, values will be set on successful futures and the errors of failed futures will be returned.
103108
func (b *batchFutureImpl) Get(ctx internal.Context, valuePtr interface{}) error {
104-
// ensure valuePtr is a slice
105-
var sliceValue reflect.Value
106-
if valuePtr != nil {
107-
108-
switch v := reflect.ValueOf(valuePtr); v.Kind() {
109-
case reflect.Ptr:
110-
if v.Elem().Kind() != reflect.Slice {
111-
return fmt.Errorf("valuePtr must be a pointer to a slice, got %v", v)
112-
}
113-
sliceValue = v.Elem()
114-
case reflect.Slice:
115-
sliceValue = v
116-
default:
117-
return fmt.Errorf("valuePtr must be a slice or a pointer to a slice, got %v", v.Kind())
118-
}
119-
// ensure slice size is the same as the number of futures
120-
if sliceValue.Len() != len(b.futures) {
121-
return fmt.Errorf("slice size must be the same as the number of futures, got %d, expected %d", sliceValue.Len(), len(b.futures))
109+
// No assignment if valuePtr is nil
110+
if valuePtr == nil {
111+
b.wg.Wait(ctx)
112+
var errs error
113+
for i := range b.futures {
114+
errs = multierr.Append(errs, b.futures[i].Get(ctx, nil))
122115
}
116+
return errs
117+
}
118+
119+
v := reflect.ValueOf(valuePtr)
120+
if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Slice {
121+
return fmt.Errorf("valuePtr must be a pointer to a slice, got %v", v.Kind())
123122
}
124123

124+
// resize the slice to the length of the futures
125+
slice := v.Elem()
126+
if slice.Cap() < len(b.futures) {
127+
slice.Grow(len(b.futures) - slice.Cap())
128+
}
129+
slice.SetLen(len(b.futures))
130+
125131
// wait for all futures to be ready
126132
b.wg.Wait(ctx)
127133

128134
// loop through all elements of valuePtr
129135
var errs error
130136
for i := range b.futures {
131-
if valuePtr == nil {
132-
errs = multierr.Append(errs, b.futures[i].Get(ctx, nil))
133-
} else {
134-
value := sliceValue.Index(i)
135-
if value.Kind() != reflect.Ptr {
136-
value = value.Addr()
137-
}
138-
// if value is nil, initialize it
139-
if value.IsNil() {
140-
value.Set(reflect.New(value.Type().Elem()))
141-
}
142-
143-
e := b.futures[i].Get(ctx, value.Interface())
144-
errs = multierr.Append(errs, e)
145-
}
137+
e := b.futures[i].Get(ctx, slice.Index(i).Addr().Interface())
138+
errs = multierr.Append(errs, e)
146139
}
147140

148141
return errs

internal/batch/batch_future_test.go

Lines changed: 120 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"errors"
66
"fmt"
77
"math/rand"
8-
"reflect"
98
"sync"
109
"testing"
1110
"time"
@@ -211,9 +210,125 @@ func Test_Futures(t *testing.T) {
211210
env.ExecuteWorkflow(futureTest)
212211
}
213212

214-
func Test_valuePtr(t *testing.T) {
215-
slices := make([]int, 10)
216-
slicePtr := &slices
213+
func batchWorkflowAssignWithSlice(ctx internal.Context) ([]int, error) {
214+
totalSize := 5
215+
concurrency := 2
216+
factories := make([]func(ctx internal.Context) internal.Future, totalSize)
217+
for i := 0; i < totalSize; i++ {
218+
i := i
219+
factories[i] = func(ctx internal.Context) internal.Future {
220+
aCtx := internal.WithActivityOptions(ctx, internal.ActivityOptions{
221+
ScheduleToStartTimeout: time.Second * 10,
222+
StartToCloseTimeout: time.Second * 10,
223+
})
224+
return internal.ExecuteActivity(aCtx, batchActivity, i)
225+
}
226+
}
227+
228+
batchFuture, err := NewBatchFuture(ctx, concurrency, factories)
229+
if err != nil {
230+
return nil, err
231+
}
217232

218-
fmt.Println(reflect.ValueOf(slicePtr).Elem().Len())
233+
var valuePtr []int
234+
if err := batchFuture.Get(ctx, &valuePtr); err != nil {
235+
return nil, err
236+
}
237+
return valuePtr, nil
238+
}
239+
240+
func batchWorkflowAssignWithSliceOfPointers(ctx internal.Context) ([]int, error) {
241+
totalSize := 5
242+
concurrency := 2
243+
factories := make([]func(ctx internal.Context) internal.Future, totalSize)
244+
for i := 0; i < totalSize; i++ {
245+
i := i
246+
factories[i] = func(ctx internal.Context) internal.Future {
247+
aCtx := internal.WithActivityOptions(ctx, internal.ActivityOptions{
248+
ScheduleToStartTimeout: time.Second * 10,
249+
StartToCloseTimeout: time.Second * 10,
250+
})
251+
return internal.ExecuteActivity(aCtx, batchActivity, i)
252+
}
253+
}
254+
batchFuture, err := NewBatchFuture(ctx, concurrency, factories)
255+
if err != nil {
256+
return nil, err
257+
}
258+
var valuePtr []*int
259+
if err := batchFuture.Get(ctx, &valuePtr); err != nil {
260+
return nil, err
261+
}
262+
263+
var result []int
264+
for _, v := range valuePtr {
265+
result = append(result, *v)
266+
}
267+
return result, nil
268+
}
269+
270+
func batchWorkflowAssignWithNil(ctx internal.Context) ([]int, error) {
271+
totalSize := 5
272+
concurrency := 2
273+
factories := make([]func(ctx internal.Context) internal.Future, totalSize)
274+
for i := 0; i < totalSize; i++ {
275+
i := i
276+
factories[i] = func(ctx internal.Context) internal.Future {
277+
aCtx := internal.WithActivityOptions(ctx, internal.ActivityOptions{
278+
ScheduleToStartTimeout: time.Second * 10,
279+
StartToCloseTimeout: time.Second * 10,
280+
})
281+
return internal.ExecuteActivity(aCtx, batchActivity, i)
282+
}
283+
}
284+
285+
batchFuture, err := NewBatchFuture(ctx, concurrency, factories)
286+
if err != nil {
287+
return nil, err
288+
}
289+
290+
var valuePtr []int
291+
if err := batchFuture.Get(ctx, nil); err != nil {
292+
return nil, err
293+
}
294+
return valuePtr, nil
295+
}
296+
297+
func Test_BatchFuture_Get(t *testing.T) {
298+
tests := []struct {
299+
name string
300+
workflow func(ctx internal.Context) ([]int, error)
301+
want interface{}
302+
}{
303+
{
304+
name: "success with nil slice",
305+
workflow: batchWorkflowAssignWithSlice,
306+
want: []int{0, 1, 2, 3, 4},
307+
},
308+
{
309+
name: "success with non-nil slice",
310+
workflow: batchWorkflowAssignWithSliceOfPointers,
311+
want: []int{0, 1, 2, 3, 4},
312+
},
313+
{
314+
name: "success with nil",
315+
workflow: batchWorkflowAssignWithNil,
316+
want: []int(nil),
317+
},
318+
}
319+
320+
for _, tt := range tests {
321+
t.Run(tt.name, func(t *testing.T) {
322+
testSuite := &testsuite.WorkflowTestSuite{}
323+
env := testSuite.NewTestWorkflowEnvironment()
324+
env.RegisterWorkflow(tt.workflow)
325+
env.RegisterActivity(batchActivity)
326+
env.ExecuteWorkflow(tt.workflow)
327+
assert.True(t, env.IsWorkflowCompleted())
328+
assert.Nil(t, env.GetWorkflowError())
329+
var result []int
330+
assert.Nil(t, env.GetWorkflowResult(&result))
331+
assert.Equal(t, tt.want, result)
332+
})
333+
}
219334
}

0 commit comments

Comments
 (0)