Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JMockit to Mockito Recipe - Handle Typed Class Argument Matching and Collections #420

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
package org.openrewrite.java.testing.jmockit;

import java.util.*;
import java.util.regex.Pattern;

import lombok.EqualsAndHashCode;
import lombok.Value;
Expand Down Expand Up @@ -69,6 +68,14 @@ private static class RewriteExpectationsVisitor extends JavaIsoVisitor<Execution
JMOCKIT_ARGUMENT_MATCHERS.add("anyShort");
JMOCKIT_ARGUMENT_MATCHERS.add("any");
}
private static final Map<String, String> MOCKITO_COLLECTION_MATCHERS = new HashMap<>();
static {
MOCKITO_COLLECTION_MATCHERS.put("java.util.List", "anyList");
MOCKITO_COLLECTION_MATCHERS.put("java.util.Set", "anySet");
MOCKITO_COLLECTION_MATCHERS.put("java.util.Collection", "anyCollection");
MOCKITO_COLLECTION_MATCHERS.put("java.util.Iterable", "anyIterable");
MOCKITO_COLLECTION_MATCHERS.put("java.util.Map", "anyMap");
}

@Override
public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration methodDeclaration, ExecutionContext ctx) {
Expand Down Expand Up @@ -116,7 +123,7 @@ public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration methodDecl
if (expectationStatement instanceof J.MethodInvocation) {
if (!templateParams.isEmpty()) {
// apply template to build new method body
newBody = applyTemplate(ctx, templateParams, cursorLocation, coordinates);
newBody = rewriteMethodBody(ctx, templateParams, cursorLocation, coordinates);

// next statement coordinates are immediately after the statement just added
int newStatementIndex = bodyStatementIndex + mockitoStatementIndex;
Expand All @@ -138,15 +145,15 @@ public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration methodDecl

// handle the last statement
if (!templateParams.isEmpty()) {
newBody = applyTemplate(ctx, templateParams, cursorLocation, coordinates);
newBody = rewriteMethodBody(ctx, templateParams, cursorLocation, coordinates);
}
}

return md.withBody(newBody);
}

private J.Block applyTemplate(ExecutionContext ctx, List<Object> templateParams, Object cursorLocation,
JavaCoordinates coordinates) {
private J.Block rewriteMethodBody(ExecutionContext ctx, List<Object> templateParams, Object cursorLocation,
JavaCoordinates coordinates) {
Expression result = null;
String methodName = "doNothing";
if (templateParams.size() > 1) {
Expand All @@ -166,29 +173,80 @@ private J.Block applyTemplate(ExecutionContext ctx, List<Object> templateParams,
);
}

private void rewriteArgumentMatchers(ExecutionContext ctx, List<Object> templateParams) {
J.MethodInvocation invocation = (J.MethodInvocation) templateParams.get(0);
private void rewriteArgumentMatchers(ExecutionContext ctx, List<Object> bodyTemplateParams) {
J.MethodInvocation invocation = (J.MethodInvocation) bodyTemplateParams.get(0);
List<Expression> newArguments = new ArrayList<>(invocation.getArguments().size());
for (Expression methodArgument : invocation.getArguments()) {
if (!isArgumentMatcher(methodArgument)) {
newArguments.add(methodArgument);
continue;
}
String argumentMatcher = ((J.Identifier) methodArgument).getSimpleName();
maybeAddImport("org.mockito.Mockito", argumentMatcher);
newArguments.add(JavaTemplate.builder(argumentMatcher + "()")
.javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "mockito-core-3.12"))
.staticImports("org.mockito.Mockito." + argumentMatcher)
.build()
.apply(
new Cursor(getCursor(), methodArgument),
methodArgument.getCoordinates().replace()
));
String argumentMatcher, template;
List<Object> argumentTemplateParams = new ArrayList<>();
if (!(methodArgument instanceof J.TypeCast)) {
argumentMatcher = ((J.Identifier) methodArgument).getSimpleName();
template = argumentMatcher + "()";
newArguments.add(rewriteMethodArgument(ctx, argumentMatcher, template, methodArgument,
methodArgument.getCoordinates().replace(), argumentTemplateParams));
continue;
}
J.TypeCast tc = (J.TypeCast) methodArgument;
argumentMatcher = ((J.Identifier) tc.getExpression()).getSimpleName();
String className, fqn;
JavaType typeCastType = tc.getType();
if (typeCastType instanceof JavaType.Parameterized) {
// strip the raw type from the parameterized type
className = ((JavaType.Parameterized) typeCastType).getType().getClassName();
fqn = ((JavaType.Parameterized) typeCastType).getType().getFullyQualifiedName();
} else if (typeCastType instanceof JavaType.FullyQualified) {
className = ((JavaType.FullyQualified) typeCastType).getClassName();
fqn = ((JavaType.FullyQualified) typeCastType).getFullyQualifiedName();
} else {
throw new IllegalStateException("Unexpected value: " + typeCastType);
timtebeek marked this conversation as resolved.
Show resolved Hide resolved
}
if (MOCKITO_COLLECTION_MATCHERS.containsKey(fqn)) {
// mockito has specific argument matchers for collections
argumentMatcher = MOCKITO_COLLECTION_MATCHERS.get(fqn);
template = argumentMatcher + "()";
} else {
// rewrite parameter from ((<type>) any) to <type>.class
argumentTemplateParams.add(JavaTemplate.builder("#{}.class")
.javaParser(JavaParser.fromJavaVersion())
.imports(fqn)
.build()
.apply(
new Cursor(getCursor(), tc),
tc.getCoordinates().replace(),
className
));
template = argumentMatcher + "(#{any(java.lang.Class)})";
}

newArguments.add(rewriteMethodArgument(ctx, argumentMatcher, template, methodArgument,
methodArgument.getCoordinates().replace(), argumentTemplateParams));
}
templateParams.set(0, invocation.withArguments(newArguments));
bodyTemplateParams.set(0, invocation.withArguments(newArguments));
}

private Expression rewriteMethodArgument(ExecutionContext ctx, String argumentMatcher, String template,
Object cursorLocation, JavaCoordinates coordinates,
List<Object> templateParams) {
maybeAddImport("org.mockito.Mockito", argumentMatcher);
return JavaTemplate.builder(template)
.javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "mockito-core-3.12"))
.staticImports("org.mockito.Mockito." + argumentMatcher)
.build()
.apply(
new Cursor(getCursor(), cursorLocation),
coordinates,
templateParams.toArray()
);
}

private static boolean isArgumentMatcher(Expression expression) {
if (expression instanceof J.TypeCast) {
expression = ((J.TypeCast) expression).getExpression();
}
if (!(expression instanceof J.Identifier)) {
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,74 @@ void test() throws RuntimeException {
);
}

@Test
void jMockitExpectationsToMockitoWhenClassArgumentMatcher() {
//language=java
rewriteRun(
java(
"""
import java.util.List;

class MyObject {
public String getSomeField(List<String> input) {
return "X";
}
}
"""
),
java(
"""
import java.util.ArrayList;
import java.util.List;

import mockit.Expectations;
import mockit.Mocked;
import mockit.integration.junit5.JMockitExtension;
import org.junit.jupiter.api.extension.ExtendWith;

import static org.junit.jupiter.api.Assertions.assertNotNull;

@ExtendWith(JMockitExtension.class)
class MyTest {
@Mocked
MyObject myObject;

void test() {
new Expectations() {{
myObject.getSomeField((List<String>) any);
result = null;
}};
assertNotNull(myObject.getSomeField(new ArrayList<>()));
}
}
""",
"""
import java.util.ArrayList;
import java.util.List;

import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;

import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.mockito.Mockito.anyList;
import static org.mockito.Mockito.when;

@ExtendWith(MockitoExtension.class)
class MyTest {
@Mock
MyObject myObject;

void test() {
when(myObject.getSomeField(anyList())).thenReturn(null);
assertNotNull(myObject.getSomeField(new ArrayList<>()));
}
}
"""
)
);
}

@Test
void jMockitExpectationsToMockitoWhenMultipleStatements() {
//language=java
Expand All @@ -480,18 +548,18 @@ public void doSomething() {}
import mockit.Mocked;
import mockit.integration.junit5.JMockitExtension;
import org.junit.jupiter.api.extension.ExtendWith;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;

@ExtendWith(JMockitExtension.class)
class MyTest {
@Mocked
MyObject myObject;

@Mocked
MyObject myOtherObject;

void test() {
new Expectations() {{
myObject.getSomeIntField();
Expand Down