diff --git a/matchers/verification.go b/matchers/verification.go index 5e7541c..0835c49 100644 --- a/matchers/verification.go +++ b/matchers/verification.go @@ -15,10 +15,6 @@ type InvocationData struct { Args []reflect.Value } -type InstanceVerifier interface { - RecordInteraction(data *InvocationData) error -} - type MethodVerifier interface { Verify(data *MethodVerificationData) error } @@ -47,12 +43,6 @@ func MethodVerifierFromFunc(f func(data *MethodVerificationData) error) MethodVe } } -func InstanceVerifierFromFunc(f func(data *InvocationData) error) InstanceVerifier { - return &instanceVerifierImpl{ - f: f, - } -} - type methodVerifierImpl struct { f func(data *MethodVerificationData) error } @@ -60,11 +50,3 @@ type methodVerifierImpl struct { func (m *methodVerifierImpl) Verify(data *MethodVerificationData) error { return m.f(data) } - -type instanceVerifierImpl struct { - f func(data *InvocationData) error -} - -func (i *instanceVerifierImpl) RecordInteraction(data *InvocationData) error { - return i.f(data) -} diff --git a/mock/api.go b/mock/api.go index b1b66c1..a359993 100644 --- a/mock/api.go +++ b/mock/api.go @@ -1,6 +1,7 @@ package mock import ( + "context" "fmt" "github.com/ovechkin-dm/mockio/matchers" "github.com/ovechkin-dm/mockio/registry" @@ -97,6 +98,12 @@ func AnyInterface() any { return Any[any]() } +// AnyContext is an alias for Any[context.Context] +// See Any for more description +func AnyContext() context.Context { + return Any[context.Context]() +} + // AnyOfType is an alias for Any[T] for specific type // Used for automatic type inference func AnyOfType[T any](t T) T { @@ -354,7 +361,7 @@ func Never() matchers.MethodVerifier { } // VerifyNoMoreInteractions verifies that there are no more unverified interactions with the mock object. -// +// For example if // Example usage: // // // Create a mock object for testing @@ -369,7 +376,5 @@ func Never() matchers.MethodVerifier { // // Verify that there are no more unverified interactions // VerifyNoMoreInteractions(mockObj) func VerifyNoMoreInteractions(value any) { - registry.VerifyInstance(value, matchers.InstanceVerifierFromFunc(func(data *matchers.InvocationData) error { - return fmt.Errorf("no more interactions should be recorded for mock") - })) + registry.VerifyNoMoreInteractions(value) } diff --git a/registry/handler.go b/registry/handler.go index aac6591..257200c 100644 --- a/registry/handler.go +++ b/registry/handler.go @@ -13,7 +13,6 @@ import ( type invocationHandler struct { ctx *mockContext calls []*methodRecorder - instanceVerifiers []matchers.InstanceVerifier lock sync.Mutex instanceType reflect.Type } @@ -35,10 +34,6 @@ func (h *invocationHandler) Handle(method *dyno.Method, values []reflect.Value) func (h *invocationHandler) DoAnswer(c *MethodCall) []reflect.Value { rec := h.calls[c.Method.Num] - ok := h.VerifyInstance(c) - if !ok { - return createDefaultReturnValues(c.Method.Type) - } h.ctx.getState().whenHandler = h h.ctx.getState().whenCall = c var matched bool @@ -130,26 +125,6 @@ func (h *invocationHandler) When() matchers.ReturnerAll { return NewReturnerAll(h.ctx, m) } -func (h *invocationHandler) VerifyInstance(m *MethodCall) bool { - data := &matchers.InvocationData{ - MethodType: m.Method.Type, - MethodName: m.Method.Name, - Args: m.Values, - } - for _, v := range h.instanceVerifiers { - err := v.RecordInteraction(data) - if err != nil { - h.ctx.reporter.FailNow(err) - return false - } - } - return true -} - -func (h *invocationHandler) AddInstanceVerifier(v matchers.InstanceVerifier) { - h.instanceVerifiers = append(h.instanceVerifiers, v) -} - func (h *invocationHandler) VerifyMethod(verifier matchers.MethodVerifier) { h.lock.Lock() defer h.lock.Unlock() @@ -198,6 +173,7 @@ func (h *invocationHandler) DoVerifyMethod(call *MethodCall) []reflect.Value { } if matches { + c.Verified = true numMethodCalls += 1 } } @@ -225,7 +201,6 @@ func newHandler[T any](holder *mockContext) *invocationHandler { return &invocationHandler{ ctx: holder, calls: recorders, - instanceVerifiers: make([]matchers.InstanceVerifier, 0), instanceType: tp, } } @@ -308,3 +283,16 @@ func (h *invocationHandler) CheckUnusedStubs() { } } } + +func (h *invocationHandler) VerifyNoMoreInteractions() { + for _, rec := range h.calls { + for _, call := range rec.calls { + if call.WhenCall { + continue + } + if !call.Verified { + h.ctx.reporter.ReportNoMoreInteractionsExpected(h.instanceType, call) + } + } + } +} diff --git a/registry/registry.go b/registry/registry.go index 1930143..e25c768 100644 --- a/registry/registry.go +++ b/registry/registry.go @@ -107,14 +107,14 @@ func VerifyMethod(t any, v matchers.MethodVerifier) { }) } -func VerifyInstance(t any, v matchers.InstanceVerifier) { +func VerifyNoMoreInteractions(t any) { withCheck(func() any { handler, ok := getInstance().mapping[t] if !ok { getInstance().mockContext.reporter.ReportUnregisteredMockVerify(t) return nil } - handler.AddInstanceVerifier(v) + handler.VerifyNoMoreInteractions() return nil }) } diff --git a/registry/reporter.go b/registry/reporter.go index b8f5f8e..4c3f956 100644 --- a/registry/reporter.go +++ b/registry/reporter.go @@ -202,3 +202,8 @@ func PrettyPrintMethodInvocation(interfaceType reflect.Type, method reflect.Meth sb.WriteRune(')') return sb.String() } + +func (e *EnrichedReporter) ReportNoMoreInteractionsExpected(instanceType reflect.Type, call *MethodCall) { + methodSig := prettyPrintMethodSignature(instanceType, call.Method.Type) + e.Errorf("no more interactions expected on %v", methodSig) +} diff --git a/registry/state.go b/registry/state.go index 847f067..0eaeed7 100644 --- a/registry/state.go +++ b/registry/state.go @@ -116,4 +116,5 @@ type MethodCall struct { Method *dyno.Method Values []reflect.Value WhenCall bool + Verified bool } diff --git a/tests/verify/verify_test.go b/tests/verify/verify_test.go index 7407a9f..e3a18db 100644 --- a/tests/verify/verify_test.go +++ b/tests/verify/verify_test.go @@ -63,12 +63,50 @@ func TestVerifyNeverFails(t *testing.T) { r.AssertError() } -func TestNoMoreInteractions(t *testing.T) { +func TestNoMoreInteractionsFails(t *testing.T) { r := common.NewMockReporter(t) SetUp(r) m := Mock[iface]() WhenSingle(m.Foo(Any[int]())).ThenReturn(10) + m.Foo(10) VerifyNoMoreInteractions(m) + r.AssertError() +} + +func TestNoMoreInteractionsSuccess(t *testing.T) { + r := common.NewMockReporter(t) + SetUp(r) + m := Mock[iface]() + WhenSingle(m.Foo(Any[int]())).ThenReturn(10) m.Foo(10) + Verify(m, Once()).Foo(10) + VerifyNoMoreInteractions(m) + r.AssertNoError() +} + +func TestNoMoreInteractionsComplexFail(t *testing.T) { + r := common.NewMockReporter(t) + SetUp(r) + m := Mock[iface]() + WhenSingle(m.Foo(10)).ThenReturn(10) + WhenSingle(m.Foo(11)).ThenReturn(10) + m.Foo(10) + m.Foo(11) + Verify(m, Once()).Foo(10) + VerifyNoMoreInteractions(m) r.AssertError() } + +func TestNoMoreInteractionsComplexSuccess(t *testing.T) { + r := common.NewMockReporter(t) + SetUp(r) + m := Mock[iface]() + WhenSingle(m.Foo(10)).ThenReturn(10) + WhenSingle(m.Foo(11)).ThenReturn(10) + m.Foo(10) + m.Foo(11) + Verify(m, AtLeastOnce()).Foo(AnyInt()) + Verify(m, Once()).Foo(11) + VerifyNoMoreInteractions(m) + r.AssertNoError() +}