diff --git a/registry/handler.go b/registry/handler.go index 8d3dacf..f8cc029 100644 --- a/registry/handler.go +++ b/registry/handler.go @@ -19,7 +19,7 @@ type invocationHandler struct { func (h *invocationHandler) Handle(method *dyno.Method, values []reflect.Value) []reflect.Value { h.lock.Lock() defer h.lock.Unlock() - + values = h.refineValues(method, values) call := &MethodCall{ Method: method, Values: values, @@ -39,6 +39,9 @@ func (h *invocationHandler) DoAnswer(c *MethodCall) []reflect.Value { var matched bool for _, mm := range rec.methodMatches { matched = true + if len(mm.matchers) != len(c.Values) { + continue + } for argIdx, matcher := range mm.matchers { if !matcher.matcher.Match(valueSliceToInterfaceSlice(c.Values), c.Values[argIdx].Interface()) { matched = false @@ -167,6 +170,9 @@ func (h *invocationHandler) DoVerifyMethod(call *MethodCall) []reflect.Value { if c.Method.Type != call.Method.Type { continue } + if len(argMatchers) != len(c.Values) { + continue + } for i := range argMatchers { if !argMatchers[i].matcher.Match(valueSliceToInterfaceSlice(c.Values), c.Values[i].Interface()) { @@ -226,8 +232,7 @@ func (h *invocationHandler) validateMatchers(call *MethodCall) bool { } h.ctx.getState().matchers = argMatchers } - mt := call.Method.Type - if len(argMatchers) != mt.Type.NumIn() { + if len(argMatchers) != len(call.Values) { h.ctx.reporter.ReportInvalidUseOfMatchers(h.instanceType, call, argMatchers) return false } @@ -287,3 +292,19 @@ func (h *invocationHandler) VerifyNoMoreInteractions() { h.ctx.reporter.ReportNoMoreInteractionsExpected(h.instanceType, unexpected) } } + +func (h *invocationHandler) refineValues(method *dyno.Method, values []reflect.Value) []reflect.Value { + tp := method.Type.Type + if tp.IsVariadic() { + result := make([]reflect.Value, 0) + for i := 0; i < tp.NumIn()-1; i++ { + result = append(result, values[i]) + } + last := values[len(values)-1] + for i := 0; i < last.Len(); i++ { + result = append(result, last.Index(i)) + } + return result + } + return values +} diff --git a/registry/registry.go b/registry/registry.go index e9ae4ea..3d9a9b8 100644 --- a/registry/registry.go +++ b/registry/registry.go @@ -125,7 +125,7 @@ func newRegistry() any { reporter := &EnrichedReporter{&panicReporter{}} return &Registry{ mockContext: newMockContext(reporter), - mapping: make(map[any]*invocationHandler, 0), + mapping: make(map[any]*invocationHandler), } } diff --git a/registry/reporter.go b/registry/reporter.go index 57aa31a..6b38b2e 100644 --- a/registry/reporter.go +++ b/registry/reporter.go @@ -87,9 +87,12 @@ func (e *EnrichedReporter) ReportInvalidUseOfMatchers(instanceType reflect.Type, declarationLines = append(declarationLines, "\t\t" + m[i].stackTrace.CallerLine()) } decl := strings.Join(declarationLines, "\n") + expectedStr := fmt.Sprintf("%v expected, %v recorded:\n", numExpected, numActual) + if call.Method.Type.Type.IsVariadic() { + expectedStr = "" + } e.StackTraceErrorf(`Invalid use of matchers - %v expected, %v recorded: -%v + %s%v method: %v expected: @@ -99,7 +102,7 @@ func (e *EnrichedReporter) ReportInvalidUseOfMatchers(instanceType reflect.Type, This can happen for 2 reasons: 1. Declaration of matcher outside When() call 2. Mixing matchers and exact values in When() call. Is this case, consider using "Exact" matcher.`, - numExpected, numActual, decl, methodSig, inArgsStr, matchersString) + expectedStr, decl, methodSig, inArgsStr, matchersString) } func (e *EnrichedReporter) ReportCaptorInsideVerify(call *MethodCall, m []*matcherWrapper) { @@ -132,6 +135,9 @@ func (e *EnrichedReporter) ReportVerifyMethodError( other := strings.Builder{} for j, c := range recorder.calls { + if c.WhenCall { + continue + } callArgs := make([]string, len(c.Values)) for i := range c.Values { callArgs[i] = fmt.Sprintf("%v", c.Values[i]) diff --git a/tests/reporting/error_reporting_test.go b/tests/reporting/error_reporting_test.go index 58a0f4a..e8ff08e 100644 --- a/tests/reporting/error_reporting_test.go +++ b/tests/reporting/error_reporting_test.go @@ -10,6 +10,7 @@ import ( type Foo interface { Bar() Baz(a int, b int, c int) int + VarArgs(a string, b ...int) int } func TestReportIncorrectWhenUsage(t *testing.T) { @@ -54,6 +55,16 @@ func TestInvalidUseOfMatchers(t *testing.T) { r.PrintError() } +func TestInvalidUseOfMatchersVarArgs(t *testing.T) { + r := common.NewMockReporter(t) + SetUp(r) + mock := Mock[Foo]() + When(mock.VarArgs(AnyString(), AnyInt(), 10)).ThenReturn(10) + mock.VarArgs("a", 2) + r.AssertError() + r.PrintError() +} + func TestCaptorInsideVerify(t *testing.T) { r := common.NewMockReporter(t) SetUp(r) @@ -76,6 +87,28 @@ func TestVerify(t *testing.T) { r.PrintError() } +func TestVerifyVarArgs(t *testing.T) { + r := common.NewMockReporter(t) + SetUp(r) + mock := Mock[Foo]() + When(mock.VarArgs(AnyString(), AnyInt(), AnyInt())).ThenReturn(10) + _ = mock.VarArgs("a", 10, 11) + Verify(mock, Once()).VarArgs(AnyString(), AnyInt(), Exact(10)) + r.AssertError() + r.PrintError() +} + +func TestVerifyDifferentVarArgs(t *testing.T) { + r := common.NewMockReporter(t) + SetUp(r) + mock := Mock[Foo]() + When(mock.VarArgs(AnyString(), AnyInt(), AnyInt())).ThenReturn(10) + _ = mock.VarArgs("a", 10, 11) + Verify(mock, Once()).VarArgs(AnyString(), AnyInt(), AnyInt(), AnyInt()) + r.AssertError() + r.PrintError() +} + func TestVerifyTimes(t *testing.T) { r := common.NewMockReporter(t) SetUp(r) @@ -118,6 +151,18 @@ func TestNoMoreInteractions(t *testing.T) { r.PrintError() } +func TestNoMoreInteractionsVarArgs(t *testing.T) { + r := common.NewMockReporter(t) + SetUp(r) + mock := Mock[Foo]() + When(mock.VarArgs(AnyString(), AnyInt(), AnyInt())).ThenReturn("test", 10) + _ = mock.Baz(10, 10, 10) + _ = mock.Baz(10, 20, 10) + VerifyNoMoreInteractions(mock) + r.AssertError() + r.PrintError() +} + func TestUnexpectedMatchers(t *testing.T) { r := common.NewMockReporter(t) SetUp(r) diff --git a/tests/variadic/variadic_test.go b/tests/variadic/variadic_test.go new file mode 100644 index 0000000..2af66f3 --- /dev/null +++ b/tests/variadic/variadic_test.go @@ -0,0 +1,37 @@ +package variadic + +import ( + . "github.com/ovechkin-dm/mockio/mock" + "github.com/ovechkin-dm/mockio/tests/common" + "testing" +) + +type myInterface interface { + Foo(a ...int) int +} + +func TestVariadicSimple(t *testing.T) { + r := common.NewMockReporter(t) + SetUp(r) + m := Mock[myInterface]() + WhenSingle(m.Foo(1, 1)).ThenReturn(1) + WhenSingle(m.Foo(1)).ThenReturn(2) + ret := m.Foo(1) + r.AssertEqual(2, ret) + Verify(m, AtLeastOnce()).Foo(1) + r.AssertNoError() +} + +func TestCaptor(t *testing.T) { + r := common.NewMockReporter(t) + SetUp(r) + m := Mock[myInterface]() + c1 := Captor[int]() + c2 := Captor[int]() + WhenSingle(m.Foo(c1.Capture(), c2.Capture())).ThenReturn(1) + ret := m.Foo(1, 2) + r.AssertEqual(1, ret) + r.AssertEqual(c1.Last(), 1) + r.AssertEqual(c2.Last(), 2) + r.AssertNoError() +} \ No newline at end of file