diff --git a/concurrency/worker.go b/concurrency/worker.go index f40f03348..a0d0bee75 100644 --- a/concurrency/worker.go +++ b/concurrency/worker.go @@ -1,5 +1,18 @@ package concurrency +import ( + "errors" + "sync" +) + +type ClosedAction int + +const ( + PanicWhenClosed ClosedAction = iota + ErrorWhenClosed + SpawnNewGoroutineWhenClosed +) + // NewReusableGoroutinesPool creates a new worker pool with the given size. // These workers will run the workloads passed through Go() calls. // If all workers are busy, Go() will spawn a new goroutine to run the workload. @@ -17,22 +30,54 @@ func NewReusableGoroutinesPool(size int) *ReusableGoroutinesPool { return p } +func (p *ReusableGoroutinesPool) WithClosedAction(action ClosedAction) *ReusableGoroutinesPool { + p.closedAction = action + return p +} + type ReusableGoroutinesPool struct { - jobs chan func() + jobsMu sync.Mutex + closed bool + closedAction ClosedAction + jobs chan func() } // Go will run the given function in a worker of the pool. // If all workers are busy, Go() will spawn a new goroutine to run the workload. -func (p *ReusableGoroutinesPool) Go(f func()) { +func (p *ReusableGoroutinesPool) Go(f func()) error { + p.jobsMu.Lock() + defer p.jobsMu.Unlock() + + if p.closed { + switch p.closedAction { + case PanicWhenClosed: + panic("tried to run a workload on a closed ReusableGoroutinesPool. Use a different ClosedAction to avoid this panic.") + case ErrorWhenClosed: + msg := "tried to run a workload on a closed ReusableGoroutinesPool, dropping the workload" + return errors.New(msg) + case SpawnNewGoroutineWhenClosed: + msg := "tried to run a workload on a closed ReusableGoroutinesPool, spawning a new goroutine to run the workload" + go f() + return errors.New(msg) + } + } + select { case p.jobs <- f: default: go f() } + + return nil } // Close stops the workers of the pool. -// No new Do() calls should be performed after calling Close(). +// No new Go() calls should be performed after calling Close(). // Close does NOT wait for all jobs to finish, it is the caller's responsibility to ensure that in the provided workloads. // Close is intended to be used in tests to ensure that no goroutines are leaked. -func (p *ReusableGoroutinesPool) Close() { close(p.jobs) } +func (p *ReusableGoroutinesPool) Close() { + p.jobsMu.Lock() + defer p.jobsMu.Unlock() + p.closed = true + close(p.jobs) +} diff --git a/concurrency/worker_test.go b/concurrency/worker_test.go index 338062055..7f6883ff1 100644 --- a/concurrency/worker_test.go +++ b/concurrency/worker_test.go @@ -4,10 +4,12 @@ import ( "regexp" "runtime" "strings" + "sync" "testing" "time" "github.com/stretchr/testify/require" + "go.uber.org/atomic" ) func TestReusableGoroutinesPool(t *testing.T) { @@ -59,3 +61,67 @@ func TestReusableGoroutinesPool(t *testing.T) { } t.Fatalf("expected %d goroutines after closing, got %d", 0, countGoroutines()) } + +func TestReusableGoroutinesPool_ClosedActionPanic(t *testing.T) { + w := NewReusableGoroutinesPool(2) + + runCount, panicked, _ := causePoolFailure(t, w, 10) + + require.NotZero(t, runCount, "expected at least one run") + require.Less(t, runCount, 10, "expected less than 10 runs") + require.True(t, panicked, "expected panic") +} + +func TestReusableGoroutinesPool_ClosedActionError(t *testing.T) { + w := NewReusableGoroutinesPool(2).WithClosedAction(ErrorWhenClosed) + + runCount, panicked, errors := causePoolFailure(t, w, 10) + + require.NotZero(t, runCount, "expected at least one run") + require.Less(t, runCount, 10, "expected less than 10 runs") + require.False(t, panicked, "expected no panic") + require.NotZero(t, len(errors), "expected errors") + require.Less(t, len(errors), 10, "expected less than 10 errors. Some workloads were submitted before close.") +} + +func TestReusableGoroutinesPool_ClosedActionSpawn(t *testing.T) { + w := NewReusableGoroutinesPool(2).WithClosedAction(SpawnNewGoroutineWhenClosed) + + runCount, panicked, errors := causePoolFailure(t, w, 10) + + require.Equal(t, runCount, 10, "expected all workloads to run") + require.False(t, panicked, "expected no panic") + require.NotZero(t, len(errors), "expected errors") + require.Less(t, len(errors), 10, "expected less than 10 errors. Some workloads were submitted before close.") +} + +func causePoolFailure(t *testing.T, w *ReusableGoroutinesPool, maxMsgCount int) (runCount int, panicked bool, errors []error) { + t.Helper() + + var runCountAtomic atomic.Int32 + + var testWG sync.WaitGroup + testWG.Add(1) + go func() { + defer testWG.Done() + defer func() { + if r := recover(); r != nil { + panicked = true + } + }() + for i := 0; i < maxMsgCount; i++ { + err := w.Go(func() { + runCountAtomic.Add(1) + }) + if err != nil { + errors = append(errors, err) + } + time.Sleep(10 * time.Millisecond) + } + }() + time.Sleep(10 * time.Millisecond) + w.Close() // close the pool + testWG.Wait() // wait for the test to finish + + return int(runCountAtomic.Load()), panicked, errors +}