From 78eae56c9af1add111daa410d356b40ae01c55d7 Mon Sep 17 00:00:00 2001 From: joshvanl Date: Tue, 14 Jan 2025 11:41:46 +0000 Subject: [PATCH] Make work queue push based The work item queue was pull based, meaning that if the backed had no work to do, the processor would both spin wheels doing no work as well as processing the next work item slower as there was a backoff delay to when there was no work. Updates the queue to be pull based, calling a function which hangs until there is either work to do or the context is cancelled. Signed-off-by: joshvanl --- backend/activity.go | 24 ++-- backend/backend.go | 12 +- backend/orchestration.go | 23 ++- backend/sqlite/sqlite.go | 81 ++++++++++- backend/taskhub.go | 11 +- backend/worker.go | 244 +++++++++----------------------- backend/workitem.go | 3 - go.mod | 1 - go.sum | 2 - submodules/durabletask-protobuf | 2 +- tests/backend_test.go | 49 ++----- tests/mocks/Backend.go | 138 +++++++++--------- tests/mocks/TaskWorker.go | 100 +++---------- tests/mocks/task.go | 51 ++++--- tests/taskhub_test.go | 8 +- tests/worker_test.go | 63 ++++----- 16 files changed, 332 insertions(+), 480 deletions(-) diff --git a/backend/activity.go b/backend/activity.go index 0f2423f..970169c 100644 --- a/backend/activity.go +++ b/backend/activity.go @@ -21,12 +21,12 @@ type ActivityExecutor interface { ExecuteActivity(context.Context, api.InstanceID, *protos.HistoryEvent) (*protos.HistoryEvent, error) } -func NewActivityTaskWorker(be Backend, executor ActivityExecutor, logger Logger, opts ...NewTaskWorkerOptions) TaskWorker { +func NewActivityTaskWorker(be Backend, executor ActivityExecutor, logger Logger, opts ...NewTaskWorkerOptions) TaskWorker[*ActivityWorkItem] { processor := newActivityProcessor(be, executor) return NewTaskWorker(processor, logger, opts...) } -func newActivityProcessor(be Backend, executor ActivityExecutor) TaskProcessor { +func newActivityProcessor(be Backend, executor ActivityExecutor) TaskProcessor[*ActivityWorkItem] { return &activityProcessor{ be: be, executor: executor, @@ -38,15 +38,13 @@ func (*activityProcessor) Name() string { return "activity-processor" } -// FetchWorkItem implements TaskDispatcher -func (ap *activityProcessor) FetchWorkItem(ctx context.Context) (WorkItem, error) { - return ap.be.GetActivityWorkItem(ctx) +// NextWorkItem implements TaskDispatcher +func (ap *activityProcessor) NextWorkItem(ctx context.Context) (*ActivityWorkItem, error) { + return ap.be.NextActivityWorkItem(ctx) } // ProcessWorkItem implements TaskDispatcher -func (p *activityProcessor) ProcessWorkItem(ctx context.Context, wi WorkItem) error { - awi := wi.(*ActivityWorkItem) - +func (p *activityProcessor) ProcessWorkItem(ctx context.Context, awi *ActivityWorkItem) error { ts := awi.NewEvent.GetTaskScheduled() if ts == nil { return fmt.Errorf("%v: invalid TaskScheduled event", awi.InstanceID) @@ -83,20 +81,18 @@ func (p *activityProcessor) ProcessWorkItem(ctx context.Context, wi WorkItem) er } // CompleteWorkItem implements TaskDispatcher -func (ap *activityProcessor) CompleteWorkItem(ctx context.Context, wi WorkItem) error { - awi := wi.(*ActivityWorkItem) +func (ap *activityProcessor) CompleteWorkItem(ctx context.Context, awi *ActivityWorkItem) error { if awi.Result == nil { - return fmt.Errorf("can't complete work item '%s' with nil result", wi) + return fmt.Errorf("can't complete work item '%s' with nil result", awi) } if awi.Result.GetTaskCompleted() == nil && awi.Result.GetTaskFailed() == nil { - return fmt.Errorf("can't complete work item '%s', which isn't TaskCompleted or TaskFailed", wi) + return fmt.Errorf("can't complete work item '%s', which isn't TaskCompleted or TaskFailed", awi) } return ap.be.CompleteActivityWorkItem(ctx, awi) } // AbandonWorkItem implements TaskDispatcher -func (ap *activityProcessor) AbandonWorkItem(ctx context.Context, wi WorkItem) error { - awi := wi.(*ActivityWorkItem) +func (ap *activityProcessor) AbandonWorkItem(ctx context.Context, awi *ActivityWorkItem) error { return ap.be.AbandonActivityWorkItem(ctx, awi) } diff --git a/backend/backend.go b/backend/backend.go index 61692a9..cfaebb3 100644 --- a/backend/backend.go +++ b/backend/backend.go @@ -68,9 +68,9 @@ type Backend interface { // AddNewEvent adds a new orchestration event to the specified orchestration instance. AddNewOrchestrationEvent(context.Context, api.InstanceID, *HistoryEvent) error - // GetOrchestrationWorkItem gets a pending work item from the task hub or returns [ErrNoOrchWorkItems] - // if there are no pending work items. - GetOrchestrationWorkItem(context.Context) (*OrchestrationWorkItem, error) + // NextOrchestrationWorkItem blocks and returns the next orchestration work + // item from the task hub. Should only return an error when shutting down. + NextOrchestrationWorkItem(context.Context) (*OrchestrationWorkItem, error) // GetOrchestrationRuntimeState gets the runtime state of an orchestration instance. GetOrchestrationRuntimeState(context.Context, *OrchestrationWorkItem) (*OrchestrationRuntimeState, error) @@ -97,9 +97,9 @@ type Backend interface { // completes with a failure is still considered a successfully processed work item). AbandonOrchestrationWorkItem(context.Context, *OrchestrationWorkItem) error - // GetActivityWorkItem gets a pending activity work item from the task hub or returns [ErrNoWorkItems] - // if there are no pending activity work items. - GetActivityWorkItem(context.Context) (*ActivityWorkItem, error) + // NextActivityWorkItem blocks and returns the next activity work item from + // the task hub. Should only return an error when shutting down. + NextActivityWorkItem(context.Context) (*ActivityWorkItem, error) // CompleteActivityWorkItem sends a message to the parent orchestration indicating activity completion. // diff --git a/backend/orchestration.go b/backend/orchestration.go index b37653f..a1a3c2d 100644 --- a/backend/orchestration.go +++ b/backend/orchestration.go @@ -30,13 +30,13 @@ type orchestratorProcessor struct { logger Logger } -func NewOrchestrationWorker(be Backend, executor OrchestratorExecutor, logger Logger, opts ...NewTaskWorkerOptions) TaskWorker { +func NewOrchestrationWorker(be Backend, executor OrchestratorExecutor, logger Logger, opts ...NewTaskWorkerOptions) TaskWorker[*OrchestrationWorkItem] { processor := &orchestratorProcessor{ be: be, executor: executor, logger: logger, } - return NewTaskWorker(processor, logger, opts...) + return NewTaskWorker[*OrchestrationWorkItem](processor, logger, opts...) } // Name implements TaskProcessor @@ -44,14 +44,13 @@ func (*orchestratorProcessor) Name() string { return "orchestration-processor" } -// FetchWorkItem implements TaskProcessor -func (p *orchestratorProcessor) FetchWorkItem(ctx context.Context) (WorkItem, error) { - return p.be.GetOrchestrationWorkItem(ctx) +// NextWorkItem implements TaskProcessor +func (p *orchestratorProcessor) NextWorkItem(ctx context.Context) (*OrchestrationWorkItem, error) { + return p.be.NextOrchestrationWorkItem(ctx) } // ProcessWorkItem implements TaskProcessor -func (w *orchestratorProcessor) ProcessWorkItem(ctx context.Context, cwi WorkItem) error { - wi := cwi.(*OrchestrationWorkItem) +func (w *orchestratorProcessor) ProcessWorkItem(ctx context.Context, wi *OrchestrationWorkItem) error { w.logger.Debugf("%v: received work item with %d new event(s): %v", wi.InstanceID, len(wi.NewEvents), helpers.HistoryListSummary(wi.NewEvents)) // TODO: Caching @@ -131,15 +130,13 @@ func (w *orchestratorProcessor) ProcessWorkItem(ctx context.Context, cwi WorkIte } // CompleteWorkItem implements TaskProcessor -func (p *orchestratorProcessor) CompleteWorkItem(ctx context.Context, wi WorkItem) error { - owi := wi.(*OrchestrationWorkItem) - return p.be.CompleteOrchestrationWorkItem(ctx, owi) +func (p *orchestratorProcessor) CompleteWorkItem(ctx context.Context, wi *OrchestrationWorkItem) error { + return p.be.CompleteOrchestrationWorkItem(ctx, wi) } // AbandonWorkItem implements TaskProcessor -func (p *orchestratorProcessor) AbandonWorkItem(ctx context.Context, wi WorkItem) error { - owi := wi.(*OrchestrationWorkItem) - return p.be.AbandonOrchestrationWorkItem(ctx, owi) +func (p *orchestratorProcessor) AbandonWorkItem(ctx context.Context, wi *OrchestrationWorkItem) error { + return p.be.AbandonOrchestrationWorkItem(ctx, wi) } func (w *orchestratorProcessor) applyWorkItem(ctx context.Context, wi *OrchestrationWorkItem) (context.Context, trace.Span, bool) { diff --git a/backend/sqlite/sqlite.go b/backend/sqlite/sqlite.go index 962c5b6..8701650 100644 --- a/backend/sqlite/sqlite.go +++ b/backend/sqlite/sqlite.go @@ -28,6 +28,8 @@ var schema string var emptyString string = "" +var errNoWorkItems = errors.New("no work items were found") + type SqliteOptions struct { OrchestrationLockTimeout time.Duration ActivityLockTimeout time.Duration @@ -40,6 +42,9 @@ type sqliteBackend struct { workerName string logger backend.Logger options *SqliteOptions + + activityWorker *backend.TaskWorker[*backend.ActivityWorkItem] + orchestrationWorker *backend.TaskWorker[*backend.OrchestrationWorkItem] } // NewSqliteOptions creates a new options object for the sqlite backend provider. @@ -778,8 +783,8 @@ func (be *sqliteBackend) GetOrchestrationRuntimeState(ctx context.Context, wi *b return state, nil } -// GetOrchestrationWorkItem implements backend.Backend -func (be *sqliteBackend) GetOrchestrationWorkItem(ctx context.Context) (*backend.OrchestrationWorkItem, error) { +// getOrchestrationWorkItem implements backend.Backend +func (be *sqliteBackend) getOrchestrationWorkItem(ctx context.Context) (*backend.OrchestrationWorkItem, error) { if err := be.ensureDB(); err != nil { return nil, err } @@ -819,7 +824,7 @@ func (be *sqliteBackend) GetOrchestrationWorkItem(ctx context.Context) (*backend if err := row.Scan(&instanceID); err != nil { if err == sql.ErrNoRows { // No new events to process - return nil, backend.ErrNoWorkItems + return nil, errNoWorkItems } return nil, fmt.Errorf("failed to scan the orchestration work-item: %w", err) @@ -878,7 +883,73 @@ func (be *sqliteBackend) GetOrchestrationWorkItem(ctx context.Context) (*backend return wi, nil } -func (be *sqliteBackend) GetActivityWorkItem(ctx context.Context) (*backend.ActivityWorkItem, error) { +func (be *sqliteBackend) NextOrchestrationWorkItem(ctx context.Context) (*backend.OrchestrationWorkItem, error) { + b := backoff.WithContext(&backoff.ExponentialBackOff{ + InitialInterval: 50 * time.Millisecond, + MaxInterval: 5 * time.Second, + Multiplier: 1.05, + RandomizationFactor: 0.05, + Stop: backoff.Stop, + Clock: backoff.SystemClock, + }, ctx) + + for { + wi, err := be.getOrchestrationWorkItem(ctx) + if err == nil { + return wi, nil + } + + if !errors.Is(err, errNoWorkItems) { + return nil, err + } + + t := time.NewTimer(b.NextBackOff()) + select { + case <-t.C: + case <-ctx.Done(): + if !t.Stop() { + <-t.C + } + be.logger.Info("Activity: received cancellation signal") + return nil, ctx.Err() + } + } +} + +func (be *sqliteBackend) NextActivityWorkItem(ctx context.Context) (*backend.ActivityWorkItem, error) { + b := backoff.WithContext(&backoff.ExponentialBackOff{ + InitialInterval: 50 * time.Millisecond, + MaxInterval: 5 * time.Second, + Multiplier: 1.05, + RandomizationFactor: 0.05, + Stop: backoff.Stop, + Clock: backoff.SystemClock, + }, ctx) + + for { + wi, err := be.getActivityWorkItem(ctx) + if err == nil { + return wi, nil + } + + if !errors.Is(err, errNoWorkItems) { + return nil, err + } + + t := time.NewTimer(b.NextBackOff()) + select { + case <-t.C: + case <-ctx.Done(): + if !t.Stop() { + <-t.C + } + be.logger.Info("Activity: received cancellation signal") + return nil, ctx.Err() + } + } +} + +func (be *sqliteBackend) getActivityWorkItem(ctx context.Context) (*backend.ActivityWorkItem, error) { if err := be.ensureDB(); err != nil { return nil, err } @@ -910,7 +981,7 @@ func (be *sqliteBackend) GetActivityWorkItem(ctx context.Context) (*backend.Acti if err := row.Scan(&sequenceNumber, &instanceID, &eventPayload); err != nil { if err == sql.ErrNoRows { // No new activity tasks to process - return nil, backend.ErrNoWorkItems + return nil, errNoWorkItems } return nil, fmt.Errorf("failed to scan the activity work-item: %w", err) diff --git a/backend/taskhub.go b/backend/taskhub.go index 688cb73..2fa92cc 100644 --- a/backend/taskhub.go +++ b/backend/taskhub.go @@ -15,12 +15,12 @@ type TaskHubWorker interface { type taskHubWorker struct { backend Backend - orchestrationWorker TaskWorker - activityWorker TaskWorker + orchestrationWorker TaskWorker[*OrchestrationWorkItem] + activityWorker TaskWorker[*ActivityWorkItem] logger Logger } -func NewTaskHubWorker(be Backend, orchestrationWorker TaskWorker, activityWorker TaskWorker, logger Logger) TaskHubWorker { +func NewTaskHubWorker(be Backend, orchestrationWorker TaskWorker[*OrchestrationWorkItem], activityWorker TaskWorker[*ActivityWorkItem], logger Logger) TaskHubWorker { return &taskHubWorker{ backend: be, orchestrationWorker: orchestrationWorker, @@ -30,13 +30,14 @@ func NewTaskHubWorker(be Backend, orchestrationWorker TaskWorker, activityWorker } func (w *taskHubWorker) Start(ctx context.Context) error { - // TODO: Check for already started worker if err := w.backend.CreateTaskHub(ctx); err != nil && err != ErrTaskHubExists { return err } + if err := w.backend.Start(ctx); err != nil { return err } + w.logger.Infof("worker started with backend %v", w.backend) w.orchestrationWorker.Start(ctx) @@ -53,7 +54,7 @@ func (w *taskHubWorker) Shutdown(ctx context.Context) error { w.logger.Info("workers stopping and draining...") defer w.logger.Info("finished stopping and draining workers!") - wg := sync.WaitGroup{} + var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() diff --git a/backend/worker.go b/backend/worker.go index 640e7bc..e4f38a2 100644 --- a/backend/worker.go +++ b/backend/worker.go @@ -2,51 +2,33 @@ package backend import ( "context" - "errors" "sync" - "sync/atomic" - "time" - - "github.com/cenkalti/backoff/v4" - "github.com/marusama/semaphore/v2" ) -type TaskWorker interface { +type TaskWorker[T WorkItem] interface { // Start starts background polling for the activity work items. Start(context.Context) - // ProcessNext attempts to fetch and process a work item. This method returns - // true if a work item was found and processing started; false otherwise. An - // error is returned if the context is cancelled. - ProcessNext(context.Context) (bool, error) - // StopAndDrain stops the worker and waits for all outstanding work items to finish. StopAndDrain() } -type TaskProcessor interface { +type TaskProcessor[T WorkItem] interface { Name() string - FetchWorkItem(context.Context) (WorkItem, error) - ProcessWorkItem(context.Context, WorkItem) error - AbandonWorkItem(context.Context, WorkItem) error - CompleteWorkItem(context.Context, WorkItem) error + ProcessWorkItem(context.Context, T) error + NextWorkItem(context.Context) (T, error) + AbandonWorkItem(context.Context, T) error + CompleteWorkItem(context.Context, T) error } -type worker struct { - options *WorkerOptions - logger Logger - // dispatchSemaphore is for throttling orchestration concurrency. - dispatchSemaphore semaphore.Semaphore - - // pending is for keeping track of outstanding orchestration executions. - pending *sync.WaitGroup - - // cancel is used to cancel background polling. - // It will be nil if background polling isn't started. - cancel context.CancelFunc - processor TaskProcessor - waiting bool - stop atomic.Bool +type worker[T WorkItem] struct { + logger Logger + + processor TaskProcessor[T] + closeCh chan struct{} + wg sync.WaitGroup + workItems chan T + parallelLock chan struct{} } type NewTaskWorkerOptions func(*WorkerOptions) @@ -67,188 +49,96 @@ func WithMaxParallelism(n int32) NewTaskWorkerOptions { } } -func NewTaskWorker(p TaskProcessor, logger Logger, opts ...NewTaskWorkerOptions) TaskWorker { +func NewTaskWorker[T WorkItem](p TaskProcessor[T], logger Logger, opts ...NewTaskWorkerOptions) TaskWorker[T] { options := &WorkerOptions{MaxParallelWorkItems: 1} for _, configure := range opts { configure(options) } - return &worker{ - processor: p, - logger: logger, - dispatchSemaphore: semaphore.New(int(options.MaxParallelWorkItems)), - pending: &sync.WaitGroup{}, - cancel: nil, // assigned later - options: options, + return &worker[T]{ + processor: p, + logger: logger, + workItems: make(chan T), + parallelLock: make(chan struct{}, options.MaxParallelWorkItems), + closeCh: make(chan struct{}), } } -func (w *worker) Name() string { +func (w *worker[T]) Name() string { return w.processor.Name() } -func (w *worker) Start(ctx context.Context) { - // TODO: Check for already started worker - ctx, cancel := context.WithCancel(ctx) - w.cancel = cancel +func (w *worker[T]) Start(ctx context.Context) { + w.wg.Add(2) - w.stop.Store(false) + ctx, cancel := context.WithCancel(ctx) go func() { - var b backoff.BackOff = &backoff.ExponentialBackOff{ - InitialInterval: 50 * time.Millisecond, - MaxInterval: 5 * time.Second, - Multiplier: 1.05, - RandomizationFactor: 0.05, - Stop: backoff.Stop, - Clock: backoff.SystemClock, + defer w.wg.Done() + defer cancel() + + select { + case <-w.closeCh: + case <-ctx.Done(): } - b = backoff.WithContext(b, ctx) - b.Reset() + }() + + go func() { + defer w.wg.Done() + defer w.logger.Infof("%v: worker stopped", w.Name()) - loop: for { - // returns right away, with "ok" if a work item was found - ok, err := w.ProcessNext(ctx) - - switch { - case ok: - // found a work item - reset the backoff and check for the next item - b.Reset() - case err != nil && errors.Is(err, ctx.Err()): - // there's an error and it's due to the context being canceled - w.logger.Infof("%v: received cancellation signal", w.Name()) - break loop - case err != nil: - // another error was encountered - // log the error and inject some extra sleep to avoid tight failure loops - w.logger.Errorf("unexpected worker error: %v. Adding 5 extra seconds of backoff.", err) - t := time.NewTimer(5 * time.Second) - select { - case <-t.C: - // nop - all good - case <-ctx.Done(): - if !t.Stop() { - <-t.C - } - w.logger.Infof("%v: received cancellation signal", w.Name()) - break loop - } - default: - // no work item found, so sleep until the next backoff - t := time.NewTimer(b.NextBackOff()) - select { - case <-t.C: - // nop - all good - case <-ctx.Done(): - if !t.Stop() { - <-t.C - } - w.logger.Infof("%v: received cancellation signal", w.Name()) - break loop - } + select { + case w.parallelLock <- struct{}{}: + case <-ctx.Done(): + return } - } - w.logger.Infof("%v: stopped listening for new work items", w.Name()) - }() -} + wi, err := w.processor.NextWorkItem(ctx) + if err != nil { + <-w.parallelLock -func (w *worker) ProcessNext(ctx context.Context) (bool, error) { - if !w.dispatchSemaphore.TryAcquire(1) { - w.logger.Debugf("%v: waiting for one of %v in-flight execution(s) to complete", w.Name(), w.dispatchSemaphore.GetCount()) - if err := w.dispatchSemaphore.Acquire(ctx, 1); err != nil { - // cancelled - return false, err - } - } - w.pending.Add(1) + if ctx.Err() != nil { + return + } - processing := false - defer func() { - if !processing { - w.pending.Done() - w.dispatchSemaphore.Release(1) - } - }() + w.logger.Errorf("%v: failed to get next work item: %v", w.Name(), err) + continue + } - wi, err := w.processor.FetchWorkItem(ctx) - switch { - case errors.Is(err, ErrNoWorkItems) || wi == nil: - if !w.waiting { - w.logger.Debugf("%v: waiting for new work items...", w.Name()) - w.waiting = true - } - return false, nil - case err != nil: - if !errors.Is(err, ctx.Err()) { - w.logger.Errorf("%v: failed to fetch work item: %v", w.Name(), err) + w.wg.Add(1) + go func() { + defer func() { + <-w.parallelLock + w.wg.Done() + }() + w.processWorkItem(ctx, wi) + }() } - return false, err - default: - // process the work-item in the background - w.waiting = false - processing = true - go w.processWorkItem(ctx, wi) - return true, nil - } + }() } -func (w *worker) StopAndDrain() { - w.logger.Debugf("%v: stop and drain...", w.Name()) - defer w.logger.Debugf("%v: finished stop and drain...", w.Name()) - - w.stop.Store(true) - - // Cancel the background poller and dispatcher(s) - if w.cancel != nil { - w.cancel() - } - - // Wait for outstanding work-items to finish processing. - // TODO: Need to find a way to cancel this if it takes too long for some reason. - w.pending.Wait() +func (w *worker[T]) StopAndDrain() { + close(w.closeCh) + w.wg.Wait() } -func (w *worker) processWorkItem(ctx context.Context, wi WorkItem) { - defer w.dispatchSemaphore.Release(1) - defer w.pending.Done() - +func (w *worker[T]) processWorkItem(ctx context.Context, wi T) { w.logger.Debugf("%v: processing work item: %s", w.Name(), wi) - if w.stop.Load() { - if err := w.processor.AbandonWorkItem(context.Background(), wi); err != nil { - w.logger.Errorf("%v: failed to abandon work item: %v", w.Name(), err) - } - return - } - if err := w.processor.ProcessWorkItem(ctx, wi); err != nil { - if errors.Is(err, ctx.Err()) { - w.logger.Warnf("%v: abandoning work item due to cancellation", w.Name()) - } else { - w.logger.Errorf("%v: failed to process work item: %v", w.Name(), err) - } - if w.stop.Load() { - ctx = context.Background() - } - if err := w.processor.AbandonWorkItem(ctx, wi); err != nil { + w.logger.Errorf("%v: failed to process work item: %v", w.Name(), err) + if err = w.processor.AbandonWorkItem(context.Background(), wi); err != nil { w.logger.Errorf("%v: failed to abandon work item: %v", w.Name(), err) } return } if err := w.processor.CompleteWorkItem(ctx, wi); err != nil { - if errors.Is(err, ctx.Err()) { - w.logger.Warnf("%v: failed to complete work item due to cancellation", w.Name()) - } else { - w.logger.Errorf("%v: failed to complete work item: %v", w.Name(), err) - } - if w.stop.Load() { - ctx = context.Background() - } - if err := w.processor.AbandonWorkItem(ctx, wi); err != nil { + w.logger.Errorf("%v: failed to complete work item: %v", w.Name(), err) + if err = w.processor.AbandonWorkItem(context.Background(), wi); err != nil { w.logger.Errorf("%v: failed to abandon work item: %v", w.Name(), err) } + return } w.logger.Debugf("%v: work item processed successfully", w.Name()) diff --git a/backend/workitem.go b/backend/workitem.go index 26ab2d7..1a938e6 100644 --- a/backend/workitem.go +++ b/backend/workitem.go @@ -1,15 +1,12 @@ package backend import ( - "errors" "fmt" "time" "github.com/dapr/durabletask-go/api" ) -var ErrNoWorkItems = errors.New("no work items were found") - type WorkItem interface { fmt.Stringer IsWorkItem() bool diff --git a/go.mod b/go.mod index 3d3b832..4ece2fe 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,6 @@ require ( github.com/cenkalti/backoff/v4 v4.2.1 github.com/dapr/kit v0.13.1-0.20250110192255-fb195706966f github.com/google/uuid v1.6.0 - github.com/marusama/semaphore/v2 v2.5.0 github.com/stretchr/testify v1.9.0 go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.44.0 go.opentelemetry.io/otel v1.18.0 diff --git a/go.sum b/go.sum index b9eb71e..7f05915 100644 --- a/go.sum +++ b/go.sum @@ -26,8 +26,6 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/marusama/semaphore/v2 v2.5.0 h1:o/1QJD9DBYOWRnDhPwDVAXQn6mQYD0gZaS1Tpx6DJGM= -github.com/marusama/semaphore/v2 v2.5.0/go.mod h1:z9nMiNUekt/LTpTUQdpp+4sJeYqUGpwMHfW0Z8V8fnQ= github.com/mattn/go-isatty v0.0.16 h1:bq3VjFmv/sOjHtdEhmkEV4x1AJtvUvOJ2PFAZ5+peKQ= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y= diff --git a/submodules/durabletask-protobuf b/submodules/durabletask-protobuf index 4207e1d..c62e0a2 160000 --- a/submodules/durabletask-protobuf +++ b/submodules/durabletask-protobuf @@ -1 +1 @@ -Subproject commit 4207e1dbd14cedc268f69c3befee60fcaad19367 +Subproject commit c62e0a234019a300f334ceedbd96831dbca1bca6 diff --git a/tests/backend_test.go b/tests/backend_test.go index b3483af..1eacb07 100644 --- a/tests/backend_test.go +++ b/tests/backend_test.go @@ -68,10 +68,6 @@ func Test_NewOrchestrationWorkItem_Single(t *testing.T) { assert.Equal(t, 0, len(state.NewEvents())) assert.Equal(t, 0, len(state.OldEvents())) } - - // Ensure no more work items - _, err := be.GetOrchestrationWorkItem(ctx) - assert.ErrorIs(t, err, backend.ErrNoWorkItems) } } } @@ -113,10 +109,6 @@ func Test_NewOrchestrationWorkItem_Multiple(t *testing.T) { } } } - - // Ensure no more work items - _, err := be.GetOrchestrationWorkItem(ctx) - assert.ErrorIs(t, err, backend.ErrNoWorkItems) } } @@ -166,10 +158,6 @@ func Test_CompleteOrchestration(t *testing.T) { // Execute the test, which calls the above callbacks workItemProcessingTestLogic(t, be, getOrchestratorActions, validateMetadata) - - // Ensure no more work items - _, err := be.GetOrchestrationWorkItem(ctx) - assert.ErrorIs(t, err, backend.ErrNoWorkItems) } } } @@ -183,11 +171,6 @@ func Test_ScheduleActivityTasks(t *testing.T) { for i, be := range backends { initTest(t, be, i, true) - wi, err := be.GetActivityWorkItem(ctx) - if !assert.ErrorIs(t, err, backend.ErrNoWorkItems) { - continue - } - // Produce a TaskScheduled event with a particular input getOrchestratorActions := func() []*protos.OrchestratorAction { return []*protos.OrchestratorAction{ @@ -208,21 +191,13 @@ func Test_ScheduleActivityTasks(t *testing.T) { // Execute the test, which calls the above callbacks workItemProcessingTestLogic(t, be, getOrchestratorActions, validateMetadata) - // Ensure no more orchestration work items - _, err = be.GetOrchestrationWorkItem(ctx) - assert.ErrorIs(t, err, backend.ErrNoWorkItems) - // However, there should be an activity work item - wi, err = be.GetActivityWorkItem(ctx) + wi, err := be.NextActivityWorkItem(ctx) if assert.NoError(t, err) && assert.NotNil(t, wi) { assert.Equal(t, expectedName, wi.NewEvent.GetTaskScheduled().GetName()) assert.Equal(t, expectedInput, wi.NewEvent.GetTaskScheduled().GetInput().GetValue()) } - // Ensure no more activity work items - _, err = be.GetActivityWorkItem(ctx) - assert.ErrorIs(t, err, backend.ErrNoWorkItems) - // Complete the fetched activity work item wi.Result = &protos.HistoryEvent{ EventId: -1, @@ -237,7 +212,7 @@ func Test_ScheduleActivityTasks(t *testing.T) { err = be.CompleteActivityWorkItem(ctx, wi) if assert.NoError(t, err) { // Completing the activity work item should create a new TaskCompleted event - wi, err := be.GetOrchestrationWorkItem(ctx) + wi, err := be.NextOrchestrationWorkItem(ctx) if assert.NoError(t, err) && assert.NotNil(t, wi) && assert.Len(t, wi.NewEvents, 1) { assert.Equal(t, expectedTaskID, wi.NewEvents[0].GetTaskCompleted().GetTaskScheduledId()) assert.Equal(t, expectedResult, wi.NewEvents[0].GetTaskCompleted().GetResult().GetValue()) @@ -270,15 +245,11 @@ func Test_ScheduleTimerTasks(t *testing.T) { // Execute the test, which calls the above callbacks workItemProcessingTestLogic(t, be, getOrchestratorActions, validateMetadata) - // Validate that the timer work-item isn't yet visible - _, err := be.GetOrchestrationWorkItem(ctx) - assert.ErrorIs(t, err, backend.ErrNoWorkItems) - // Sleep until the expected visibility time expires time.Sleep(timerDuration) // Validate that the timer work-item is now visible - wi, err := be.GetOrchestrationWorkItem(ctx) + wi, err := be.NextOrchestrationWorkItem(ctx) if assert.NoError(t, err) && assert.Equal(t, 1, len(wi.NewEvents)) { e := wi.NewEvents[0] tf := e.GetTimerFired() @@ -330,15 +301,11 @@ func Test_AbandonActivityWorkItem(t *testing.T) { workItemProcessingTestLogic(t, be, getOrchestratorActions, validateMetadata) // The NewScheduleTaskAction should have created an activity work item - wi, err := be.GetActivityWorkItem(ctx) + wi, err := be.NextActivityWorkItem(ctx) if assert.NoError(t, err) && assert.NotNil(t, wi) { - // Ensure no more activity work items - _, err = be.GetActivityWorkItem(ctx) - assert.ErrorIs(t, err, backend.ErrNoWorkItems) - if err := be.AbandonActivityWorkItem(ctx, wi); assert.NoError(t, err) { // Re-fetch the abandoned activity work item - wi, err = be.GetActivityWorkItem(ctx) + wi, err = be.NextActivityWorkItem(ctx) assert.Equal(t, "MyActivity", wi.NewEvent.GetTaskScheduled().GetName()) assert.Equal(t, int32(123), wi.NewEvent.EventId) assert.Nil(t, wi.NewEvent.GetTaskScheduled().GetInput()) @@ -361,9 +328,9 @@ func Test_UninitializedBackend(t *testing.T) { assert.Equal(t, err, backend.ErrNotInitialized) _, err = be.GetOrchestrationRuntimeState(ctx, nil) assert.Equal(t, err, backend.ErrNotInitialized) - _, err = be.GetOrchestrationWorkItem(ctx) + _, err = be.NextOrchestrationWorkItem(ctx) assert.Equal(t, err, backend.ErrNotInitialized) - _, err = be.GetActivityWorkItem(ctx) + _, err = be.NextActivityWorkItem(ctx) assert.Equal(t, err, backend.ErrNotInitialized) } } @@ -511,7 +478,7 @@ func createOrchestrationInstance(t assert.TestingT, be backend.Backend, instance } func getOrchestrationWorkItem(t assert.TestingT, be backend.Backend, expectedInstanceID string) (*backend.OrchestrationWorkItem, bool) { - wi, err := be.GetOrchestrationWorkItem(ctx) + wi, err := be.NextOrchestrationWorkItem(ctx) if assert.NoError(t, err) && assert.NotNil(t, wi) { assert.NotEmpty(t, wi.LockedBy) return wi, assert.Equal(t, expectedInstanceID, string(wi.InstanceID)) diff --git a/tests/mocks/Backend.go b/tests/mocks/Backend.go index 78ef150..8e33f40 100644 --- a/tests/mocks/Backend.go +++ b/tests/mocks/Backend.go @@ -416,64 +416,6 @@ func (_c *Backend_DeleteTaskHub_Call) RunAndReturn(run func(context.Context) err return _c } -// GetActivityWorkItem provides a mock function with given fields: _a0 -func (_m *Backend) GetActivityWorkItem(_a0 context.Context) (*backend.ActivityWorkItem, error) { - ret := _m.Called(_a0) - - if len(ret) == 0 { - panic("no return value specified for GetActivityWorkItem") - } - - var r0 *backend.ActivityWorkItem - var r1 error - if rf, ok := ret.Get(0).(func(context.Context) (*backend.ActivityWorkItem, error)); ok { - return rf(_a0) - } - if rf, ok := ret.Get(0).(func(context.Context) *backend.ActivityWorkItem); ok { - r0 = rf(_a0) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*backend.ActivityWorkItem) - } - } - - if rf, ok := ret.Get(1).(func(context.Context) error); ok { - r1 = rf(_a0) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// Backend_GetActivityWorkItem_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetActivityWorkItem' -type Backend_GetActivityWorkItem_Call struct { - *mock.Call -} - -// GetActivityWorkItem is a helper method to define mock.On call -// - _a0 context.Context -func (_e *Backend_Expecter) GetActivityWorkItem(_a0 interface{}) *Backend_GetActivityWorkItem_Call { - return &Backend_GetActivityWorkItem_Call{Call: _e.mock.On("GetActivityWorkItem", _a0)} -} - -func (_c *Backend_GetActivityWorkItem_Call) Run(run func(_a0 context.Context)) *Backend_GetActivityWorkItem_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context)) - }) - return _c -} - -func (_c *Backend_GetActivityWorkItem_Call) Return(_a0 *backend.ActivityWorkItem, _a1 error) *Backend_GetActivityWorkItem_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *Backend_GetActivityWorkItem_Call) RunAndReturn(run func(context.Context) (*backend.ActivityWorkItem, error)) *Backend_GetActivityWorkItem_Call { - _c.Call.Return(run) - return _c -} - // GetOrchestrationMetadata provides a mock function with given fields: _a0, _a1 func (_m *Backend) GetOrchestrationMetadata(_a0 context.Context, _a1 api.InstanceID) (*protos.OrchestrationMetadata, error) { ret := _m.Called(_a0, _a1) @@ -592,12 +534,70 @@ func (_c *Backend_GetOrchestrationRuntimeState_Call) RunAndReturn(run func(conte return _c } -// GetOrchestrationWorkItem provides a mock function with given fields: _a0 -func (_m *Backend) GetOrchestrationWorkItem(_a0 context.Context) (*backend.OrchestrationWorkItem, error) { +// NextActivityWorkItem provides a mock function with given fields: _a0 +func (_m *Backend) NextActivityWorkItem(_a0 context.Context) (*backend.ActivityWorkItem, error) { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for NextActivityWorkItem") + } + + var r0 *backend.ActivityWorkItem + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (*backend.ActivityWorkItem, error)); ok { + return rf(_a0) + } + if rf, ok := ret.Get(0).(func(context.Context) *backend.ActivityWorkItem); ok { + r0 = rf(_a0) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*backend.ActivityWorkItem) + } + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(_a0) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Backend_NextActivityWorkItem_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'NextActivityWorkItem' +type Backend_NextActivityWorkItem_Call struct { + *mock.Call +} + +// NextActivityWorkItem is a helper method to define mock.On call +// - _a0 context.Context +func (_e *Backend_Expecter) NextActivityWorkItem(_a0 interface{}) *Backend_NextActivityWorkItem_Call { + return &Backend_NextActivityWorkItem_Call{Call: _e.mock.On("NextActivityWorkItem", _a0)} +} + +func (_c *Backend_NextActivityWorkItem_Call) Run(run func(_a0 context.Context)) *Backend_NextActivityWorkItem_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *Backend_NextActivityWorkItem_Call) Return(_a0 *backend.ActivityWorkItem, _a1 error) *Backend_NextActivityWorkItem_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Backend_NextActivityWorkItem_Call) RunAndReturn(run func(context.Context) (*backend.ActivityWorkItem, error)) *Backend_NextActivityWorkItem_Call { + _c.Call.Return(run) + return _c +} + +// NextOrchestrationWorkItem provides a mock function with given fields: _a0 +func (_m *Backend) NextOrchestrationWorkItem(_a0 context.Context) (*backend.OrchestrationWorkItem, error) { ret := _m.Called(_a0) if len(ret) == 0 { - panic("no return value specified for GetOrchestrationWorkItem") + panic("no return value specified for NextOrchestrationWorkItem") } var r0 *backend.OrchestrationWorkItem @@ -622,30 +622,30 @@ func (_m *Backend) GetOrchestrationWorkItem(_a0 context.Context) (*backend.Orche return r0, r1 } -// Backend_GetOrchestrationWorkItem_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetOrchestrationWorkItem' -type Backend_GetOrchestrationWorkItem_Call struct { +// Backend_NextOrchestrationWorkItem_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'NextOrchestrationWorkItem' +type Backend_NextOrchestrationWorkItem_Call struct { *mock.Call } -// GetOrchestrationWorkItem is a helper method to define mock.On call +// NextOrchestrationWorkItem is a helper method to define mock.On call // - _a0 context.Context -func (_e *Backend_Expecter) GetOrchestrationWorkItem(_a0 interface{}) *Backend_GetOrchestrationWorkItem_Call { - return &Backend_GetOrchestrationWorkItem_Call{Call: _e.mock.On("GetOrchestrationWorkItem", _a0)} +func (_e *Backend_Expecter) NextOrchestrationWorkItem(_a0 interface{}) *Backend_NextOrchestrationWorkItem_Call { + return &Backend_NextOrchestrationWorkItem_Call{Call: _e.mock.On("NextOrchestrationWorkItem", _a0)} } -func (_c *Backend_GetOrchestrationWorkItem_Call) Run(run func(_a0 context.Context)) *Backend_GetOrchestrationWorkItem_Call { +func (_c *Backend_NextOrchestrationWorkItem_Call) Run(run func(_a0 context.Context)) *Backend_NextOrchestrationWorkItem_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context)) }) return _c } -func (_c *Backend_GetOrchestrationWorkItem_Call) Return(_a0 *backend.OrchestrationWorkItem, _a1 error) *Backend_GetOrchestrationWorkItem_Call { +func (_c *Backend_NextOrchestrationWorkItem_Call) Return(_a0 *backend.OrchestrationWorkItem, _a1 error) *Backend_NextOrchestrationWorkItem_Call { _c.Call.Return(_a0, _a1) return _c } -func (_c *Backend_GetOrchestrationWorkItem_Call) RunAndReturn(run func(context.Context) (*backend.OrchestrationWorkItem, error)) *Backend_GetOrchestrationWorkItem_Call { +func (_c *Backend_NextOrchestrationWorkItem_Call) RunAndReturn(run func(context.Context) (*backend.OrchestrationWorkItem, error)) *Backend_NextOrchestrationWorkItem_Call { _c.Call.Return(run) return _c } diff --git a/tests/mocks/TaskWorker.go b/tests/mocks/TaskWorker.go index 4fb7bea..35c04d9 100644 --- a/tests/mocks/TaskWorker.go +++ b/tests/mocks/TaskWorker.go @@ -5,150 +5,96 @@ package mocks import ( context "context" + backend "github.com/dapr/durabletask-go/backend" + mock "github.com/stretchr/testify/mock" ) // TaskWorker is an autogenerated mock type for the TaskWorker type -type TaskWorker struct { +type TaskWorker[T backend.WorkItem] struct { mock.Mock } -type TaskWorker_Expecter struct { +type TaskWorker_Expecter[T backend.WorkItem] struct { mock *mock.Mock } -func (_m *TaskWorker) EXPECT() *TaskWorker_Expecter { - return &TaskWorker_Expecter{mock: &_m.Mock} -} - -// ProcessNext provides a mock function with given fields: _a0 -func (_m *TaskWorker) ProcessNext(_a0 context.Context) (bool, error) { - ret := _m.Called(_a0) - - if len(ret) == 0 { - panic("no return value specified for ProcessNext") - } - - var r0 bool - var r1 error - if rf, ok := ret.Get(0).(func(context.Context) (bool, error)); ok { - return rf(_a0) - } - if rf, ok := ret.Get(0).(func(context.Context) bool); ok { - r0 = rf(_a0) - } else { - r0 = ret.Get(0).(bool) - } - - if rf, ok := ret.Get(1).(func(context.Context) error); ok { - r1 = rf(_a0) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// TaskWorker_ProcessNext_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ProcessNext' -type TaskWorker_ProcessNext_Call struct { - *mock.Call -} - -// ProcessNext is a helper method to define mock.On call -// - _a0 context.Context -func (_e *TaskWorker_Expecter) ProcessNext(_a0 interface{}) *TaskWorker_ProcessNext_Call { - return &TaskWorker_ProcessNext_Call{Call: _e.mock.On("ProcessNext", _a0)} -} - -func (_c *TaskWorker_ProcessNext_Call) Run(run func(_a0 context.Context)) *TaskWorker_ProcessNext_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context)) - }) - return _c -} - -func (_c *TaskWorker_ProcessNext_Call) Return(_a0 bool, _a1 error) *TaskWorker_ProcessNext_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *TaskWorker_ProcessNext_Call) RunAndReturn(run func(context.Context) (bool, error)) *TaskWorker_ProcessNext_Call { - _c.Call.Return(run) - return _c +func (_m *TaskWorker[T]) EXPECT() *TaskWorker_Expecter[T] { + return &TaskWorker_Expecter[T]{mock: &_m.Mock} } // Start provides a mock function with given fields: _a0 -func (_m *TaskWorker) Start(_a0 context.Context) { +func (_m *TaskWorker[T]) Start(_a0 context.Context) { _m.Called(_a0) } // TaskWorker_Start_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Start' -type TaskWorker_Start_Call struct { +type TaskWorker_Start_Call[T backend.WorkItem] struct { *mock.Call } // Start is a helper method to define mock.On call // - _a0 context.Context -func (_e *TaskWorker_Expecter) Start(_a0 interface{}) *TaskWorker_Start_Call { - return &TaskWorker_Start_Call{Call: _e.mock.On("Start", _a0)} +func (_e *TaskWorker_Expecter[T]) Start(_a0 interface{}) *TaskWorker_Start_Call[T] { + return &TaskWorker_Start_Call[T]{Call: _e.mock.On("Start", _a0)} } -func (_c *TaskWorker_Start_Call) Run(run func(_a0 context.Context)) *TaskWorker_Start_Call { +func (_c *TaskWorker_Start_Call[T]) Run(run func(_a0 context.Context)) *TaskWorker_Start_Call[T] { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context)) }) return _c } -func (_c *TaskWorker_Start_Call) Return() *TaskWorker_Start_Call { +func (_c *TaskWorker_Start_Call[T]) Return() *TaskWorker_Start_Call[T] { _c.Call.Return() return _c } -func (_c *TaskWorker_Start_Call) RunAndReturn(run func(context.Context)) *TaskWorker_Start_Call { +func (_c *TaskWorker_Start_Call[T]) RunAndReturn(run func(context.Context)) *TaskWorker_Start_Call[T] { _c.Call.Return(run) return _c } // StopAndDrain provides a mock function with given fields: -func (_m *TaskWorker) StopAndDrain() { +func (_m *TaskWorker[T]) StopAndDrain() { _m.Called() } // TaskWorker_StopAndDrain_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'StopAndDrain' -type TaskWorker_StopAndDrain_Call struct { +type TaskWorker_StopAndDrain_Call[T backend.WorkItem] struct { *mock.Call } // StopAndDrain is a helper method to define mock.On call -func (_e *TaskWorker_Expecter) StopAndDrain() *TaskWorker_StopAndDrain_Call { - return &TaskWorker_StopAndDrain_Call{Call: _e.mock.On("StopAndDrain")} +func (_e *TaskWorker_Expecter[T]) StopAndDrain() *TaskWorker_StopAndDrain_Call[T] { + return &TaskWorker_StopAndDrain_Call[T]{Call: _e.mock.On("StopAndDrain")} } -func (_c *TaskWorker_StopAndDrain_Call) Run(run func()) *TaskWorker_StopAndDrain_Call { +func (_c *TaskWorker_StopAndDrain_Call[T]) Run(run func()) *TaskWorker_StopAndDrain_Call[T] { _c.Call.Run(func(args mock.Arguments) { run() }) return _c } -func (_c *TaskWorker_StopAndDrain_Call) Return() *TaskWorker_StopAndDrain_Call { +func (_c *TaskWorker_StopAndDrain_Call[T]) Return() *TaskWorker_StopAndDrain_Call[T] { _c.Call.Return() return _c } -func (_c *TaskWorker_StopAndDrain_Call) RunAndReturn(run func()) *TaskWorker_StopAndDrain_Call { +func (_c *TaskWorker_StopAndDrain_Call[T]) RunAndReturn(run func()) *TaskWorker_StopAndDrain_Call[T] { _c.Call.Return(run) return _c } // NewTaskWorker creates a new instance of TaskWorker. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. -func NewTaskWorker(t interface { +func NewTaskWorker[T backend.WorkItem](t interface { mock.TestingT Cleanup(func()) -}) *TaskWorker { - mock := &TaskWorker{} +}) *TaskWorker[T] { + mock := &TaskWorker[T]{} mock.Mock.Test(t) t.Cleanup(func() { mock.AssertExpectations(t) }) diff --git a/tests/mocks/task.go b/tests/mocks/task.go index ba8a054..f16a969 100644 --- a/tests/mocks/task.go +++ b/tests/mocks/task.go @@ -10,90 +10,87 @@ import ( backend "github.com/dapr/durabletask-go/backend" ) -var _ backend.TaskProcessor = &TestTaskProcessor{} - // TestTaskProcessor implements a dummy task processor useful for testing -type TestTaskProcessor struct { +type TestTaskProcessor[T backend.WorkItem] struct { name string processingBlocked atomic.Bool workItemMu sync.Mutex - workItems []backend.WorkItem + workItems []T abandonedWorkItemMu sync.Mutex - abandonedWorkItems []backend.WorkItem + abandonedWorkItems []T completedWorkItemMu sync.Mutex - completedWorkItems []backend.WorkItem + completedWorkItems []T } -func NewTestTaskPocessor(name string) *TestTaskProcessor { - return &TestTaskProcessor{ +func NewTestTaskPocessor[T backend.WorkItem](name string) *TestTaskProcessor[T] { + return &TestTaskProcessor[T]{ name: name, } } -func (t *TestTaskProcessor) BlockProcessing() { +func (t *TestTaskProcessor[T]) BlockProcessing() { t.processingBlocked.Store(true) } -func (t *TestTaskProcessor) UnblockProcessing() { +func (t *TestTaskProcessor[T]) UnblockProcessing() { t.processingBlocked.Store(false) } -func (t *TestTaskProcessor) PendingWorkItems() []backend.WorkItem { +func (t *TestTaskProcessor[T]) PendingWorkItems() []T { t.workItemMu.Lock() defer t.workItemMu.Unlock() // copy array - return append([]backend.WorkItem{}, t.workItems...) + return append([]T{}, t.workItems...) } -func (t *TestTaskProcessor) AbandonedWorkItems() []backend.WorkItem { +func (t *TestTaskProcessor[T]) AbandonedWorkItems() []T { t.abandonedWorkItemMu.Lock() defer t.abandonedWorkItemMu.Unlock() // copy array - return append([]backend.WorkItem{}, t.abandonedWorkItems...) + return append([]T{}, t.abandonedWorkItems...) } -func (t *TestTaskProcessor) CompletedWorkItems() []backend.WorkItem { +func (t *TestTaskProcessor[T]) CompletedWorkItems() []T { t.completedWorkItemMu.Lock() defer t.completedWorkItemMu.Unlock() // copy array - return append([]backend.WorkItem{}, t.completedWorkItems...) + return append([]T{}, t.completedWorkItems...) } -func (t *TestTaskProcessor) AddWorkItems(wis ...backend.WorkItem) { +func (t *TestTaskProcessor[T]) AddWorkItems(wis ...T) { t.workItemMu.Lock() defer t.workItemMu.Unlock() t.workItems = append(t.workItems, wis...) } -func (t *TestTaskProcessor) Name() string { +func (t *TestTaskProcessor[T]) Name() string { return t.name } -func (t *TestTaskProcessor) FetchWorkItem(context.Context) (backend.WorkItem, error) { +func (t *TestTaskProcessor[T]) NextWorkItem(context.Context) (T, error) { t.workItemMu.Lock() defer t.workItemMu.Unlock() if len(t.workItems) == 0 { - return nil, backend.ErrNoWorkItems + var tt T + return tt, errors.New("no work items") } - // pop first item - i := 0 - wi := t.workItems[i] - t.workItems = append(t.workItems[:i], t.workItems[i+1:]...) + wi := t.workItems[0] + t.workItems = t.workItems[1:] return wi, nil } -func (t *TestTaskProcessor) ProcessWorkItem(ctx context.Context, wi backend.WorkItem) error { +func (t *TestTaskProcessor[T]) ProcessWorkItem(ctx context.Context, wi T) error { if !t.processingBlocked.Load() { return nil } @@ -111,7 +108,7 @@ func (t *TestTaskProcessor) ProcessWorkItem(ctx context.Context, wi backend.Work } } -func (t *TestTaskProcessor) AbandonWorkItem(ctx context.Context, wi backend.WorkItem) error { +func (t *TestTaskProcessor[T]) AbandonWorkItem(ctx context.Context, wi T) error { t.abandonedWorkItemMu.Lock() defer t.abandonedWorkItemMu.Unlock() @@ -119,7 +116,7 @@ func (t *TestTaskProcessor) AbandonWorkItem(ctx context.Context, wi backend.Work return nil } -func (t *TestTaskProcessor) CompleteWorkItem(ctx context.Context, wi backend.WorkItem) error { +func (t *TestTaskProcessor[T]) CompleteWorkItem(ctx context.Context, wi T) error { t.completedWorkItemMu.Lock() defer t.completedWorkItemMu.Unlock() diff --git a/tests/taskhub_test.go b/tests/taskhub_test.go index f9a235c..538004e 100644 --- a/tests/taskhub_test.go +++ b/tests/taskhub_test.go @@ -13,8 +13,8 @@ func Test_TaskHubWorkerStartsDependencies(t *testing.T) { ctx := context.Background() be := mocks.NewBackend(t) - orchWorker := mocks.NewTaskWorker(t) - actWorker := mocks.NewTaskWorker(t) + orchWorker := mocks.NewTaskWorker[*backend.OrchestrationWorkItem](t) + actWorker := mocks.NewTaskWorker[*backend.ActivityWorkItem](t) be.EXPECT().CreateTaskHub(ctx).Return(nil).Once() be.EXPECT().Start(ctx).Return(nil).Once() @@ -30,8 +30,8 @@ func Test_TaskHubWorkerStopsDependencies(t *testing.T) { ctx := context.Background() be := mocks.NewBackend(t) - orchWorker := mocks.NewTaskWorker(t) - actWorker := mocks.NewTaskWorker(t) + orchWorker := mocks.NewTaskWorker[*backend.OrchestrationWorkItem](t) + actWorker := mocks.NewTaskWorker[*backend.ActivityWorkItem](t) be.EXPECT().Stop(ctx).Return(nil).Once() orchWorker.EXPECT().StopAndDrain().Return().Once() diff --git a/tests/worker_test.go b/tests/worker_test.go index 7f67357..098f096 100644 --- a/tests/worker_test.go +++ b/tests/worker_test.go @@ -2,6 +2,7 @@ package tests import ( "context" + "errors" "sync/atomic" "testing" "time" @@ -46,9 +47,13 @@ func Test_TryProcessSingleOrchestrationWorkItem_BasicFlow(t *testing.T) { state := &backend.OrchestrationRuntimeState{} result := &backend.ExecutionResults{Response: &protos.OrchestratorResponse{}} + ctx, cancel := context.WithCancel(ctx) completed := atomic.Bool{} be := mocks.NewBackend(t) - be.EXPECT().GetOrchestrationWorkItem(anyContext).Return(wi, nil).Once() + be.EXPECT().NextOrchestrationWorkItem(anyContext).Return(wi, nil).Once() + be.EXPECT().NextOrchestrationWorkItem(anyContext).Return(nil, errors.New("")).Once().Run(func(mock.Arguments) { + cancel() + }) be.EXPECT().GetOrchestrationRuntimeState(anyContext, wi).Return(state, nil).Once() be.EXPECT().CompleteOrchestrationWorkItem(anyContext, wi).RunAndReturn(func(ctx context.Context, owi *backend.OrchestrationWorkItem) error { completed.Store(true) @@ -59,10 +64,7 @@ func Test_TryProcessSingleOrchestrationWorkItem_BasicFlow(t *testing.T) { ex.EXPECT().ExecuteOrchestrator(anyContext, wi.InstanceID, state.OldEvents(), mock.Anything).Return(result, nil).Once() worker := backend.NewOrchestrationWorker(be, ex, logger) - ok, err := worker.ProcessNext(ctx) - // Successfully processing a work-item should result in a nil error - assert.Nil(t, err) - assert.True(t, ok) + worker.Start(ctx) require.EventuallyWithT(t, func(collect *assert.CollectT) { if !completed.Load() { @@ -73,17 +75,6 @@ func Test_TryProcessSingleOrchestrationWorkItem_BasicFlow(t *testing.T) { worker.StopAndDrain() } -func Test_TryProcessSingleOrchestrationWorkItem_NoWorkItems(t *testing.T) { - ctx := context.Background() - be := mocks.NewBackend(t) - be.EXPECT().GetOrchestrationWorkItem(anyContext).Return(nil, backend.ErrNoWorkItems).Once() - - w := backend.NewOrchestrationWorker(be, nil, logger) - ok, err := w.ProcessNext(ctx) - assert.Nil(t, err) - assert.False(t, ok) -} - func Test_TryProcessSingleOrchestrationWorkItem_ExecutionStartedAndCompleted(t *testing.T) { ctx := context.Background() iid := api.InstanceID("test123") @@ -111,8 +102,13 @@ func Test_TryProcessSingleOrchestrationWorkItem_ExecutionStartedAndCompleted(t * // Empty orchestration runtime state since we're starting a new execution from scratch state := backend.NewOrchestrationRuntimeState(iid, []*protos.HistoryEvent{}) + ctx, cancel := context.WithCancel(ctx) be := mocks.NewBackend(t) - be.EXPECT().GetOrchestrationWorkItem(anyContext).Return(wi, nil).Once() + be.EXPECT().NextOrchestrationWorkItem(anyContext).Return(wi, nil).Once() + be.EXPECT().NextOrchestrationWorkItem(anyContext).Return(nil, errors.New("")).Once().Run(func(mock.Arguments) { + cancel() + }) + be.EXPECT().GetOrchestrationRuntimeState(anyContext, wi).Return(state, nil).Once() ex := mocks.NewExecutor(t) @@ -148,10 +144,11 @@ func Test_TryProcessSingleOrchestrationWorkItem_ExecutionStartedAndCompleted(t * // Set up and run the test worker := backend.NewOrchestrationWorker(be, ex, logger) - ok, err := worker.ProcessNext(ctx) - // Successfully processing a work-item should result in a nil error - assert.Nil(t, err) - assert.True(t, ok) + worker.Start(ctx) + //ok, err := worker.ProcessNext(ctx) + //// Successfully processing a work-item should result in a nil error + //assert.Nil(t, err) + //assert.True(t, ok) require.EventuallyWithT(t, func(collect *assert.CollectT) { if !completed.Load() { @@ -166,18 +163,18 @@ func Test_TaskWorker(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - tp := mocks.NewTestTaskPocessor("test") + tp := mocks.NewTestTaskPocessor[*backend.ActivityWorkItem]("test") tp.UnblockProcessing() - first := backend.ActivityWorkItem{ + first := &backend.ActivityWorkItem{ SequenceNumber: 1, } - second := backend.ActivityWorkItem{ + second := &backend.ActivityWorkItem{ SequenceNumber: 2, } tp.AddWorkItems(first, second) - worker := backend.NewTaskWorker(tp, logger) + worker := backend.NewTaskWorker[*backend.ActivityWorkItem](tp, logger) worker.Start(ctx) @@ -213,7 +210,7 @@ func Test_StartAndStop(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - tp := mocks.NewTestTaskPocessor("test") + tp := mocks.NewTestTaskPocessor[*backend.ActivityWorkItem]("test") tp.BlockProcessing() first := backend.ActivityWorkItem{ @@ -222,21 +219,17 @@ func Test_StartAndStop(t *testing.T) { second := backend.ActivityWorkItem{ SequenceNumber: 2, } - tp.AddWorkItems(first, second) + tp.AddWorkItems(&first, &second) - worker := backend.NewTaskWorker(tp, logger) + worker := backend.NewTaskWorker[*backend.ActivityWorkItem](tp, logger) worker.Start(ctx) - require.EventuallyWithT(t, func(collect *assert.CollectT) { - if len(tp.PendingWorkItems()) == 1 { - return - } - collect.Errorf("first work item not consumed yet") - }, 500*time.Millisecond, 100*time.Millisecond) + require.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Len(c, tp.PendingWorkItems(), 1) + }, time.Second*5, 100*time.Millisecond) // due to the configuration of the TestTaskProcessor, now the work item is blocked on ProcessWorkItem until the context is cancelled - drainFinished := make(chan bool) go func() { worker.StopAndDrain()