15
15
package test
16
16
17
17
import (
18
+ "context"
18
19
"runtime"
19
20
"strconv"
20
21
"strings"
@@ -130,17 +131,26 @@ type ConcurrentT struct {
130
131
t require.TestingT
131
132
failed bool
132
133
failedCh chan struct {}
134
+ ctx context.Context
133
135
134
136
mutex sync.Mutex
135
137
stages map [string ]* stage
136
138
}
137
139
138
140
// NewConcurrent creates a new concurrent testing object.
139
141
func NewConcurrent (t require.TestingT ) * ConcurrentT {
142
+ return NewConcurrentCtx (t , context .Background ())
143
+ }
144
+
145
+ // NewConcurrentCtx creates a new concurrent testing object controlled by a
146
+ // context. If that context expires, any ongoing stages and wait calls will
147
+ // fail.
148
+ func NewConcurrentCtx (t require.TestingT , ctx context.Context ) * ConcurrentT {
140
149
return & ConcurrentT {
141
150
t : t ,
142
151
stages : make (map [string ]* stage ),
143
152
failedCh : make (chan struct {}),
153
+ ctx : ctx ,
144
154
}
145
155
}
146
156
@@ -167,8 +177,10 @@ func (t *ConcurrentT) getStage(name string) *stage {
167
177
return s
168
178
}
169
179
170
- // Wait waits until the stages and barriers with the requested names terminate.
171
- // If any stage or barrier fails, terminates the current goroutine or test.
180
+ // Wait waits until the stages and barriers with the requested names
181
+ // terminate or the test's context expires. If the context expires, fails the
182
+ // test. If any stage or barrier fails, terminates the current goroutine or
183
+ // test.
172
184
func (t * ConcurrentT ) Wait (names ... string ) {
173
185
if len (names ) == 0 {
174
186
panic ("Wait(): called with 0 names" )
@@ -177,6 +189,11 @@ func (t *ConcurrentT) Wait(names ...string) {
177
189
for _ , name := range names {
178
190
stage := t .getStage (name )
179
191
select {
192
+ case <- t .ctx .Done ():
193
+ t .failNowMutex .Lock ()
194
+ t .t .Errorf ("Wait for stage %s: %v" , name , t .ctx .Err ())
195
+ t .failNowMutex .Unlock ()
196
+ t .FailNow ()
180
197
case <- stage .wg .WaitCh ():
181
198
if stage .failed .IsSet () {
182
199
t .FailNow ()
@@ -209,28 +226,41 @@ func (t *ConcurrentT) FailNow() {
209
226
// fn must not spawn any goroutines or pass along the T object to goroutines
210
227
// that call T.Fatal. To achieve this, make other goroutines call
211
228
// ConcurrentT.StageN() instead.
229
+ // If the test's context expires before the call returns, fails the test.
212
230
func (t * ConcurrentT ) StageN (name string , goroutines int , fn func (ConcT )) {
213
231
stage := t .spawnStage (name , goroutines )
214
232
215
233
stageT := ConcT {TestingT : stage , ct : t }
216
- abort := CheckAbort ( func () {
234
+ abort , ok := CheckAbortCtx ( t . ctx , func () {
217
235
fn (stageT )
218
236
})
219
237
220
- if abort != nil {
221
- // Fail the stage, if it had not been marked as such, yet.
222
- if stage .failed .TrySet () {
223
- defer stage .wg .Done ()
224
- }
225
- // If it is a panic or Goexit from certain contexts, print stack trace.
226
- if _ , ok := abort .(* Panic ); ok || shouldPrintStack (abort .Stack ()) {
227
- print ("\n " , abort .String ())
228
- }
238
+ if ok && abort == nil {
239
+ stage .pass ()
240
+ t .Wait (name )
241
+ return
242
+ }
243
+
244
+ // Fail the stage, if it had not been marked as such, yet.
245
+ if stage .failed .TrySet () {
246
+ defer stage .wg .Done ()
247
+ }
248
+
249
+ // If it did not terminate, just abort the test.
250
+ if ! ok {
251
+ t .failNowMutex .Lock ()
252
+ t .t .Errorf ("Stage %s: %v" , name , t .ctx .Err ())
253
+ t .failNowMutex .Unlock ()
229
254
t .FailNow ()
230
255
}
231
256
232
- stage .pass ()
233
- t .Wait (name )
257
+ // If it is a panic or Goexit from certain contexts, print stack trace.
258
+ if _ , ok := abort .(* Panic ); ok || shouldPrintStack (abort .Stack ()) {
259
+ t .failNowMutex .Lock ()
260
+ t .t .Errorf ("Stage %s: %s" , name , abort .String ())
261
+ t .failNowMutex .Unlock ()
262
+ }
263
+ t .FailNow ()
234
264
}
235
265
236
266
func shouldPrintStack (stack string ) bool {
0 commit comments