Skip to content

Commit

Permalink
Merge pull request #3 from newmo-oss/fix-calling-original-tb
Browse files Browse the repository at this point in the history
Call original method
  • Loading branch information
tenntenn authored Dec 6, 2024
2 parents 8f3361d + 87ed15a commit dac995b
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 27 deletions.
121 changes: 94 additions & 27 deletions testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ type TB struct {
SkippedFunc func() bool
TempDirFunc func() string

testing.TB // for private method and unsupport method
testing.TB // for default behavior andd private method
}

// Record records the result of [Run].
Expand Down Expand Up @@ -72,131 +72,198 @@ func Run(f func(*TB)) *Record {
}

func (tb *TB) Cleanup(f func()) {
if tb.CleanupFunc != nil {
switch {
case tb.CleanupFunc != nil:
tb.CleanupFunc(f)
case tb.TB != nil:
tb.TB.Cleanup(f)
}
}

func (tb *TB) Error(args ...any) {
tb.record.Failed = true
if tb.ErrorFunc != nil {
switch {
case tb.ErrorFunc != nil:
tb.ErrorFunc(args...)
case tb.TB != nil:
tb.TB.Error(args...)
}
}

func (tb *TB) Errorf(format string, args ...any) {
tb.record.Failed = true
if tb.ErrorfFunc != nil {
switch {
case tb.ErrorfFunc != nil:
tb.ErrorfFunc(format, args...)
case tb.TB != nil:
tb.TB.Errorf(format, args...)
}
}

func (tb *TB) Fail() {
tb.record.Failed = true
if tb.FailFunc != nil {
switch {
case tb.FailFunc != nil:
tb.FailFunc()
case tb.TB != nil:
tb.TB.Fail()
}
}

func (tb *TB) FailNow() {
tb.record.Failed = true
tb.record.Goexit = true
if tb.FailNowFunc != nil {

switch {
case tb.FailNowFunc != nil:
tb.FailNowFunc()
} else {
runtime.Goexit()
case tb.TB != nil:
tb.TB.FailNow()
default:
runtime.Goexit()
}
}

func (tb *TB) Failed() bool {
if tb.FailedFunc != nil {
switch {
case tb.FailedFunc != nil:
return tb.FailedFunc()
case tb.TB != nil:
return tb.TB.Failed()
default:
return tb.record.Failed
}
return tb.record.Failed
}

func (tb *TB) Fatal(args ...any) {
tb.record.Failed = true
tb.record.Goexit = true
if tb.FatalFunc != nil {
switch {
case tb.FatalFunc != nil:
tb.FatalFunc(args...)
runtime.Goexit()
case tb.TB != nil:
tb.TB.Fatal(args...)
default:
runtime.Goexit()
}
runtime.Goexit()
}

func (tb *TB) Fatalf(format string, args ...any) {
tb.record.Failed = true
tb.record.Goexit = true
if tb.FatalfFunc != nil {

switch {
case tb.FatalfFunc != nil:
tb.FatalfFunc(format, args...)
runtime.Goexit()
case tb.TB != nil:
tb.TB.Fatalf(format, args...)
default:
runtime.Goexit()
}
runtime.Goexit()
}

func (tb *TB) Helper() {
if tb.HelperFunc != nil {
switch {
case tb.HelperFunc != nil:
tb.HelperFunc()
case tb.TB != nil:
tb.TB.Helper()
}
}

func (tb *TB) Log(args ...any) {
if tb.LogFunc != nil {
switch {
case tb.LogFunc != nil:
tb.LogFunc(args...)
case tb.TB != nil:
tb.TB.Log(args...)
}
}

func (tb *TB) Logf(format string, args ...any) {
if tb.LogfFunc != nil {
switch {
case tb.LogfFunc != nil:
tb.LogfFunc(format, args...)
case tb.TB != nil:
tb.TB.Logf(format, args...)
}
}

func (tb *TB) Name() string {
if tb.NameFunc != nil {
switch {
case tb.NameFunc != nil:
return tb.NameFunc()
case tb.TB != nil:
return tb.TB.Name()
default:
return ""
}
return ""
}

func (tb *TB) Setenv(key, value string) {
if tb.SetenvFunc != nil {
switch {
case tb.SetenvFunc != nil:
tb.SetenvFunc(key, value)
case tb.TB != nil:
tb.TB.Setenv(key, value)
}
}

func (tb *TB) Skip(args ...any) {
tb.record.Skipped = true
if tb.SkipFunc != nil {
switch {
case tb.SkipFunc != nil:
tb.SkipFunc(args...)
case tb.TB != nil:
tb.TB.Skip(args...)
}
}

func (tb *TB) SkipNow() {
tb.record.Skipped = true
tb.record.Goexit = true
if tb.SkipNowFunc != nil {
switch {
case tb.SkipNowFunc != nil:
tb.SkipNowFunc()
runtime.Goexit()
case tb.TB != nil:
tb.TB.SkipNow()
default:
runtime.Goexit()
}
runtime.Goexit()
}

func (tb *TB) Skipf(format string, args ...any) {
tb.record.Skipped = true
if tb.SkipfFunc != nil {
switch {
case tb.SkipfFunc != nil:
tb.SkipfFunc(format, args...)
case tb.TB != nil:
tb.TB.Skipf(format, args...)
}
}

func (tb *TB) Skipped() bool {
if tb.SkippedFunc != nil {
switch {
case tb.SkippedFunc != nil:
return tb.SkippedFunc()
case tb.TB != nil:
return tb.TB.Skipped()
default:
return tb.record.Skipped
}
return tb.record.Skipped
}

func (tb *TB) TempDir() string {
if tb.TempDirFunc != nil {
switch {
case tb.TempDirFunc != nil:
return tb.TempDirFunc()
case tb.TB != nil:
return tb.TB.TempDir()
default:
return ""
}
return ""
}
54 changes: 54 additions & 0 deletions testing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,60 @@ func TestTB_Fields(t *testing.T) {
}
}

func TestTB_DefaultMethod(t *testing.T) {
t.Parallel()

typ := reflect.TypeFor[gotestingmock.TB]()
for i := range typ.NumField() {
ft := typ.Field(i)
if !strings.HasSuffix(ft.Name, "Func") &&
ft.Type.Kind() != reflect.Func {
continue
}

t.Run(ft.Name, func(t *testing.T) {
t.Parallel()

method := strings.TrimSuffix(ft.Name, "Func")

var call bool
rec := gotestingmock.Run(func(parent *gotestingmock.TB) {
tb := &gotestingmock.TB{TB: parent}

pv := reflect.ValueOf(parent)
pfv := pv.Elem().Field(i)

/*
parent.XxxFunc = func() {
call = true
}
parent.Xxx()
*/

pfv.Set(reflect.MakeFunc(ft.Type, func([]reflect.Value) []reflect.Value {
call = true
ret := make([]reflect.Value, pfv.Type().NumOut())
for i := range pfv.Type().NumOut() {
ret[i] = reflect.New(pfv.Type().Out(i)).Elem()
}
return ret
}))

v := reflect.ValueOf(tb)
callWithZeros(v.MethodByName(method))
})

if rec.PanicValue != nil {
t.Fatal("unexpected panic:", rec.PanicValue)
}

if !call {
t.Errorf("(*gotestingmock.TB).%[1]s did not call with (testing.TB).%[1]s", method)
}
})
}
}

func TestRecord(t *testing.T) {
t.Parallel()

Expand Down

0 comments on commit dac995b

Please sign in to comment.