Skip to content

Commit c7113e1

Browse files
authored
🔀 Merge pull request #112 from perun-network/110-concurrentt-waitctx
✨ [pkg/test] Add context to ConcurrentT.
2 parents f5f7b3e + 9cb3942 commit c7113e1

File tree

3 files changed

+96
-39
lines changed

3 files changed

+96
-39
lines changed

pkg/test/concurrent.go

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
package test
1616

1717
import (
18+
"context"
1819
"runtime"
1920
"strconv"
2021
"strings"
@@ -130,17 +131,26 @@ type ConcurrentT struct {
130131
t require.TestingT
131132
failed bool
132133
failedCh chan struct{}
134+
ctx context.Context
133135

134136
mutex sync.Mutex
135137
stages map[string]*stage
136138
}
137139

138140
// NewConcurrent creates a new concurrent testing object.
139141
func NewConcurrent(t require.TestingT) *ConcurrentT {
142+
return NewConcurrentCtx(t, context.Background())
143+
}
144+
145+
// NewConcurrentCtx creates a new concurrent testing object controlled by a
146+
// context. If that context expires, any ongoing stages and wait calls will
147+
// fail.
148+
func NewConcurrentCtx(t require.TestingT, ctx context.Context) *ConcurrentT {
140149
return &ConcurrentT{
141150
t: t,
142151
stages: make(map[string]*stage),
143152
failedCh: make(chan struct{}),
153+
ctx: ctx,
144154
}
145155
}
146156

@@ -167,8 +177,10 @@ func (t *ConcurrentT) getStage(name string) *stage {
167177
return s
168178
}
169179

170-
// Wait waits until the stages and barriers with the requested names terminate.
171-
// If any stage or barrier fails, terminates the current goroutine or test.
180+
// Wait waits until the stages and barriers with the requested names
181+
// terminate or the test's context expires. If the context expires, fails the
182+
// test. If any stage or barrier fails, terminates the current goroutine or
183+
// test.
172184
func (t *ConcurrentT) Wait(names ...string) {
173185
if len(names) == 0 {
174186
panic("Wait(): called with 0 names")
@@ -177,6 +189,11 @@ func (t *ConcurrentT) Wait(names ...string) {
177189
for _, name := range names {
178190
stage := t.getStage(name)
179191
select {
192+
case <-t.ctx.Done():
193+
t.failNowMutex.Lock()
194+
t.t.Errorf("Wait for stage %s: %v", name, t.ctx.Err())
195+
t.failNowMutex.Unlock()
196+
t.FailNow()
180197
case <-stage.wg.WaitCh():
181198
if stage.failed.IsSet() {
182199
t.FailNow()
@@ -209,28 +226,41 @@ func (t *ConcurrentT) FailNow() {
209226
// fn must not spawn any goroutines or pass along the T object to goroutines
210227
// that call T.Fatal. To achieve this, make other goroutines call
211228
// ConcurrentT.StageN() instead.
229+
// If the test's context expires before the call returns, fails the test.
212230
func (t *ConcurrentT) StageN(name string, goroutines int, fn func(ConcT)) {
213231
stage := t.spawnStage(name, goroutines)
214232

215233
stageT := ConcT{TestingT: stage, ct: t}
216-
abort := CheckAbort(func() {
234+
abort, ok := CheckAbortCtx(t.ctx, func() {
217235
fn(stageT)
218236
})
219237

220-
if abort != nil {
221-
// Fail the stage, if it had not been marked as such, yet.
222-
if stage.failed.TrySet() {
223-
defer stage.wg.Done()
224-
}
225-
// If it is a panic or Goexit from certain contexts, print stack trace.
226-
if _, ok := abort.(*Panic); ok || shouldPrintStack(abort.Stack()) {
227-
print("\n", abort.String())
228-
}
238+
if ok && abort == nil {
239+
stage.pass()
240+
t.Wait(name)
241+
return
242+
}
243+
244+
// Fail the stage, if it had not been marked as such, yet.
245+
if stage.failed.TrySet() {
246+
defer stage.wg.Done()
247+
}
248+
249+
// If it did not terminate, just abort the test.
250+
if !ok {
251+
t.failNowMutex.Lock()
252+
t.t.Errorf("Stage %s: %v", name, t.ctx.Err())
253+
t.failNowMutex.Unlock()
229254
t.FailNow()
230255
}
231256

232-
stage.pass()
233-
t.Wait(name)
257+
// If it is a panic or Goexit from certain contexts, print stack trace.
258+
if _, ok := abort.(*Panic); ok || shouldPrintStack(abort.Stack()) {
259+
t.failNowMutex.Lock()
260+
t.t.Errorf("Stage %s: %s", name, abort.String())
261+
t.failNowMutex.Unlock()
262+
}
263+
t.FailNow()
234264
}
235265

236266
func shouldPrintStack(stack string) bool {

pkg/test/concurrent_external_test.go

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
package test_test
1616

1717
import (
18+
"context"
1819
"fmt"
1920
"strconv"
2021
"sync"
@@ -50,20 +51,34 @@ func TestConcurrentT_Wait(t *testing.T) {
5051
})
5152
ctxtest.AssertTerminates(t, timeout, func() { ct.Wait("known") })
5253
})
54+
55+
t.Run("context expiry", func(t *testing.T) {
56+
ctxtest.AssertTerminates(t, timeout, func() {
57+
test.AssertFatal(t, func(t test.T) {
58+
ctx, cancel := context.WithCancel(context.Background())
59+
cancel()
60+
test.NewConcurrentCtx(t, ctx).Stage("", func(test.ConcT) {
61+
time.Sleep(timeout)
62+
})
63+
})
64+
})
65+
})
5366
}
5467

5568
func TestConcurrentT_FailNow(t *testing.T) {
56-
var ct *test.ConcurrentT
69+
t.Run("idempotence", func(t *testing.T) {
70+
var ct *test.ConcurrentT
5771

58-
// Test that NewConcurrent.FailNow() calls T.FailNow().
59-
test.AssertFatal(t, func(t test.T) {
60-
ct = test.NewConcurrent(t)
61-
ct.FailNow()
62-
})
72+
// Test that NewConcurrent.FailNow() calls T.FailNow().
73+
test.AssertFatal(t, func(t test.T) {
74+
ct = test.NewConcurrent(t)
75+
ct.FailNow()
76+
})
6377

64-
// Test that after that, FailNow() calls runtime.Goexit().
65-
assert.True(t, test.CheckGoexit(ct.FailNow),
66-
"redundant FailNow() must call runtime.Goexit()")
78+
// Test that after that, FailNow() calls runtime.Goexit().
79+
assert.True(t, test.CheckGoexit(ct.FailNow),
80+
"redundant FailNow() must call runtime.Goexit()")
81+
})
6782

6883
t.Run("hammer", func(t *testing.T) {
6984
const parallel = 12

pkg/test/goexit.go

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
package test
1616

1717
import (
18+
"context"
1819
"fmt"
1920
"runtime/debug"
2021
"strings"
@@ -64,10 +65,11 @@ func (g Goexit) String() string {
6465
return "runtime.Goexit:\n\n" + g.Stack()
6566
}
6667

67-
// CheckAbort tests whether a supplied function is aborted early using panic()
68-
// or runtime.Goexit(). Returns a descriptor of the termination cause or nil if
69-
// it terminated normally.
70-
func CheckAbort(function func()) (abort Abort) {
68+
// CheckAbortCtx tests whether a supplied function terminates within a context,
69+
// and whether it is aborted early using panic() or runtime.Goexit(). Returns
70+
// whether the function terminated before the expiry of the context and if so, a
71+
// descriptor of the termination cause or nil if it terminated normally.
72+
func CheckAbortCtx(ctx context.Context, function func()) (abort Abort, ok bool) {
7173
done := make(chan struct{})
7274

7375
goexit := true // Whether runtime.Goexit occurred.
@@ -103,20 +105,30 @@ func CheckAbort(function func()) (abort Abort) {
103105
goexit = false
104106
}()
105107

106-
<-done
107-
108-
// Concatenate the inner call stack of the failure (which starts at the
109-
// goroutine instantiation) with the goroutine that is calling CheckAbort.
110-
if goexit || aborted {
111-
base.stack += "\n" + getStack(true, 1, 0)
108+
select {
109+
case <-ctx.Done():
110+
return nil, false
111+
case <-done:
112+
ok = true
113+
// Concatenate the inner call stack of the failure (which starts at the
114+
// goroutine instantiation) with the goroutine that is calling CheckAbort.
115+
if goexit || aborted {
116+
base.stack += "\n" + getStack(true, 1, 0)
117+
}
118+
119+
if goexit {
120+
abort = &Goexit{base}
121+
} else if aborted {
122+
abort = &Panic{base, recovered}
123+
}
124+
return
112125
}
126+
}
113127

114-
if goexit {
115-
abort = &Goexit{base}
116-
} else if aborted {
117-
abort = &Panic{base, recovered}
118-
}
119-
return
128+
// CheckAbort calls CheckAbortCtx with context.Background.
129+
func CheckAbort(function func()) Abort {
130+
abort, _ := CheckAbortCtx(context.Background(), function)
131+
return abort
120132
}
121133

122134
// getStack retrieves the current call stack as text, and optionally removes the

0 commit comments

Comments
 (0)