From 72793ae979620bd699ddc6e34904d6162550db8d Mon Sep 17 00:00:00 2001 From: SiBorea <108953913+SiBorea@users.noreply.github.com> Date: Thu, 5 Dec 2024 19:46:11 +0800 Subject: [PATCH] Fix AddImport match variable's name (#4747) * Fix AddImport match variable's name * Apply formatter * Run a single test * Simplify if/else --------- Co-authored-by: Tim te Beek --- .../org/openrewrite/java/AddImportTest.java | 291 ++++++++++-------- .../java/org/openrewrite/java/AddImport.java | 26 +- 2 files changed, 183 insertions(+), 134 deletions(-) diff --git a/rewrite-java-test/src/test/java/org/openrewrite/java/AddImportTest.java b/rewrite-java-test/src/test/java/org/openrewrite/java/AddImportTest.java index e5b6c404e5f..d82981f8863 100644 --- a/rewrite-java-test/src/test/java/org/openrewrite/java/AddImportTest.java +++ b/rewrite-java-test/src/test/java/org/openrewrite/java/AddImportTest.java @@ -37,9 +37,7 @@ import static java.util.Collections.emptySet; import static java.util.Collections.singletonList; -import static org.openrewrite.java.Assertions.addTypesToSourceSet; -import static org.openrewrite.java.Assertions.java; -import static org.openrewrite.java.Assertions.srcMainJava; +import static org.openrewrite.java.Assertions.*; import static org.openrewrite.test.RewriteTest.toRecipe; @SuppressWarnings("rawtypes") @@ -113,7 +111,7 @@ void dontDuplicateImports() { """ import org.springframework.http.HttpStatus; import org.springframework.http.HttpStatus.Series; - + class A {} """ ) @@ -139,7 +137,7 @@ class A {} import org.junit.jupiter.api.*; import org.slf4j.Logger; import org.slf4j.LoggerFactory; - + class A {} """ ) @@ -155,16 +153,16 @@ void dontDuplicateImports3() { """ import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; - + import java.util.List; class A {} """, """ import static org.junit.jupiter.api.Assertions.*; - + import java.util.List; - + class A {} """ ) @@ -178,7 +176,7 @@ void dontImportYourself() { java( """ package com.myorg; - + class A { } """ @@ -241,7 +239,7 @@ void dontImportFromSamePackage() { java( """ package com.myorg; - + class B { } """ @@ -249,7 +247,7 @@ class B { java( """ package com.myorg; - + class A { } """ @@ -309,7 +307,7 @@ void addNamedImport() { java("class A {}", """ import java.util.List; - + class A {} """ ) @@ -323,7 +321,7 @@ void doNotAddImportIfNotReferenced() { java( """ package a; - + class A {} """ ) @@ -337,22 +335,22 @@ void addImportInsertsNewMiddleBlock() { java( """ package a; - + import com.sun.naming.*; - + import static java.util.Collections.*; - + class A {} """, """ package a; - + import com.sun.naming.*; - + import java.util.List; - + import static java.util.Collections.*; - + class A {} """ ) @@ -366,14 +364,14 @@ void addFirstImport() { java( """ package a; - + class A {} """, """ package a; - + import java.util.List; - + class A {} """ ) @@ -385,22 +383,22 @@ class A {} void addImportIfReferenced() { rewriteRun( spec -> spec.recipe(toRecipe(() -> - new JavaIsoVisitor<>() { - @Override - public J.ClassDeclaration visitClassDeclaration(J.ClassDeclaration classDecl, ExecutionContext ctx) { - J.ClassDeclaration c = super.visitClassDeclaration(classDecl, ctx); - maybeAddImport("java.math.BigDecimal"); - maybeAddImport("java.math.RoundingMode"); - return JavaTemplate.builder("BigDecimal d = BigDecimal.valueOf(1).setScale(1, RoundingMode.HALF_EVEN);") - .imports("java.math.BigDecimal", "java.math.RoundingMode") - .build() - .apply( - updateCursor(c), - c.getBody().getCoordinates().lastStatement() - ); - } - } - ).withMaxCycles(1)), + new JavaIsoVisitor<>() { + @Override + public J.ClassDeclaration visitClassDeclaration(J.ClassDeclaration classDecl, ExecutionContext ctx) { + J.ClassDeclaration c = super.visitClassDeclaration(classDecl, ctx); + maybeAddImport("java.math.BigDecimal"); + maybeAddImport("java.math.RoundingMode"); + return JavaTemplate.builder("BigDecimal d = BigDecimal.valueOf(1).setScale(1, RoundingMode.HALF_EVEN);") + .imports("java.math.BigDecimal", "java.math.RoundingMode") + .build() + .apply( + updateCursor(c), + c.getBody().getCoordinates().lastStatement() + ); + } + } + ).withMaxCycles(1)), java( """ package a; @@ -410,10 +408,10 @@ class A { """, """ package a; - + import java.math.BigDecimal; import java.math.RoundingMode; - + class A { BigDecimal d = BigDecimal.valueOf(1).setScale(1, RoundingMode.HALF_EVEN); } @@ -429,7 +427,7 @@ void doNotAddWildcardImportIfNotReferenced() { java( """ package a; - + class A {} """ ) @@ -443,7 +441,7 @@ void lastImportWhenFirstClassDeclarationHasJavadoc() { java( """ import java.util.List; - + /** * My type */ @@ -451,9 +449,9 @@ class A {} """, """ import java.util.List; - + import static java.util.Collections.*; - + /** * My type */ @@ -474,9 +472,9 @@ class A {} """, """ package a; - + import java.util.List; - + class A {} """ ) @@ -516,18 +514,18 @@ public class B {} java( """ package a; - + import c.C0; import c.c.C1; import c.c.c.C2; - + class A {} """, String.format(""" package a; - + %s - + class A {} """, expectedImports.stream().map(i -> "import " + i + ";").collect(Collectors.joining("\n")) @@ -549,7 +547,7 @@ void doNotAddImportIfAlreadyExists() { java( """ package a; - + import java.util.List; class A {} """ @@ -564,7 +562,7 @@ void doNotAddImportIfCoveredByStarImport() { java( """ package a; - + import java.util.*; class A {} """ @@ -595,17 +593,17 @@ void addNamedImportIfStarStaticImportExists() { java( """ package a; - + import static java.util.List.*; class A {} """, """ package a; - + import java.util.List; - + import static java.util.List.*; - + class A {} """ ) @@ -623,9 +621,9 @@ class A {} """, """ import java.util.*; - + import static java.util.Collections.emptyList; - + class A {} """ ) @@ -640,7 +638,7 @@ void addStaticImportForUnreferencedField() { java( """ package mycompany; - + public class Type { public static String FIELD; } @@ -650,7 +648,7 @@ public class Type { "class A {}", """ import static mycompany.Type.FIELD; - + class A {} """ ) @@ -683,17 +681,17 @@ public J.ClassDeclaration visitClassDeclaration(J.ClassDeclaration classDecl, Ex java( """ public class A { - + } """, """ import java.time.temporal.ChronoUnit; - + import static java.time.temporal.ChronoUnit.MILLIS; - + public class A { ChronoUnit unit = MILLIS; - + } """ ) @@ -708,9 +706,9 @@ void dontAddImportToStaticFieldWithNamespaceConflict() { java( """ package a; - + import java.time.temporal.ChronoUnit; - + class A { static final int MILLIS = 1; ChronoUnit unit = ChronoUnit.MILLIS; @@ -727,7 +725,7 @@ void dontAddStaticWildcardImportIfNotReferenced() { java( """ package a; - + class A {} """ ) @@ -749,9 +747,9 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu java( """ package a; - + import java.util.List; - + class A { public A() { List list = java.util.Collections.emptyList(); @@ -760,11 +758,11 @@ public A() { """, """ package a; - + import java.util.List; - + import static java.util.Collections.emptyList; - + class A { public A() { List list = emptyList(); @@ -775,6 +773,49 @@ public A() { ); } + @Test + void addNamedStaticImportWhenReferenced2() { + rewriteRun( + spec -> spec.recipe(toRecipe(() -> new JavaIsoVisitor<>() { + @Override + public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, ExecutionContext executionContext) { + method = super.visitMethodDeclaration(method, executionContext); + method = JavaTemplate.builder("List list = new ArrayList<>();") + .imports("java.util.ArrayList", "java.util.List") + .staticImports("java.util.Calendar.Builder") + .build() + .apply(getCursor(), method.getBody().getCoordinates().firstStatement()); + maybeAddImport("java.util.ArrayList"); + maybeAddImport("java.util.List"); + maybeAddImport("java.util.Calendar", "Builder"); + return method; + } + }).withMaxCycles(1)), + java( + """ + import static java.util.Calendar.Builder; + + class A { + public A() { + } + } + """, + """ + import java.util.ArrayList; + import java.util.List; + + import static java.util.Calendar.Builder; + + class A { + public A() { + List list = new ArrayList<>(); + } + } + """ + ) + ); + } + @Test void doNotAddNamedStaticImportIfNotReferenced() { rewriteRun( @@ -782,7 +823,7 @@ void doNotAddNamedStaticImportIfNotReferenced() { java( """ package a; - + class A {} """ ) @@ -807,9 +848,9 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu java( """ package a; - + import java.util.List; - + class A { public A() { List list = java.util.Collections.emptyList(); @@ -818,11 +859,11 @@ public A() { """, """ package a; - + import java.util.List; - + import static java.util.Collections.*; - + class A { public A() { List list = emptyList(); @@ -883,14 +924,14 @@ public class C { """ import foo.B; import foo.C; - + import java.util.Collections; import java.util.List; import java.util.HashSet; import java.util.HashMap; import java.util.Map; import java.util.Set; - + class A { B b = new B(); C c = new C(); @@ -903,7 +944,7 @@ class A { """ import foo.B; import foo.C; - + import java.util.*; class A { @@ -928,15 +969,15 @@ void addImportWhenDuplicatesExist() { """ import javax.ws.rs.Path; import javax.ws.rs.Path; - + class A {} """, """ import org.springframework.http.MediaType; - + import javax.ws.rs.Path; import javax.ws.rs.Path; - + class A {} """ ) @@ -952,15 +993,15 @@ void unorderedImportsWithNewBlock() { """ import org.foo.B; import org.foo.A; - + class A {} """, """ import org.foo.B; import org.foo.A; - + import java.time.Duration; - + class A {} """ ) @@ -981,7 +1022,7 @@ void doNotFoldNormalImportWithNamespaceConflict() { import java.util.Collections; import java.util.Map; import java.util.Set; - + @SuppressWarnings("ALL") class Test { List list; @@ -994,7 +1035,7 @@ class Test { import java.util.List; import java.util.Map; import java.util.Set; - + @SuppressWarnings("ALL") class Test { List list; @@ -1136,7 +1177,7 @@ void noImportLayout() { """, """ import java.util.List; - + import static java.util.Collections.*; """ ) @@ -1154,9 +1195,9 @@ class A {} """.replace("\n", "\r\n"), """ package a; - + import java.util.List; - + class A {} """.replace("\n", "\r\n") ) @@ -1170,17 +1211,17 @@ void crlfNewLinesWithPreviousImports() { java( """ package a; - + import java.util.Set; - + class A {} """.replace("\n", "\r\n"), """ package a; - + import java.util.List; import java.util.Set; - + class A {} """.replace("\n", "\r\n") ) @@ -1194,13 +1235,13 @@ void crlfNewLinesWithPreviousImportsNoPackage() { java( """ import java.util.Set; - + class A {} """.replace("\n", "\r\n"), """ import java.util.List; import java.util.Set; - + class A {} """.replace("\n", "\r\n") ) @@ -1214,13 +1255,13 @@ void crlfNewLinesWithPreviousImportsNoClass() { java( """ package a; - + import java.util.Arrays; import java.util.Set; """.replace("\n", "\r\n"), """ package a; - + import java.util.Arrays; import java.util.List; import java.util.Set; @@ -1269,10 +1310,10 @@ void crlfNewLinesInComments() { * limitations under the License. */ """.replace("\n", "\r\n") + - """ - import java.util.Arrays; - import java.util.Set; - """, + """ + import java.util.Arrays; + import java.util.Set; + """, """ /* * Copyright 2023 the original author or authors. @@ -1305,31 +1346,31 @@ void crlfNewLinesInJavadoc() { """ import java.util.Arrays; import java.util.Set; - + """ + - """ - /** - * Copyright 2023 the original author or authors. - *

- * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - *

- * https://www.apache.org/licenses/LICENSE-2.0 - *

- * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - """.replace("\n", "\r\n") + - "class Foo {}", + """ + /** + * Copyright 2023 the original author or authors. + *

+ * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * https://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + """.replace("\n", "\r\n") + + "class Foo {}", """ import java.util.Arrays; import java.util.List; import java.util.Set; - + /** * Copyright 2023 the original author or authors. *

diff --git a/rewrite-java/src/main/java/org/openrewrite/java/AddImport.java b/rewrite-java/src/main/java/org/openrewrite/java/AddImport.java index 3feff736d1c..9aecd66f89f 100644 --- a/rewrite-java/src/main/java/org/openrewrite/java/AddImport.java +++ b/rewrite-java/src/main/java/org/openrewrite/java/AddImport.java @@ -107,7 +107,7 @@ public AddImport(@Nullable String packageName, String typeName, @Nullable String // No need to add imports if the class to import is in java.lang, or if the classes are within the same package if (("java.lang".equals(packageName) && StringUtils.isBlank(member)) || (cu.getPackageDeclaration() != null && - packageName.equals(cu.getPackageDeclaration().getExpression().printTrimmed(getCursor())))) { + packageName.equals(cu.getPackageDeclaration().getExpression().printTrimmed(getCursor())))) { return cu; } @@ -119,10 +119,10 @@ public AddImport(@Nullable String packageName, String typeName, @Nullable String String ending = i.getQualid().getSimpleName(); if (member == null) { return !i.isStatic() && i.getPackageName().equals(packageName) && - (ending.equals(typeName) || "*".equals(ending)); + (ending.equals(typeName) || "*".equals(ending)); } return i.isStatic() && i.getTypeName().equals(fullyQualifiedName) && - (ending.equals(member) || "*".equals(ending)); + (ending.equals(member) || "*".equals(ending)); })) { return cu; } @@ -133,7 +133,7 @@ public AddImport(@Nullable String packageName, String typeName, @Nullable String new JLeftPadded<>(member == null ? Space.EMPTY : Space.SINGLE_SPACE, member != null, Markers.EMPTY), TypeTree.build(fullyQualifiedName + - (member == null ? "" : "." + member)).withPrefix(Space.SINGLE_SPACE), + (member == null ? "" : "." + member)).withPrefix(Space.SINGLE_SPACE), null); List> imports = new ArrayList<>(cu.getPadding().getImports()); @@ -214,7 +214,7 @@ private boolean hasReference(JavaSourceFile compilationUnit) { //Non-static imports, we just look for field accesses. for (NameTree t : FindTypes.find(compilationUnit, fullyQualifiedName)) { if ((!(t instanceof J.FieldAccess) || !((J.FieldAccess) t).isFullyQualifiedClassReference(fullyQualifiedName)) && - isTypeReference(t)) { + isTypeReference(t)) { return true; } } @@ -226,7 +226,7 @@ private boolean hasReference(JavaSourceFile compilationUnit) { if (invocation instanceof J.MethodInvocation) { J.MethodInvocation mi = (J.MethodInvocation) invocation; if (mi.getSelect() == null && - ("*".equals(member) || mi.getName().getSimpleName().equals(member))) { + ("*".equals(member) || mi.getName().getSimpleName().equals(member))) { return true; } } @@ -239,11 +239,18 @@ private boolean hasReference(JavaSourceFile compilationUnit) { } private class FindStaticFieldAccess extends JavaIsoVisitor> { + private boolean checkIsOfClassType(@Nullable JavaType type, String fullyQualifiedName) { + if (isOfClassType(type, fullyQualifiedName)) { + return true; + } + return type instanceof JavaType.Class && isOfClassType(((JavaType.Class) type).getOwningClass(), fullyQualifiedName); + } + @Override public J.CompilationUnit visitCompilationUnit(J.CompilationUnit cu, AtomicReference found) { // If the type isn't used there's no need to proceed further for (JavaType.Variable varType : cu.getTypesInUse().getVariables()) { - if (varType.getName().equals(member) && isOfClassType(varType.getType(), fullyQualifiedName)) { + if (checkIsOfClassType(varType.getType(), fullyQualifiedName)) { return super.visitCompilationUnit(cu, found); } } @@ -253,8 +260,9 @@ public J.CompilationUnit visitCompilationUnit(J.CompilationUnit cu, AtomicRefere @Override public J.Identifier visitIdentifier(J.Identifier identifier, AtomicReference found) { assert getCursor().getParent() != null; - if (identifier.getSimpleName().equals(member) && isOfClassType(identifier.getType(), fullyQualifiedName) && - !(getCursor().getParent().firstEnclosingOrThrow(J.class) instanceof J.FieldAccess)) { + if (identifier.getSimpleName().equals(member) && + checkIsOfClassType(identifier.getType(), fullyQualifiedName) && + !(getCursor().getParent().firstEnclosingOrThrow(J.class) instanceof J.FieldAccess)) { found.set(true); } return identifier;