diff --git a/src/main/java/org/openrewrite/java/testing/junit5/RemoveTryCatchFailBlocks.java b/src/main/java/org/openrewrite/java/testing/junit5/RemoveTryCatchFailBlocks.java index fe15442c0..ff1c9cd26 100644 --- a/src/main/java/org/openrewrite/java/testing/junit5/RemoveTryCatchFailBlocks.java +++ b/src/main/java/org/openrewrite/java/testing/junit5/RemoveTryCatchFailBlocks.java @@ -25,13 +25,12 @@ import org.openrewrite.java.JavaVisitor; import org.openrewrite.java.MethodMatcher; import org.openrewrite.java.search.UsesMethod; -import org.openrewrite.java.tree.Expression; -import org.openrewrite.java.tree.J; -import org.openrewrite.java.tree.Statement; -import org.openrewrite.java.tree.TypeUtils; +import org.openrewrite.java.tree.*; import java.util.Collections; +import java.util.Objects; import java.util.Set; +import java.util.stream.Stream; public class RemoveTryCatchFailBlocks extends Recipe { private static final MethodMatcher ASSERT_FAIL_NO_ARG = new MethodMatcher("org.junit.jupiter.api.Assertions fail()"); @@ -120,6 +119,7 @@ private static boolean isException(Expression expression) { @NotNull private J.MethodInvocation replaceWithAssertDoesNotThrowWithoutStringExpression(ExecutionContext ctx, J.Try try_) { maybeAddImport("org.junit.jupiter.api.Assertions"); + maybeRemoveCatchTypes(try_); return JavaTemplate.builder("Assertions.assertDoesNotThrow(() -> #{any()})") .contextSensitive() .imports("org.junit.jupiter.api.Assertions") @@ -128,10 +128,22 @@ private J.MethodInvocation replaceWithAssertDoesNotThrowWithoutStringExpression( .apply(getCursor(), try_.getCoordinates().replace(), try_.getBody()); } + private void maybeRemoveCatchTypes(J.Try try_) { + JavaType catchType = try_.getCatches().get(0).getParameter().getTree().getType(); + if (catchType != null) { + Stream.of(catchType) + .flatMap(t -> t instanceof JavaType.MultiCatch ? ((JavaType.MultiCatch) t).getThrowableTypes().stream() : Stream.of(t)) + .map(TypeUtils::asFullyQualified) + .filter(Objects::nonNull) + .forEach(t -> maybeRemoveImport(t.getFullyQualifiedName())); + } + } + @NotNull private J.MethodInvocation replaceWithAssertDoesNotThrowWithStringExpression(ExecutionContext ctx, J.Try try_, Expression failCallArgument) { // Retain the fail(String) call argument maybeAddImport("org.junit.jupiter.api.Assertions"); + maybeRemoveCatchTypes(try_); return JavaTemplate.builder("Assertions.assertDoesNotThrow(() -> #{any()}, #{any(String)})") .contextSensitive() .imports("org.junit.jupiter.api.Assertions") diff --git a/src/test/java/org/openrewrite/java/testing/junit5/RemoveTryCatchFailBlocksTest.java b/src/test/java/org/openrewrite/java/testing/junit5/RemoveTryCatchFailBlocksTest.java index 063c58a66..6f6ec1dae 100644 --- a/src/test/java/org/openrewrite/java/testing/junit5/RemoveTryCatchFailBlocksTest.java +++ b/src/test/java/org/openrewrite/java/testing/junit5/RemoveTryCatchFailBlocksTest.java @@ -535,6 +535,7 @@ public String anotherMethod() { ); } + @SuppressWarnings("resource") @Test void failHasBinaryWithGetMessage() { //language=java @@ -543,13 +544,15 @@ void failHasBinaryWithGetMessage() { """ import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; + import java.io.FileOutputStream; + import java.io.IOException; class MyTest { @Test public void testMethod() { try { - int divide = 50 / 0; - } catch (Exception e) { + FileOutputStream outputStream = new FileOutputStream("test.txt"); + } catch (IOException | RuntimeException e) { Assertions.fail("The error is: " + e.getMessage()); } } @@ -558,12 +561,13 @@ public void testMethod() { """ import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; + import java.io.FileOutputStream; class MyTest { @Test public void testMethod() { Assertions.assertDoesNotThrow(() -> { - int divide = 50 / 0; + FileOutputStream outputStream = new FileOutputStream("test.txt"); }, "The error is: "); } }