diff --git a/rewrite-java-test/src/test/java/org/openrewrite/java/search/HasJavaVersionTest.java b/rewrite-java-test/src/test/java/org/openrewrite/java/search/HasJavaVersionTest.java index 86b98ac5bbe..dabed976d9b 100644 --- a/rewrite-java-test/src/test/java/org/openrewrite/java/search/HasJavaVersionTest.java +++ b/rewrite-java-test/src/test/java/org/openrewrite/java/search/HasJavaVersionTest.java @@ -21,7 +21,7 @@ import org.openrewrite.test.RewriteTest; import static org.openrewrite.java.Assertions.java; -import static org.openrewrite.java.Assertions.version; +import static org.openrewrite.java.Assertions.javaVersion; import static org.openrewrite.test.SourceSpecs.text; class HasJavaVersionTest implements RewriteTest { @@ -31,18 +31,16 @@ class HasJavaVersionTest implements RewriteTest { void matches(String version) { rewriteRun( spec -> spec.recipe(new HasJavaVersion(version, false)), - version( - java( - """ - class Test { - } - """, - """ - /*~~>*/class Test { - } - """ - ), - 11 + java( + """ + class Test { + } + """, + """ + /*~~>*/class Test { + } + """, + spec -> spec.markers(javaVersion(11)) ) ); } @@ -52,14 +50,12 @@ class Test { void noMatch(String version) { rewriteRun( spec -> spec.recipe(new HasJavaVersion(version, false)), - version( - java( - """ - class Test { - } - """ - ), - 17 + java( + """ + class Test { + } + """, + spec -> spec.markers(javaVersion(17)) ) ); } @@ -82,4 +78,22 @@ void declarativePrecondition() { ); } + @Test + void declarativePreconditionMatch() { + rewriteRun( + spec -> spec.recipeFromYaml(""" + --- + type: specs.openrewrite.org/v1beta/recipe + name: org.openrewrite.PreconditionTest + preconditions: + - org.openrewrite.java.search.HasJavaVersion: + version: 11 + recipeList: + - org.openrewrite.text.ChangeText: + toText: 2 + """, "org.openrewrite.PreconditionTest"), + text("1", "2", spec -> spec.markers(javaVersion(11))) + ); + } + } diff --git a/rewrite-java/src/main/java/org/openrewrite/java/search/HasJavaVersion.java b/rewrite-java/src/main/java/org/openrewrite/java/search/HasJavaVersion.java index 55deba66f8c..d85064d2316 100644 --- a/rewrite-java/src/main/java/org/openrewrite/java/search/HasJavaVersion.java +++ b/rewrite-java/src/main/java/org/openrewrite/java/search/HasJavaVersion.java @@ -19,10 +19,7 @@ import lombok.Value; import org.openrewrite.*; import org.openrewrite.internal.lang.Nullable; -import org.openrewrite.java.JavaIsoVisitor; import org.openrewrite.java.marker.JavaVersion; -import org.openrewrite.java.tree.J; -import org.openrewrite.java.tree.JavaSourceFile; import org.openrewrite.marker.SearchResult; import org.openrewrite.semver.Semver; import org.openrewrite.semver.VersionComparator; @@ -69,20 +66,19 @@ public Validated validate() { @Override public TreeVisitor getVisitor() { VersionComparator versionComparator = requireNonNull(Semver.validate(version, null).getValue()); - return new JavaIsoVisitor() { + return new TreeVisitor() { @Override - public J visit(@Nullable Tree tree, ExecutionContext ctx) { - if (tree instanceof JavaSourceFile) { - JavaSourceFile cu = (JavaSourceFile) requireNonNull(tree); - return cu.getMarkers().findFirst(JavaVersion.class) + public Tree visit(@Nullable Tree tree, ExecutionContext ctx) { + if (tree != null) { + return tree.getMarkers().findFirst(JavaVersion.class) .filter(version -> versionComparator.isValid(null, Integer.toString( Boolean.TRUE.equals(checkTargetCompatibility) ? version.getMajorReleaseVersion() : version.getMajorVersion()))) - .map(version -> SearchResult.found(cu)) - .orElse(cu); + .map(version -> SearchResult.found(tree)) + .orElse(tree); } - return (J) tree; + return tree; } }; }