diff --git a/backend/activity.go b/backend/activity.go index 0f2423f..970169c 100644 --- a/backend/activity.go +++ b/backend/activity.go @@ -21,12 +21,12 @@ type ActivityExecutor interface { ExecuteActivity(context.Context, api.InstanceID, *protos.HistoryEvent) (*protos.HistoryEvent, error) } -func NewActivityTaskWorker(be Backend, executor ActivityExecutor, logger Logger, opts ...NewTaskWorkerOptions) TaskWorker { +func NewActivityTaskWorker(be Backend, executor ActivityExecutor, logger Logger, opts ...NewTaskWorkerOptions) TaskWorker[*ActivityWorkItem] { processor := newActivityProcessor(be, executor) return NewTaskWorker(processor, logger, opts...) } -func newActivityProcessor(be Backend, executor ActivityExecutor) TaskProcessor { +func newActivityProcessor(be Backend, executor ActivityExecutor) TaskProcessor[*ActivityWorkItem] { return &activityProcessor{ be: be, executor: executor, @@ -38,15 +38,13 @@ func (*activityProcessor) Name() string { return "activity-processor" } -// FetchWorkItem implements TaskDispatcher -func (ap *activityProcessor) FetchWorkItem(ctx context.Context) (WorkItem, error) { - return ap.be.GetActivityWorkItem(ctx) +// NextWorkItem implements TaskDispatcher +func (ap *activityProcessor) NextWorkItem(ctx context.Context) (*ActivityWorkItem, error) { + return ap.be.NextActivityWorkItem(ctx) } // ProcessWorkItem implements TaskDispatcher -func (p *activityProcessor) ProcessWorkItem(ctx context.Context, wi WorkItem) error { - awi := wi.(*ActivityWorkItem) - +func (p *activityProcessor) ProcessWorkItem(ctx context.Context, awi *ActivityWorkItem) error { ts := awi.NewEvent.GetTaskScheduled() if ts == nil { return fmt.Errorf("%v: invalid TaskScheduled event", awi.InstanceID) @@ -83,20 +81,18 @@ func (p *activityProcessor) ProcessWorkItem(ctx context.Context, wi WorkItem) er } // CompleteWorkItem implements TaskDispatcher -func (ap *activityProcessor) CompleteWorkItem(ctx context.Context, wi WorkItem) error { - awi := wi.(*ActivityWorkItem) +func (ap *activityProcessor) CompleteWorkItem(ctx context.Context, awi *ActivityWorkItem) error { if awi.Result == nil { - return fmt.Errorf("can't complete work item '%s' with nil result", wi) + return fmt.Errorf("can't complete work item '%s' with nil result", awi) } if awi.Result.GetTaskCompleted() == nil && awi.Result.GetTaskFailed() == nil { - return fmt.Errorf("can't complete work item '%s', which isn't TaskCompleted or TaskFailed", wi) + return fmt.Errorf("can't complete work item '%s', which isn't TaskCompleted or TaskFailed", awi) } return ap.be.CompleteActivityWorkItem(ctx, awi) } // AbandonWorkItem implements TaskDispatcher -func (ap *activityProcessor) AbandonWorkItem(ctx context.Context, wi WorkItem) error { - awi := wi.(*ActivityWorkItem) +func (ap *activityProcessor) AbandonWorkItem(ctx context.Context, awi *ActivityWorkItem) error { return ap.be.AbandonActivityWorkItem(ctx, awi) } diff --git a/backend/backend.go b/backend/backend.go index 61692a9..cfaebb3 100644 --- a/backend/backend.go +++ b/backend/backend.go @@ -68,9 +68,9 @@ type Backend interface { // AddNewEvent adds a new orchestration event to the specified orchestration instance. AddNewOrchestrationEvent(context.Context, api.InstanceID, *HistoryEvent) error - // GetOrchestrationWorkItem gets a pending work item from the task hub or returns [ErrNoOrchWorkItems] - // if there are no pending work items. - GetOrchestrationWorkItem(context.Context) (*OrchestrationWorkItem, error) + // NextOrchestrationWorkItem blocks and returns the next orchestration work + // item from the task hub. Should only return an error when shutting down. + NextOrchestrationWorkItem(context.Context) (*OrchestrationWorkItem, error) // GetOrchestrationRuntimeState gets the runtime state of an orchestration instance. GetOrchestrationRuntimeState(context.Context, *OrchestrationWorkItem) (*OrchestrationRuntimeState, error) @@ -97,9 +97,9 @@ type Backend interface { // completes with a failure is still considered a successfully processed work item). AbandonOrchestrationWorkItem(context.Context, *OrchestrationWorkItem) error - // GetActivityWorkItem gets a pending activity work item from the task hub or returns [ErrNoWorkItems] - // if there are no pending activity work items. - GetActivityWorkItem(context.Context) (*ActivityWorkItem, error) + // NextActivityWorkItem blocks and returns the next activity work item from + // the task hub. Should only return an error when shutting down. + NextActivityWorkItem(context.Context) (*ActivityWorkItem, error) // CompleteActivityWorkItem sends a message to the parent orchestration indicating activity completion. // diff --git a/backend/orchestration.go b/backend/orchestration.go index b37653f..a1a3c2d 100644 --- a/backend/orchestration.go +++ b/backend/orchestration.go @@ -30,13 +30,13 @@ type orchestratorProcessor struct { logger Logger } -func NewOrchestrationWorker(be Backend, executor OrchestratorExecutor, logger Logger, opts ...NewTaskWorkerOptions) TaskWorker { +func NewOrchestrationWorker(be Backend, executor OrchestratorExecutor, logger Logger, opts ...NewTaskWorkerOptions) TaskWorker[*OrchestrationWorkItem] { processor := &orchestratorProcessor{ be: be, executor: executor, logger: logger, } - return NewTaskWorker(processor, logger, opts...) + return NewTaskWorker[*OrchestrationWorkItem](processor, logger, opts...) } // Name implements TaskProcessor @@ -44,14 +44,13 @@ func (*orchestratorProcessor) Name() string { return "orchestration-processor" } -// FetchWorkItem implements TaskProcessor -func (p *orchestratorProcessor) FetchWorkItem(ctx context.Context) (WorkItem, error) { - return p.be.GetOrchestrationWorkItem(ctx) +// NextWorkItem implements TaskProcessor +func (p *orchestratorProcessor) NextWorkItem(ctx context.Context) (*OrchestrationWorkItem, error) { + return p.be.NextOrchestrationWorkItem(ctx) } // ProcessWorkItem implements TaskProcessor -func (w *orchestratorProcessor) ProcessWorkItem(ctx context.Context, cwi WorkItem) error { - wi := cwi.(*OrchestrationWorkItem) +func (w *orchestratorProcessor) ProcessWorkItem(ctx context.Context, wi *OrchestrationWorkItem) error { w.logger.Debugf("%v: received work item with %d new event(s): %v", wi.InstanceID, len(wi.NewEvents), helpers.HistoryListSummary(wi.NewEvents)) // TODO: Caching @@ -131,15 +130,13 @@ func (w *orchestratorProcessor) ProcessWorkItem(ctx context.Context, cwi WorkIte } // CompleteWorkItem implements TaskProcessor -func (p *orchestratorProcessor) CompleteWorkItem(ctx context.Context, wi WorkItem) error { - owi := wi.(*OrchestrationWorkItem) - return p.be.CompleteOrchestrationWorkItem(ctx, owi) +func (p *orchestratorProcessor) CompleteWorkItem(ctx context.Context, wi *OrchestrationWorkItem) error { + return p.be.CompleteOrchestrationWorkItem(ctx, wi) } // AbandonWorkItem implements TaskProcessor -func (p *orchestratorProcessor) AbandonWorkItem(ctx context.Context, wi WorkItem) error { - owi := wi.(*OrchestrationWorkItem) - return p.be.AbandonOrchestrationWorkItem(ctx, owi) +func (p *orchestratorProcessor) AbandonWorkItem(ctx context.Context, wi *OrchestrationWorkItem) error { + return p.be.AbandonOrchestrationWorkItem(ctx, wi) } func (w *orchestratorProcessor) applyWorkItem(ctx context.Context, wi *OrchestrationWorkItem) (context.Context, trace.Span, bool) { diff --git a/backend/sqlite/sqlite.go b/backend/sqlite/sqlite.go index 962c5b6..8701650 100644 --- a/backend/sqlite/sqlite.go +++ b/backend/sqlite/sqlite.go @@ -28,6 +28,8 @@ var schema string var emptyString string = "" +var errNoWorkItems = errors.New("no work items were found") + type SqliteOptions struct { OrchestrationLockTimeout time.Duration ActivityLockTimeout time.Duration @@ -40,6 +42,9 @@ type sqliteBackend struct { workerName string logger backend.Logger options *SqliteOptions + + activityWorker *backend.TaskWorker[*backend.ActivityWorkItem] + orchestrationWorker *backend.TaskWorker[*backend.OrchestrationWorkItem] } // NewSqliteOptions creates a new options object for the sqlite backend provider. @@ -778,8 +783,8 @@ func (be *sqliteBackend) GetOrchestrationRuntimeState(ctx context.Context, wi *b return state, nil } -// GetOrchestrationWorkItem implements backend.Backend -func (be *sqliteBackend) GetOrchestrationWorkItem(ctx context.Context) (*backend.OrchestrationWorkItem, error) { +// getOrchestrationWorkItem implements backend.Backend +func (be *sqliteBackend) getOrchestrationWorkItem(ctx context.Context) (*backend.OrchestrationWorkItem, error) { if err := be.ensureDB(); err != nil { return nil, err } @@ -819,7 +824,7 @@ func (be *sqliteBackend) GetOrchestrationWorkItem(ctx context.Context) (*backend if err := row.Scan(&instanceID); err != nil { if err == sql.ErrNoRows { // No new events to process - return nil, backend.ErrNoWorkItems + return nil, errNoWorkItems } return nil, fmt.Errorf("failed to scan the orchestration work-item: %w", err) @@ -878,7 +883,73 @@ func (be *sqliteBackend) GetOrchestrationWorkItem(ctx context.Context) (*backend return wi, nil } -func (be *sqliteBackend) GetActivityWorkItem(ctx context.Context) (*backend.ActivityWorkItem, error) { +func (be *sqliteBackend) NextOrchestrationWorkItem(ctx context.Context) (*backend.OrchestrationWorkItem, error) { + b := backoff.WithContext(&backoff.ExponentialBackOff{ + InitialInterval: 50 * time.Millisecond, + MaxInterval: 5 * time.Second, + Multiplier: 1.05, + RandomizationFactor: 0.05, + Stop: backoff.Stop, + Clock: backoff.SystemClock, + }, ctx) + + for { + wi, err := be.getOrchestrationWorkItem(ctx) + if err == nil { + return wi, nil + } + + if !errors.Is(err, errNoWorkItems) { + return nil, err + } + + t := time.NewTimer(b.NextBackOff()) + select { + case <-t.C: + case <-ctx.Done(): + if !t.Stop() { + <-t.C + } + be.logger.Info("Activity: received cancellation signal") + return nil, ctx.Err() + } + } +} + +func (be *sqliteBackend) NextActivityWorkItem(ctx context.Context) (*backend.ActivityWorkItem, error) { + b := backoff.WithContext(&backoff.ExponentialBackOff{ + InitialInterval: 50 * time.Millisecond, + MaxInterval: 5 * time.Second, + Multiplier: 1.05, + RandomizationFactor: 0.05, + Stop: backoff.Stop, + Clock: backoff.SystemClock, + }, ctx) + + for { + wi, err := be.getActivityWorkItem(ctx) + if err == nil { + return wi, nil + } + + if !errors.Is(err, errNoWorkItems) { + return nil, err + } + + t := time.NewTimer(b.NextBackOff()) + select { + case <-t.C: + case <-ctx.Done(): + if !t.Stop() { + <-t.C + } + be.logger.Info("Activity: received cancellation signal") + return nil, ctx.Err() + } + } +} + +func (be *sqliteBackend) getActivityWorkItem(ctx context.Context) (*backend.ActivityWorkItem, error) { if err := be.ensureDB(); err != nil { return nil, err } @@ -910,7 +981,7 @@ func (be *sqliteBackend) GetActivityWorkItem(ctx context.Context) (*backend.Acti if err := row.Scan(&sequenceNumber, &instanceID, &eventPayload); err != nil { if err == sql.ErrNoRows { // No new activity tasks to process - return nil, backend.ErrNoWorkItems + return nil, errNoWorkItems } return nil, fmt.Errorf("failed to scan the activity work-item: %w", err) diff --git a/backend/taskhub.go b/backend/taskhub.go index 688cb73..2fa92cc 100644 --- a/backend/taskhub.go +++ b/backend/taskhub.go @@ -15,12 +15,12 @@ type TaskHubWorker interface { type taskHubWorker struct { backend Backend - orchestrationWorker TaskWorker - activityWorker TaskWorker + orchestrationWorker TaskWorker[*OrchestrationWorkItem] + activityWorker TaskWorker[*ActivityWorkItem] logger Logger } -func NewTaskHubWorker(be Backend, orchestrationWorker TaskWorker, activityWorker TaskWorker, logger Logger) TaskHubWorker { +func NewTaskHubWorker(be Backend, orchestrationWorker TaskWorker[*OrchestrationWorkItem], activityWorker TaskWorker[*ActivityWorkItem], logger Logger) TaskHubWorker { return &taskHubWorker{ backend: be, orchestrationWorker: orchestrationWorker, @@ -30,13 +30,14 @@ func NewTaskHubWorker(be Backend, orchestrationWorker TaskWorker, activityWorker } func (w *taskHubWorker) Start(ctx context.Context) error { - // TODO: Check for already started worker if err := w.backend.CreateTaskHub(ctx); err != nil && err != ErrTaskHubExists { return err } + if err := w.backend.Start(ctx); err != nil { return err } + w.logger.Infof("worker started with backend %v", w.backend) w.orchestrationWorker.Start(ctx) @@ -53,7 +54,7 @@ func (w *taskHubWorker) Shutdown(ctx context.Context) error { w.logger.Info("workers stopping and draining...") defer w.logger.Info("finished stopping and draining workers!") - wg := sync.WaitGroup{} + var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() diff --git a/backend/worker.go b/backend/worker.go index 640e7bc..e4f38a2 100644 --- a/backend/worker.go +++ b/backend/worker.go @@ -2,51 +2,33 @@ package backend import ( "context" - "errors" "sync" - "sync/atomic" - "time" - - "github.com/cenkalti/backoff/v4" - "github.com/marusama/semaphore/v2" ) -type TaskWorker interface { +type TaskWorker[T WorkItem] interface { // Start starts background polling for the activity work items. Start(context.Context) - // ProcessNext attempts to fetch and process a work item. This method returns - // true if a work item was found and processing started; false otherwise. An - // error is returned if the context is cancelled. - ProcessNext(context.Context) (bool, error) - // StopAndDrain stops the worker and waits for all outstanding work items to finish. StopAndDrain() } -type TaskProcessor interface { +type TaskProcessor[T WorkItem] interface { Name() string - FetchWorkItem(context.Context) (WorkItem, error) - ProcessWorkItem(context.Context, WorkItem) error - AbandonWorkItem(context.Context, WorkItem) error - CompleteWorkItem(context.Context, WorkItem) error + ProcessWorkItem(context.Context, T) error + NextWorkItem(context.Context) (T, error) + AbandonWorkItem(context.Context, T) error + CompleteWorkItem(context.Context, T) error } -type worker struct { - options *WorkerOptions - logger Logger - // dispatchSemaphore is for throttling orchestration concurrency. - dispatchSemaphore semaphore.Semaphore - - // pending is for keeping track of outstanding orchestration executions. - pending *sync.WaitGroup - - // cancel is used to cancel background polling. - // It will be nil if background polling isn't started. - cancel context.CancelFunc - processor TaskProcessor - waiting bool - stop atomic.Bool +type worker[T WorkItem] struct { + logger Logger + + processor TaskProcessor[T] + closeCh chan struct{} + wg sync.WaitGroup + workItems chan T + parallelLock chan struct{} } type NewTaskWorkerOptions func(*WorkerOptions) @@ -67,188 +49,96 @@ func WithMaxParallelism(n int32) NewTaskWorkerOptions { } } -func NewTaskWorker(p TaskProcessor, logger Logger, opts ...NewTaskWorkerOptions) TaskWorker { +func NewTaskWorker[T WorkItem](p TaskProcessor[T], logger Logger, opts ...NewTaskWorkerOptions) TaskWorker[T] { options := &WorkerOptions{MaxParallelWorkItems: 1} for _, configure := range opts { configure(options) } - return &worker{ - processor: p, - logger: logger, - dispatchSemaphore: semaphore.New(int(options.MaxParallelWorkItems)), - pending: &sync.WaitGroup{}, - cancel: nil, // assigned later - options: options, + return &worker[T]{ + processor: p, + logger: logger, + workItems: make(chan T), + parallelLock: make(chan struct{}, options.MaxParallelWorkItems), + closeCh: make(chan struct{}), } } -func (w *worker) Name() string { +func (w *worker[T]) Name() string { return w.processor.Name() } -func (w *worker) Start(ctx context.Context) { - // TODO: Check for already started worker - ctx, cancel := context.WithCancel(ctx) - w.cancel = cancel +func (w *worker[T]) Start(ctx context.Context) { + w.wg.Add(2) - w.stop.Store(false) + ctx, cancel := context.WithCancel(ctx) go func() { - var b backoff.BackOff = &backoff.ExponentialBackOff{ - InitialInterval: 50 * time.Millisecond, - MaxInterval: 5 * time.Second, - Multiplier: 1.05, - RandomizationFactor: 0.05, - Stop: backoff.Stop, - Clock: backoff.SystemClock, + defer w.wg.Done() + defer cancel() + + select { + case <-w.closeCh: + case <-ctx.Done(): } - b = backoff.WithContext(b, ctx) - b.Reset() + }() + + go func() { + defer w.wg.Done() + defer w.logger.Infof("%v: worker stopped", w.Name()) - loop: for { - // returns right away, with "ok" if a work item was found - ok, err := w.ProcessNext(ctx) - - switch { - case ok: - // found a work item - reset the backoff and check for the next item - b.Reset() - case err != nil && errors.Is(err, ctx.Err()): - // there's an error and it's due to the context being canceled - w.logger.Infof("%v: received cancellation signal", w.Name()) - break loop - case err != nil: - // another error was encountered - // log the error and inject some extra sleep to avoid tight failure loops - w.logger.Errorf("unexpected worker error: %v. Adding 5 extra seconds of backoff.", err) - t := time.NewTimer(5 * time.Second) - select { - case <-t.C: - // nop - all good - case <-ctx.Done(): - if !t.Stop() { - <-t.C - } - w.logger.Infof("%v: received cancellation signal", w.Name()) - break loop - } - default: - // no work item found, so sleep until the next backoff - t := time.NewTimer(b.NextBackOff()) - select { - case <-t.C: - // nop - all good - case <-ctx.Done(): - if !t.Stop() { - <-t.C - } - w.logger.Infof("%v: received cancellation signal", w.Name()) - break loop - } + select { + case w.parallelLock <- struct{}{}: + case <-ctx.Done(): + return } - } - w.logger.Infof("%v: stopped listening for new work items", w.Name()) - }() -} + wi, err := w.processor.NextWorkItem(ctx) + if err != nil { + <-w.parallelLock -func (w *worker) ProcessNext(ctx context.Context) (bool, error) { - if !w.dispatchSemaphore.TryAcquire(1) { - w.logger.Debugf("%v: waiting for one of %v in-flight execution(s) to complete", w.Name(), w.dispatchSemaphore.GetCount()) - if err := w.dispatchSemaphore.Acquire(ctx, 1); err != nil { - // cancelled - return false, err - } - } - w.pending.Add(1) + if ctx.Err() != nil { + return + } - processing := false - defer func() { - if !processing { - w.pending.Done() - w.dispatchSemaphore.Release(1) - } - }() + w.logger.Errorf("%v: failed to get next work item: %v", w.Name(), err) + continue + } - wi, err := w.processor.FetchWorkItem(ctx) - switch { - case errors.Is(err, ErrNoWorkItems) || wi == nil: - if !w.waiting { - w.logger.Debugf("%v: waiting for new work items...", w.Name()) - w.waiting = true - } - return false, nil - case err != nil: - if !errors.Is(err, ctx.Err()) { - w.logger.Errorf("%v: failed to fetch work item: %v", w.Name(), err) + w.wg.Add(1) + go func() { + defer func() { + <-w.parallelLock + w.wg.Done() + }() + w.processWorkItem(ctx, wi) + }() } - return false, err - default: - // process the work-item in the background - w.waiting = false - processing = true - go w.processWorkItem(ctx, wi) - return true, nil - } + }() } -func (w *worker) StopAndDrain() { - w.logger.Debugf("%v: stop and drain...", w.Name()) - defer w.logger.Debugf("%v: finished stop and drain...", w.Name()) - - w.stop.Store(true) - - // Cancel the background poller and dispatcher(s) - if w.cancel != nil { - w.cancel() - } - - // Wait for outstanding work-items to finish processing. - // TODO: Need to find a way to cancel this if it takes too long for some reason. - w.pending.Wait() +func (w *worker[T]) StopAndDrain() { + close(w.closeCh) + w.wg.Wait() } -func (w *worker) processWorkItem(ctx context.Context, wi WorkItem) { - defer w.dispatchSemaphore.Release(1) - defer w.pending.Done() - +func (w *worker[T]) processWorkItem(ctx context.Context, wi T) { w.logger.Debugf("%v: processing work item: %s", w.Name(), wi) - if w.stop.Load() { - if err := w.processor.AbandonWorkItem(context.Background(), wi); err != nil { - w.logger.Errorf("%v: failed to abandon work item: %v", w.Name(), err) - } - return - } - if err := w.processor.ProcessWorkItem(ctx, wi); err != nil { - if errors.Is(err, ctx.Err()) { - w.logger.Warnf("%v: abandoning work item due to cancellation", w.Name()) - } else { - w.logger.Errorf("%v: failed to process work item: %v", w.Name(), err) - } - if w.stop.Load() { - ctx = context.Background() - } - if err := w.processor.AbandonWorkItem(ctx, wi); err != nil { + w.logger.Errorf("%v: failed to process work item: %v", w.Name(), err) + if err = w.processor.AbandonWorkItem(context.Background(), wi); err != nil { w.logger.Errorf("%v: failed to abandon work item: %v", w.Name(), err) } return } if err := w.processor.CompleteWorkItem(ctx, wi); err != nil { - if errors.Is(err, ctx.Err()) { - w.logger.Warnf("%v: failed to complete work item due to cancellation", w.Name()) - } else { - w.logger.Errorf("%v: failed to complete work item: %v", w.Name(), err) - } - if w.stop.Load() { - ctx = context.Background() - } - if err := w.processor.AbandonWorkItem(ctx, wi); err != nil { + w.logger.Errorf("%v: failed to complete work item: %v", w.Name(), err) + if err = w.processor.AbandonWorkItem(context.Background(), wi); err != nil { w.logger.Errorf("%v: failed to abandon work item: %v", w.Name(), err) } + return } w.logger.Debugf("%v: work item processed successfully", w.Name()) diff --git a/backend/workitem.go b/backend/workitem.go index 26ab2d7..1a938e6 100644 --- a/backend/workitem.go +++ b/backend/workitem.go @@ -1,15 +1,12 @@ package backend import ( - "errors" "fmt" "time" "github.com/dapr/durabletask-go/api" ) -var ErrNoWorkItems = errors.New("no work items were found") - type WorkItem interface { fmt.Stringer IsWorkItem() bool diff --git a/go.mod b/go.mod index 3d3b832..4ece2fe 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,6 @@ require ( github.com/cenkalti/backoff/v4 v4.2.1 github.com/dapr/kit v0.13.1-0.20250110192255-fb195706966f github.com/google/uuid v1.6.0 - github.com/marusama/semaphore/v2 v2.5.0 github.com/stretchr/testify v1.9.0 go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.44.0 go.opentelemetry.io/otel v1.18.0 diff --git a/go.sum b/go.sum index b9eb71e..7f05915 100644 --- a/go.sum +++ b/go.sum @@ -26,8 +26,6 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/marusama/semaphore/v2 v2.5.0 h1:o/1QJD9DBYOWRnDhPwDVAXQn6mQYD0gZaS1Tpx6DJGM= -github.com/marusama/semaphore/v2 v2.5.0/go.mod h1:z9nMiNUekt/LTpTUQdpp+4sJeYqUGpwMHfW0Z8V8fnQ= github.com/mattn/go-isatty v0.0.16 h1:bq3VjFmv/sOjHtdEhmkEV4x1AJtvUvOJ2PFAZ5+peKQ= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y= diff --git a/submodules/durabletask-protobuf b/submodules/durabletask-protobuf index 4207e1d..c62e0a2 160000 --- a/submodules/durabletask-protobuf +++ b/submodules/durabletask-protobuf @@ -1 +1 @@ -Subproject commit 4207e1dbd14cedc268f69c3befee60fcaad19367 +Subproject commit c62e0a234019a300f334ceedbd96831dbca1bca6 diff --git a/tests/backend_test.go b/tests/backend_test.go index b3483af..1eacb07 100644 --- a/tests/backend_test.go +++ b/tests/backend_test.go @@ -68,10 +68,6 @@ func Test_NewOrchestrationWorkItem_Single(t *testing.T) { assert.Equal(t, 0, len(state.NewEvents())) assert.Equal(t, 0, len(state.OldEvents())) } - - // Ensure no more work items - _, err := be.GetOrchestrationWorkItem(ctx) - assert.ErrorIs(t, err, backend.ErrNoWorkItems) } } } @@ -113,10 +109,6 @@ func Test_NewOrchestrationWorkItem_Multiple(t *testing.T) { } } } - - // Ensure no more work items - _, err := be.GetOrchestrationWorkItem(ctx) - assert.ErrorIs(t, err, backend.ErrNoWorkItems) } } @@ -166,10 +158,6 @@ func Test_CompleteOrchestration(t *testing.T) { // Execute the test, which calls the above callbacks workItemProcessingTestLogic(t, be, getOrchestratorActions, validateMetadata) - - // Ensure no more work items - _, err := be.GetOrchestrationWorkItem(ctx) - assert.ErrorIs(t, err, backend.ErrNoWorkItems) } } } @@ -183,11 +171,6 @@ func Test_ScheduleActivityTasks(t *testing.T) { for i, be := range backends { initTest(t, be, i, true) - wi, err := be.GetActivityWorkItem(ctx) - if !assert.ErrorIs(t, err, backend.ErrNoWorkItems) { - continue - } - // Produce a TaskScheduled event with a particular input getOrchestratorActions := func() []*protos.OrchestratorAction { return []*protos.OrchestratorAction{ @@ -208,21 +191,13 @@ func Test_ScheduleActivityTasks(t *testing.T) { // Execute the test, which calls the above callbacks workItemProcessingTestLogic(t, be, getOrchestratorActions, validateMetadata) - // Ensure no more orchestration work items - _, err = be.GetOrchestrationWorkItem(ctx) - assert.ErrorIs(t, err, backend.ErrNoWorkItems) - // However, there should be an activity work item - wi, err = be.GetActivityWorkItem(ctx) + wi, err := be.NextActivityWorkItem(ctx) if assert.NoError(t, err) && assert.NotNil(t, wi) { assert.Equal(t, expectedName, wi.NewEvent.GetTaskScheduled().GetName()) assert.Equal(t, expectedInput, wi.NewEvent.GetTaskScheduled().GetInput().GetValue()) } - // Ensure no more activity work items - _, err = be.GetActivityWorkItem(ctx) - assert.ErrorIs(t, err, backend.ErrNoWorkItems) - // Complete the fetched activity work item wi.Result = &protos.HistoryEvent{ EventId: -1, @@ -237,7 +212,7 @@ func Test_ScheduleActivityTasks(t *testing.T) { err = be.CompleteActivityWorkItem(ctx, wi) if assert.NoError(t, err) { // Completing the activity work item should create a new TaskCompleted event - wi, err := be.GetOrchestrationWorkItem(ctx) + wi, err := be.NextOrchestrationWorkItem(ctx) if assert.NoError(t, err) && assert.NotNil(t, wi) && assert.Len(t, wi.NewEvents, 1) { assert.Equal(t, expectedTaskID, wi.NewEvents[0].GetTaskCompleted().GetTaskScheduledId()) assert.Equal(t, expectedResult, wi.NewEvents[0].GetTaskCompleted().GetResult().GetValue()) @@ -270,15 +245,11 @@ func Test_ScheduleTimerTasks(t *testing.T) { // Execute the test, which calls the above callbacks workItemProcessingTestLogic(t, be, getOrchestratorActions, validateMetadata) - // Validate that the timer work-item isn't yet visible - _, err := be.GetOrchestrationWorkItem(ctx) - assert.ErrorIs(t, err, backend.ErrNoWorkItems) - // Sleep until the expected visibility time expires time.Sleep(timerDuration) // Validate that the timer work-item is now visible - wi, err := be.GetOrchestrationWorkItem(ctx) + wi, err := be.NextOrchestrationWorkItem(ctx) if assert.NoError(t, err) && assert.Equal(t, 1, len(wi.NewEvents)) { e := wi.NewEvents[0] tf := e.GetTimerFired() @@ -330,15 +301,11 @@ func Test_AbandonActivityWorkItem(t *testing.T) { workItemProcessingTestLogic(t, be, getOrchestratorActions, validateMetadata) // The NewScheduleTaskAction should have created an activity work item - wi, err := be.GetActivityWorkItem(ctx) + wi, err := be.NextActivityWorkItem(ctx) if assert.NoError(t, err) && assert.NotNil(t, wi) { - // Ensure no more activity work items - _, err = be.GetActivityWorkItem(ctx) - assert.ErrorIs(t, err, backend.ErrNoWorkItems) - if err := be.AbandonActivityWorkItem(ctx, wi); assert.NoError(t, err) { // Re-fetch the abandoned activity work item - wi, err = be.GetActivityWorkItem(ctx) + wi, err = be.NextActivityWorkItem(ctx) assert.Equal(t, "MyActivity", wi.NewEvent.GetTaskScheduled().GetName()) assert.Equal(t, int32(123), wi.NewEvent.EventId) assert.Nil(t, wi.NewEvent.GetTaskScheduled().GetInput()) @@ -361,9 +328,9 @@ func Test_UninitializedBackend(t *testing.T) { assert.Equal(t, err, backend.ErrNotInitialized) _, err = be.GetOrchestrationRuntimeState(ctx, nil) assert.Equal(t, err, backend.ErrNotInitialized) - _, err = be.GetOrchestrationWorkItem(ctx) + _, err = be.NextOrchestrationWorkItem(ctx) assert.Equal(t, err, backend.ErrNotInitialized) - _, err = be.GetActivityWorkItem(ctx) + _, err = be.NextActivityWorkItem(ctx) assert.Equal(t, err, backend.ErrNotInitialized) } } @@ -511,7 +478,7 @@ func createOrchestrationInstance(t assert.TestingT, be backend.Backend, instance } func getOrchestrationWorkItem(t assert.TestingT, be backend.Backend, expectedInstanceID string) (*backend.OrchestrationWorkItem, bool) { - wi, err := be.GetOrchestrationWorkItem(ctx) + wi, err := be.NextOrchestrationWorkItem(ctx) if assert.NoError(t, err) && assert.NotNil(t, wi) { assert.NotEmpty(t, wi.LockedBy) return wi, assert.Equal(t, expectedInstanceID, string(wi.InstanceID)) diff --git a/tests/mocks/Backend.go b/tests/mocks/Backend.go index 78ef150..8e33f40 100644 --- a/tests/mocks/Backend.go +++ b/tests/mocks/Backend.go @@ -416,64 +416,6 @@ func (_c *Backend_DeleteTaskHub_Call) RunAndReturn(run func(context.Context) err return _c } -// GetActivityWorkItem provides a mock function with given fields: _a0 -func (_m *Backend) GetActivityWorkItem(_a0 context.Context) (*backend.ActivityWorkItem, error) { - ret := _m.Called(_a0) - - if len(ret) == 0 { - panic("no return value specified for GetActivityWorkItem") - } - - var r0 *backend.ActivityWorkItem - var r1 error - if rf, ok := ret.Get(0).(func(context.Context) (*backend.ActivityWorkItem, error)); ok { - return rf(_a0) - } - if rf, ok := ret.Get(0).(func(context.Context) *backend.ActivityWorkItem); ok { - r0 = rf(_a0) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*backend.ActivityWorkItem) - } - } - - if rf, ok := ret.Get(1).(func(context.Context) error); ok { - r1 = rf(_a0) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// Backend_GetActivityWorkItem_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetActivityWorkItem' -type Backend_GetActivityWorkItem_Call struct { - *mock.Call -} - -// GetActivityWorkItem is a helper method to define mock.On call -// - _a0 context.Context -func (_e *Backend_Expecter) GetActivityWorkItem(_a0 interface{}) *Backend_GetActivityWorkItem_Call { - return &Backend_GetActivityWorkItem_Call{Call: _e.mock.On("GetActivityWorkItem", _a0)} -} - -func (_c *Backend_GetActivityWorkItem_Call) Run(run func(_a0 context.Context)) *Backend_GetActivityWorkItem_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context)) - }) - return _c -} - -func (_c *Backend_GetActivityWorkItem_Call) Return(_a0 *backend.ActivityWorkItem, _a1 error) *Backend_GetActivityWorkItem_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *Backend_GetActivityWorkItem_Call) RunAndReturn(run func(context.Context) (*backend.ActivityWorkItem, error)) *Backend_GetActivityWorkItem_Call { - _c.Call.Return(run) - return _c -} - // GetOrchestrationMetadata provides a mock function with given fields: _a0, _a1 func (_m *Backend) GetOrchestrationMetadata(_a0 context.Context, _a1 api.InstanceID) (*protos.OrchestrationMetadata, error) { ret := _m.Called(_a0, _a1) @@ -592,12 +534,70 @@ func (_c *Backend_GetOrchestrationRuntimeState_Call) RunAndReturn(run func(conte return _c } -// GetOrchestrationWorkItem provides a mock function with given fields: _a0 -func (_m *Backend) GetOrchestrationWorkItem(_a0 context.Context) (*backend.OrchestrationWorkItem, error) { +// NextActivityWorkItem provides a mock function with given fields: _a0 +func (_m *Backend) NextActivityWorkItem(_a0 context.Context) (*backend.ActivityWorkItem, error) { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for NextActivityWorkItem") + } + + var r0 *backend.ActivityWorkItem + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (*backend.ActivityWorkItem, error)); ok { + return rf(_a0) + } + if rf, ok := ret.Get(0).(func(context.Context) *backend.ActivityWorkItem); ok { + r0 = rf(_a0) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*backend.ActivityWorkItem) + } + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(_a0) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Backend_NextActivityWorkItem_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'NextActivityWorkItem' +type Backend_NextActivityWorkItem_Call struct { + *mock.Call +} + +// NextActivityWorkItem is a helper method to define mock.On call +// - _a0 context.Context +func (_e *Backend_Expecter) NextActivityWorkItem(_a0 interface{}) *Backend_NextActivityWorkItem_Call { + return &Backend_NextActivityWorkItem_Call{Call: _e.mock.On("NextActivityWorkItem", _a0)} +} + +func (_c *Backend_NextActivityWorkItem_Call) Run(run func(_a0 context.Context)) *Backend_NextActivityWorkItem_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *Backend_NextActivityWorkItem_Call) Return(_a0 *backend.ActivityWorkItem, _a1 error) *Backend_NextActivityWorkItem_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Backend_NextActivityWorkItem_Call) RunAndReturn(run func(context.Context) (*backend.ActivityWorkItem, error)) *Backend_NextActivityWorkItem_Call { + _c.Call.Return(run) + return _c +} + +// NextOrchestrationWorkItem provides a mock function with given fields: _a0 +func (_m *Backend) NextOrchestrationWorkItem(_a0 context.Context) (*backend.OrchestrationWorkItem, error) { ret := _m.Called(_a0) if len(ret) == 0 { - panic("no return value specified for GetOrchestrationWorkItem") + panic("no return value specified for NextOrchestrationWorkItem") } var r0 *backend.OrchestrationWorkItem @@ -622,30 +622,30 @@ func (_m *Backend) GetOrchestrationWorkItem(_a0 context.Context) (*backend.Orche return r0, r1 } -// Backend_GetOrchestrationWorkItem_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetOrchestrationWorkItem' -type Backend_GetOrchestrationWorkItem_Call struct { +// Backend_NextOrchestrationWorkItem_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'NextOrchestrationWorkItem' +type Backend_NextOrchestrationWorkItem_Call struct { *mock.Call } -// GetOrchestrationWorkItem is a helper method to define mock.On call +// NextOrchestrationWorkItem is a helper method to define mock.On call // - _a0 context.Context -func (_e *Backend_Expecter) GetOrchestrationWorkItem(_a0 interface{}) *Backend_GetOrchestrationWorkItem_Call { - return &Backend_GetOrchestrationWorkItem_Call{Call: _e.mock.On("GetOrchestrationWorkItem", _a0)} +func (_e *Backend_Expecter) NextOrchestrationWorkItem(_a0 interface{}) *Backend_NextOrchestrationWorkItem_Call { + return &Backend_NextOrchestrationWorkItem_Call{Call: _e.mock.On("NextOrchestrationWorkItem", _a0)} } -func (_c *Backend_GetOrchestrationWorkItem_Call) Run(run func(_a0 context.Context)) *Backend_GetOrchestrationWorkItem_Call { +func (_c *Backend_NextOrchestrationWorkItem_Call) Run(run func(_a0 context.Context)) *Backend_NextOrchestrationWorkItem_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context)) }) return _c } -func (_c *Backend_GetOrchestrationWorkItem_Call) Return(_a0 *backend.OrchestrationWorkItem, _a1 error) *Backend_GetOrchestrationWorkItem_Call { +func (_c *Backend_NextOrchestrationWorkItem_Call) Return(_a0 *backend.OrchestrationWorkItem, _a1 error) *Backend_NextOrchestrationWorkItem_Call { _c.Call.Return(_a0, _a1) return _c } -func (_c *Backend_GetOrchestrationWorkItem_Call) RunAndReturn(run func(context.Context) (*backend.OrchestrationWorkItem, error)) *Backend_GetOrchestrationWorkItem_Call { +func (_c *Backend_NextOrchestrationWorkItem_Call) RunAndReturn(run func(context.Context) (*backend.OrchestrationWorkItem, error)) *Backend_NextOrchestrationWorkItem_Call { _c.Call.Return(run) return _c } diff --git a/tests/mocks/TaskWorker.go b/tests/mocks/TaskWorker.go index 4fb7bea..35c04d9 100644 --- a/tests/mocks/TaskWorker.go +++ b/tests/mocks/TaskWorker.go @@ -5,150 +5,96 @@ package mocks import ( context "context" + backend "github.com/dapr/durabletask-go/backend" + mock "github.com/stretchr/testify/mock" ) // TaskWorker is an autogenerated mock type for the TaskWorker type -type TaskWorker struct { +type TaskWorker[T backend.WorkItem] struct { mock.Mock } -type TaskWorker_Expecter struct { +type TaskWorker_Expecter[T backend.WorkItem] struct { mock *mock.Mock } -func (_m *TaskWorker) EXPECT() *TaskWorker_Expecter { - return &TaskWorker_Expecter{mock: &_m.Mock} -} - -// ProcessNext provides a mock function with given fields: _a0 -func (_m *TaskWorker) ProcessNext(_a0 context.Context) (bool, error) { - ret := _m.Called(_a0) - - if len(ret) == 0 { - panic("no return value specified for ProcessNext") - } - - var r0 bool - var r1 error - if rf, ok := ret.Get(0).(func(context.Context) (bool, error)); ok { - return rf(_a0) - } - if rf, ok := ret.Get(0).(func(context.Context) bool); ok { - r0 = rf(_a0) - } else { - r0 = ret.Get(0).(bool) - } - - if rf, ok := ret.Get(1).(func(context.Context) error); ok { - r1 = rf(_a0) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// TaskWorker_ProcessNext_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ProcessNext' -type TaskWorker_ProcessNext_Call struct { - *mock.Call -} - -// ProcessNext is a helper method to define mock.On call -// - _a0 context.Context -func (_e *TaskWorker_Expecter) ProcessNext(_a0 interface{}) *TaskWorker_ProcessNext_Call { - return &TaskWorker_ProcessNext_Call{Call: _e.mock.On("ProcessNext", _a0)} -} - -func (_c *TaskWorker_ProcessNext_Call) Run(run func(_a0 context.Context)) *TaskWorker_ProcessNext_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context)) - }) - return _c -} - -func (_c *TaskWorker_ProcessNext_Call) Return(_a0 bool, _a1 error) *TaskWorker_ProcessNext_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *TaskWorker_ProcessNext_Call) RunAndReturn(run func(context.Context) (bool, error)) *TaskWorker_ProcessNext_Call { - _c.Call.Return(run) - return _c +func (_m *TaskWorker[T]) EXPECT() *TaskWorker_Expecter[T] { + return &TaskWorker_Expecter[T]{mock: &_m.Mock} } // Start provides a mock function with given fields: _a0 -func (_m *TaskWorker) Start(_a0 context.Context) { +func (_m *TaskWorker[T]) Start(_a0 context.Context) { _m.Called(_a0) } // TaskWorker_Start_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Start' -type TaskWorker_Start_Call struct { +type TaskWorker_Start_Call[T backend.WorkItem] struct { *mock.Call } // Start is a helper method to define mock.On call // - _a0 context.Context -func (_e *TaskWorker_Expecter) Start(_a0 interface{}) *TaskWorker_Start_Call { - return &TaskWorker_Start_Call{Call: _e.mock.On("Start", _a0)} +func (_e *TaskWorker_Expecter[T]) Start(_a0 interface{}) *TaskWorker_Start_Call[T] { + return &TaskWorker_Start_Call[T]{Call: _e.mock.On("Start", _a0)} } -func (_c *TaskWorker_Start_Call) Run(run func(_a0 context.Context)) *TaskWorker_Start_Call { +func (_c *TaskWorker_Start_Call[T]) Run(run func(_a0 context.Context)) *TaskWorker_Start_Call[T] { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context)) }) return _c } -func (_c *TaskWorker_Start_Call) Return() *TaskWorker_Start_Call { +func (_c *TaskWorker_Start_Call[T]) Return() *TaskWorker_Start_Call[T] { _c.Call.Return() return _c } -func (_c *TaskWorker_Start_Call) RunAndReturn(run func(context.Context)) *TaskWorker_Start_Call { +func (_c *TaskWorker_Start_Call[T]) RunAndReturn(run func(context.Context)) *TaskWorker_Start_Call[T] { _c.Call.Return(run) return _c } // StopAndDrain provides a mock function with given fields: -func (_m *TaskWorker) StopAndDrain() { +func (_m *TaskWorker[T]) StopAndDrain() { _m.Called() } // TaskWorker_StopAndDrain_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'StopAndDrain' -type TaskWorker_StopAndDrain_Call struct { +type TaskWorker_StopAndDrain_Call[T backend.WorkItem] struct { *mock.Call } // StopAndDrain is a helper method to define mock.On call -func (_e *TaskWorker_Expecter) StopAndDrain() *TaskWorker_StopAndDrain_Call { - return &TaskWorker_StopAndDrain_Call{Call: _e.mock.On("StopAndDrain")} +func (_e *TaskWorker_Expecter[T]) StopAndDrain() *TaskWorker_StopAndDrain_Call[T] { + return &TaskWorker_StopAndDrain_Call[T]{Call: _e.mock.On("StopAndDrain")} } -func (_c *TaskWorker_StopAndDrain_Call) Run(run func()) *TaskWorker_StopAndDrain_Call { +func (_c *TaskWorker_StopAndDrain_Call[T]) Run(run func()) *TaskWorker_StopAndDrain_Call[T] { _c.Call.Run(func(args mock.Arguments) { run() }) return _c } -func (_c *TaskWorker_StopAndDrain_Call) Return() *TaskWorker_StopAndDrain_Call { +func (_c *TaskWorker_StopAndDrain_Call[T]) Return() *TaskWorker_StopAndDrain_Call[T] { _c.Call.Return() return _c } -func (_c *TaskWorker_StopAndDrain_Call) RunAndReturn(run func()) *TaskWorker_StopAndDrain_Call { +func (_c *TaskWorker_StopAndDrain_Call[T]) RunAndReturn(run func()) *TaskWorker_StopAndDrain_Call[T] { _c.Call.Return(run) return _c } // NewTaskWorker creates a new instance of TaskWorker. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. -func NewTaskWorker(t interface { +func NewTaskWorker[T backend.WorkItem](t interface { mock.TestingT Cleanup(func()) -}) *TaskWorker { - mock := &TaskWorker{} +}) *TaskWorker[T] { + mock := &TaskWorker[T]{} mock.Mock.Test(t) t.Cleanup(func() { mock.AssertExpectations(t) }) diff --git a/tests/mocks/task.go b/tests/mocks/task.go index ba8a054..f16a969 100644 --- a/tests/mocks/task.go +++ b/tests/mocks/task.go @@ -10,90 +10,87 @@ import ( backend "github.com/dapr/durabletask-go/backend" ) -var _ backend.TaskProcessor = &TestTaskProcessor{} - // TestTaskProcessor implements a dummy task processor useful for testing -type TestTaskProcessor struct { +type TestTaskProcessor[T backend.WorkItem] struct { name string processingBlocked atomic.Bool workItemMu sync.Mutex - workItems []backend.WorkItem + workItems []T abandonedWorkItemMu sync.Mutex - abandonedWorkItems []backend.WorkItem + abandonedWorkItems []T completedWorkItemMu sync.Mutex - completedWorkItems []backend.WorkItem + completedWorkItems []T } -func NewTestTaskPocessor(name string) *TestTaskProcessor { - return &TestTaskProcessor{ +func NewTestTaskPocessor[T backend.WorkItem](name string) *TestTaskProcessor[T] { + return &TestTaskProcessor[T]{ name: name, } } -func (t *TestTaskProcessor) BlockProcessing() { +func (t *TestTaskProcessor[T]) BlockProcessing() { t.processingBlocked.Store(true) } -func (t *TestTaskProcessor) UnblockProcessing() { +func (t *TestTaskProcessor[T]) UnblockProcessing() { t.processingBlocked.Store(false) } -func (t *TestTaskProcessor) PendingWorkItems() []backend.WorkItem { +func (t *TestTaskProcessor[T]) PendingWorkItems() []T { t.workItemMu.Lock() defer t.workItemMu.Unlock() // copy array - return append([]backend.WorkItem{}, t.workItems...) + return append([]T{}, t.workItems...) } -func (t *TestTaskProcessor) AbandonedWorkItems() []backend.WorkItem { +func (t *TestTaskProcessor[T]) AbandonedWorkItems() []T { t.abandonedWorkItemMu.Lock() defer t.abandonedWorkItemMu.Unlock() // copy array - return append([]backend.WorkItem{}, t.abandonedWorkItems...) + return append([]T{}, t.abandonedWorkItems...) } -func (t *TestTaskProcessor) CompletedWorkItems() []backend.WorkItem { +func (t *TestTaskProcessor[T]) CompletedWorkItems() []T { t.completedWorkItemMu.Lock() defer t.completedWorkItemMu.Unlock() // copy array - return append([]backend.WorkItem{}, t.completedWorkItems...) + return append([]T{}, t.completedWorkItems...) } -func (t *TestTaskProcessor) AddWorkItems(wis ...backend.WorkItem) { +func (t *TestTaskProcessor[T]) AddWorkItems(wis ...T) { t.workItemMu.Lock() defer t.workItemMu.Unlock() t.workItems = append(t.workItems, wis...) } -func (t *TestTaskProcessor) Name() string { +func (t *TestTaskProcessor[T]) Name() string { return t.name } -func (t *TestTaskProcessor) FetchWorkItem(context.Context) (backend.WorkItem, error) { +func (t *TestTaskProcessor[T]) NextWorkItem(context.Context) (T, error) { t.workItemMu.Lock() defer t.workItemMu.Unlock() if len(t.workItems) == 0 { - return nil, backend.ErrNoWorkItems + var tt T + return tt, errors.New("no work items") } - // pop first item - i := 0 - wi := t.workItems[i] - t.workItems = append(t.workItems[:i], t.workItems[i+1:]...) + wi := t.workItems[0] + t.workItems = t.workItems[1:] return wi, nil } -func (t *TestTaskProcessor) ProcessWorkItem(ctx context.Context, wi backend.WorkItem) error { +func (t *TestTaskProcessor[T]) ProcessWorkItem(ctx context.Context, wi T) error { if !t.processingBlocked.Load() { return nil } @@ -111,7 +108,7 @@ func (t *TestTaskProcessor) ProcessWorkItem(ctx context.Context, wi backend.Work } } -func (t *TestTaskProcessor) AbandonWorkItem(ctx context.Context, wi backend.WorkItem) error { +func (t *TestTaskProcessor[T]) AbandonWorkItem(ctx context.Context, wi T) error { t.abandonedWorkItemMu.Lock() defer t.abandonedWorkItemMu.Unlock() @@ -119,7 +116,7 @@ func (t *TestTaskProcessor) AbandonWorkItem(ctx context.Context, wi backend.Work return nil } -func (t *TestTaskProcessor) CompleteWorkItem(ctx context.Context, wi backend.WorkItem) error { +func (t *TestTaskProcessor[T]) CompleteWorkItem(ctx context.Context, wi T) error { t.completedWorkItemMu.Lock() defer t.completedWorkItemMu.Unlock() diff --git a/tests/taskhub_test.go b/tests/taskhub_test.go index f9a235c..538004e 100644 --- a/tests/taskhub_test.go +++ b/tests/taskhub_test.go @@ -13,8 +13,8 @@ func Test_TaskHubWorkerStartsDependencies(t *testing.T) { ctx := context.Background() be := mocks.NewBackend(t) - orchWorker := mocks.NewTaskWorker(t) - actWorker := mocks.NewTaskWorker(t) + orchWorker := mocks.NewTaskWorker[*backend.OrchestrationWorkItem](t) + actWorker := mocks.NewTaskWorker[*backend.ActivityWorkItem](t) be.EXPECT().CreateTaskHub(ctx).Return(nil).Once() be.EXPECT().Start(ctx).Return(nil).Once() @@ -30,8 +30,8 @@ func Test_TaskHubWorkerStopsDependencies(t *testing.T) { ctx := context.Background() be := mocks.NewBackend(t) - orchWorker := mocks.NewTaskWorker(t) - actWorker := mocks.NewTaskWorker(t) + orchWorker := mocks.NewTaskWorker[*backend.OrchestrationWorkItem](t) + actWorker := mocks.NewTaskWorker[*backend.ActivityWorkItem](t) be.EXPECT().Stop(ctx).Return(nil).Once() orchWorker.EXPECT().StopAndDrain().Return().Once() diff --git a/tests/worker_test.go b/tests/worker_test.go index 7f67357..098f096 100644 --- a/tests/worker_test.go +++ b/tests/worker_test.go @@ -2,6 +2,7 @@ package tests import ( "context" + "errors" "sync/atomic" "testing" "time" @@ -46,9 +47,13 @@ func Test_TryProcessSingleOrchestrationWorkItem_BasicFlow(t *testing.T) { state := &backend.OrchestrationRuntimeState{} result := &backend.ExecutionResults{Response: &protos.OrchestratorResponse{}} + ctx, cancel := context.WithCancel(ctx) completed := atomic.Bool{} be := mocks.NewBackend(t) - be.EXPECT().GetOrchestrationWorkItem(anyContext).Return(wi, nil).Once() + be.EXPECT().NextOrchestrationWorkItem(anyContext).Return(wi, nil).Once() + be.EXPECT().NextOrchestrationWorkItem(anyContext).Return(nil, errors.New("")).Once().Run(func(mock.Arguments) { + cancel() + }) be.EXPECT().GetOrchestrationRuntimeState(anyContext, wi).Return(state, nil).Once() be.EXPECT().CompleteOrchestrationWorkItem(anyContext, wi).RunAndReturn(func(ctx context.Context, owi *backend.OrchestrationWorkItem) error { completed.Store(true) @@ -59,10 +64,7 @@ func Test_TryProcessSingleOrchestrationWorkItem_BasicFlow(t *testing.T) { ex.EXPECT().ExecuteOrchestrator(anyContext, wi.InstanceID, state.OldEvents(), mock.Anything).Return(result, nil).Once() worker := backend.NewOrchestrationWorker(be, ex, logger) - ok, err := worker.ProcessNext(ctx) - // Successfully processing a work-item should result in a nil error - assert.Nil(t, err) - assert.True(t, ok) + worker.Start(ctx) require.EventuallyWithT(t, func(collect *assert.CollectT) { if !completed.Load() { @@ -73,17 +75,6 @@ func Test_TryProcessSingleOrchestrationWorkItem_BasicFlow(t *testing.T) { worker.StopAndDrain() } -func Test_TryProcessSingleOrchestrationWorkItem_NoWorkItems(t *testing.T) { - ctx := context.Background() - be := mocks.NewBackend(t) - be.EXPECT().GetOrchestrationWorkItem(anyContext).Return(nil, backend.ErrNoWorkItems).Once() - - w := backend.NewOrchestrationWorker(be, nil, logger) - ok, err := w.ProcessNext(ctx) - assert.Nil(t, err) - assert.False(t, ok) -} - func Test_TryProcessSingleOrchestrationWorkItem_ExecutionStartedAndCompleted(t *testing.T) { ctx := context.Background() iid := api.InstanceID("test123") @@ -111,8 +102,13 @@ func Test_TryProcessSingleOrchestrationWorkItem_ExecutionStartedAndCompleted(t * // Empty orchestration runtime state since we're starting a new execution from scratch state := backend.NewOrchestrationRuntimeState(iid, []*protos.HistoryEvent{}) + ctx, cancel := context.WithCancel(ctx) be := mocks.NewBackend(t) - be.EXPECT().GetOrchestrationWorkItem(anyContext).Return(wi, nil).Once() + be.EXPECT().NextOrchestrationWorkItem(anyContext).Return(wi, nil).Once() + be.EXPECT().NextOrchestrationWorkItem(anyContext).Return(nil, errors.New("")).Once().Run(func(mock.Arguments) { + cancel() + }) + be.EXPECT().GetOrchestrationRuntimeState(anyContext, wi).Return(state, nil).Once() ex := mocks.NewExecutor(t) @@ -148,10 +144,11 @@ func Test_TryProcessSingleOrchestrationWorkItem_ExecutionStartedAndCompleted(t * // Set up and run the test worker := backend.NewOrchestrationWorker(be, ex, logger) - ok, err := worker.ProcessNext(ctx) - // Successfully processing a work-item should result in a nil error - assert.Nil(t, err) - assert.True(t, ok) + worker.Start(ctx) + //ok, err := worker.ProcessNext(ctx) + //// Successfully processing a work-item should result in a nil error + //assert.Nil(t, err) + //assert.True(t, ok) require.EventuallyWithT(t, func(collect *assert.CollectT) { if !completed.Load() { @@ -166,18 +163,18 @@ func Test_TaskWorker(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - tp := mocks.NewTestTaskPocessor("test") + tp := mocks.NewTestTaskPocessor[*backend.ActivityWorkItem]("test") tp.UnblockProcessing() - first := backend.ActivityWorkItem{ + first := &backend.ActivityWorkItem{ SequenceNumber: 1, } - second := backend.ActivityWorkItem{ + second := &backend.ActivityWorkItem{ SequenceNumber: 2, } tp.AddWorkItems(first, second) - worker := backend.NewTaskWorker(tp, logger) + worker := backend.NewTaskWorker[*backend.ActivityWorkItem](tp, logger) worker.Start(ctx) @@ -213,7 +210,7 @@ func Test_StartAndStop(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - tp := mocks.NewTestTaskPocessor("test") + tp := mocks.NewTestTaskPocessor[*backend.ActivityWorkItem]("test") tp.BlockProcessing() first := backend.ActivityWorkItem{ @@ -222,21 +219,17 @@ func Test_StartAndStop(t *testing.T) { second := backend.ActivityWorkItem{ SequenceNumber: 2, } - tp.AddWorkItems(first, second) + tp.AddWorkItems(&first, &second) - worker := backend.NewTaskWorker(tp, logger) + worker := backend.NewTaskWorker[*backend.ActivityWorkItem](tp, logger) worker.Start(ctx) - require.EventuallyWithT(t, func(collect *assert.CollectT) { - if len(tp.PendingWorkItems()) == 1 { - return - } - collect.Errorf("first work item not consumed yet") - }, 500*time.Millisecond, 100*time.Millisecond) + require.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Len(c, tp.PendingWorkItems(), 1) + }, time.Second*5, 100*time.Millisecond) // due to the configuration of the TestTaskProcessor, now the work item is blocked on ProcessWorkItem until the context is cancelled - drainFinished := make(chan bool) go func() { worker.StopAndDrain()