Skip to content

Commit

Permalink
🔀 Merge pull request #112 from perun-network/110-concurrentt-waitctx
Browse files Browse the repository at this point in the history
✨ [pkg/test] Add context to ConcurrentT.
  • Loading branch information
ggwpez authored Jun 17, 2021
2 parents f5f7b3e + 9cb3942 commit c7113e1
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 39 deletions.
58 changes: 44 additions & 14 deletions pkg/test/concurrent.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package test

import (
"context"
"runtime"
"strconv"
"strings"
Expand Down Expand Up @@ -130,17 +131,26 @@ type ConcurrentT struct {
t require.TestingT
failed bool
failedCh chan struct{}
ctx context.Context

mutex sync.Mutex
stages map[string]*stage
}

// NewConcurrent creates a new concurrent testing object.
func NewConcurrent(t require.TestingT) *ConcurrentT {
return NewConcurrentCtx(t, context.Background())
}

// NewConcurrentCtx creates a new concurrent testing object controlled by a
// context. If that context expires, any ongoing stages and wait calls will
// fail.
func NewConcurrentCtx(t require.TestingT, ctx context.Context) *ConcurrentT {
return &ConcurrentT{
t: t,
stages: make(map[string]*stage),
failedCh: make(chan struct{}),
ctx: ctx,
}
}

Expand All @@ -167,8 +177,10 @@ func (t *ConcurrentT) getStage(name string) *stage {
return s
}

// Wait waits until the stages and barriers with the requested names terminate.
// If any stage or barrier fails, terminates the current goroutine or test.
// Wait waits until the stages and barriers with the requested names
// terminate or the test's context expires. If the context expires, fails the
// test. If any stage or barrier fails, terminates the current goroutine or
// test.
func (t *ConcurrentT) Wait(names ...string) {
if len(names) == 0 {
panic("Wait(): called with 0 names")
Expand All @@ -177,6 +189,11 @@ func (t *ConcurrentT) Wait(names ...string) {
for _, name := range names {
stage := t.getStage(name)
select {
case <-t.ctx.Done():
t.failNowMutex.Lock()
t.t.Errorf("Wait for stage %s: %v", name, t.ctx.Err())
t.failNowMutex.Unlock()
t.FailNow()
case <-stage.wg.WaitCh():
if stage.failed.IsSet() {
t.FailNow()
Expand Down Expand Up @@ -209,28 +226,41 @@ func (t *ConcurrentT) FailNow() {
// fn must not spawn any goroutines or pass along the T object to goroutines
// that call T.Fatal. To achieve this, make other goroutines call
// ConcurrentT.StageN() instead.
// If the test's context expires before the call returns, fails the test.
func (t *ConcurrentT) StageN(name string, goroutines int, fn func(ConcT)) {
stage := t.spawnStage(name, goroutines)

stageT := ConcT{TestingT: stage, ct: t}
abort := CheckAbort(func() {
abort, ok := CheckAbortCtx(t.ctx, func() {
fn(stageT)
})

if abort != nil {
// Fail the stage, if it had not been marked as such, yet.
if stage.failed.TrySet() {
defer stage.wg.Done()
}
// If it is a panic or Goexit from certain contexts, print stack trace.
if _, ok := abort.(*Panic); ok || shouldPrintStack(abort.Stack()) {
print("\n", abort.String())
}
if ok && abort == nil {
stage.pass()
t.Wait(name)
return
}

// Fail the stage, if it had not been marked as such, yet.
if stage.failed.TrySet() {
defer stage.wg.Done()
}

// If it did not terminate, just abort the test.
if !ok {
t.failNowMutex.Lock()
t.t.Errorf("Stage %s: %v", name, t.ctx.Err())
t.failNowMutex.Unlock()
t.FailNow()
}

stage.pass()
t.Wait(name)
// If it is a panic or Goexit from certain contexts, print stack trace.
if _, ok := abort.(*Panic); ok || shouldPrintStack(abort.Stack()) {
t.failNowMutex.Lock()
t.t.Errorf("Stage %s: %s", name, abort.String())
t.failNowMutex.Unlock()
}
t.FailNow()
}

func shouldPrintStack(stack string) bool {
Expand Down
33 changes: 24 additions & 9 deletions pkg/test/concurrent_external_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package test_test

import (
"context"
"fmt"
"strconv"
"sync"
Expand Down Expand Up @@ -50,20 +51,34 @@ func TestConcurrentT_Wait(t *testing.T) {
})
ctxtest.AssertTerminates(t, timeout, func() { ct.Wait("known") })
})

t.Run("context expiry", func(t *testing.T) {
ctxtest.AssertTerminates(t, timeout, func() {
test.AssertFatal(t, func(t test.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
test.NewConcurrentCtx(t, ctx).Stage("", func(test.ConcT) {
time.Sleep(timeout)
})
})
})
})
}

func TestConcurrentT_FailNow(t *testing.T) {
var ct *test.ConcurrentT
t.Run("idempotence", func(t *testing.T) {
var ct *test.ConcurrentT

// Test that NewConcurrent.FailNow() calls T.FailNow().
test.AssertFatal(t, func(t test.T) {
ct = test.NewConcurrent(t)
ct.FailNow()
})
// Test that NewConcurrent.FailNow() calls T.FailNow().
test.AssertFatal(t, func(t test.T) {
ct = test.NewConcurrent(t)
ct.FailNow()
})

// Test that after that, FailNow() calls runtime.Goexit().
assert.True(t, test.CheckGoexit(ct.FailNow),
"redundant FailNow() must call runtime.Goexit()")
// Test that after that, FailNow() calls runtime.Goexit().
assert.True(t, test.CheckGoexit(ct.FailNow),
"redundant FailNow() must call runtime.Goexit()")
})

t.Run("hammer", func(t *testing.T) {
const parallel = 12
Expand Down
44 changes: 28 additions & 16 deletions pkg/test/goexit.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package test

import (
"context"
"fmt"
"runtime/debug"
"strings"
Expand Down Expand Up @@ -64,10 +65,11 @@ func (g Goexit) String() string {
return "runtime.Goexit:\n\n" + g.Stack()
}

// CheckAbort tests whether a supplied function is aborted early using panic()
// or runtime.Goexit(). Returns a descriptor of the termination cause or nil if
// it terminated normally.
func CheckAbort(function func()) (abort Abort) {
// CheckAbortCtx tests whether a supplied function terminates within a context,
// and whether it is aborted early using panic() or runtime.Goexit(). Returns
// whether the function terminated before the expiry of the context and if so, a
// descriptor of the termination cause or nil if it terminated normally.
func CheckAbortCtx(ctx context.Context, function func()) (abort Abort, ok bool) {
done := make(chan struct{})

goexit := true // Whether runtime.Goexit occurred.
Expand Down Expand Up @@ -103,20 +105,30 @@ func CheckAbort(function func()) (abort Abort) {
goexit = false
}()

<-done

// Concatenate the inner call stack of the failure (which starts at the
// goroutine instantiation) with the goroutine that is calling CheckAbort.
if goexit || aborted {
base.stack += "\n" + getStack(true, 1, 0)
select {
case <-ctx.Done():
return nil, false
case <-done:
ok = true
// Concatenate the inner call stack of the failure (which starts at the
// goroutine instantiation) with the goroutine that is calling CheckAbort.
if goexit || aborted {
base.stack += "\n" + getStack(true, 1, 0)
}

if goexit {
abort = &Goexit{base}
} else if aborted {
abort = &Panic{base, recovered}
}
return
}
}

if goexit {
abort = &Goexit{base}
} else if aborted {
abort = &Panic{base, recovered}
}
return
// CheckAbort calls CheckAbortCtx with context.Background.
func CheckAbort(function func()) Abort {
abort, _ := CheckAbortCtx(context.Background(), function)
return abort
}

// getStack retrieves the current call stack as text, and optionally removes the
Expand Down

0 comments on commit c7113e1

Please sign in to comment.