Skip to content

Commit

Permalink
SemanticallyEqual: Fix bug in visitFieldAccess() (#3827)
Browse files Browse the repository at this point in the history
* `SemanticallyEqual`: Fix bug in `visitFieldAccess()`

When comparing `J.FieldAccess` instances, the targets also needs to be compared.

* Start adding some more `SemanticallyEqualTest` test cases

* Polish

* Polish

* Add more tests

* Add more tests
  • Loading branch information
knutwannheden authored Dec 18, 2023
1 parent 8e5c89c commit 3659f55
Show file tree
Hide file tree
Showing 3 changed files with 220 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,28 @@ public class A {
);
}

@Test
void differentFieldAccesses() {
rewriteRun(
java(
"""
public class A {
Object f = null;
class B extends A {
boolean m(Object o) {
B other = (B) o;
if (this.f == null || other.f == null) {
return true;
}
return false;
}
}
}
"""
)
);
}

@Test
void preserveComments() {
rewriteRun(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,24 @@

import org.intellij.lang.annotations.Language;
import org.junit.jupiter.api.Test;
import org.junitpioneer.jupiter.cartesian.CartesianTest;
import org.openrewrite.java.JavaIsoVisitor;
import org.openrewrite.java.JavaParser;
import org.openrewrite.java.tree.J;

import java.util.HashMap;
import java.util.Map;
import java.util.Objects;

import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertTrue;

public class SemanticallyEqualTest {

private final JavaParser javaParser = JavaParser.fromJavaVersion().build();

@Test
void compareAbstractMethods() {
void abstractMethods() {
assertEqualToSelf(
"""
abstract class A {
Expand All @@ -38,7 +45,7 @@ abstract class A {
}

@Test
void compareClassModifierLists() {
void classModifierLists() {
assertEqual(
"""
public abstract class A {
Expand All @@ -52,7 +59,7 @@ abstract public class A {
}

@Test
void compareLiterals() {
void literals() {
assertEqual(
"""
class A {
Expand All @@ -67,6 +74,141 @@ class A {
);
}

@Test
void classLiterals() {
assertExpressionsEqual(
"""
import java.util.UUID;
class T {
Class<?> a = java.util.UUID.class;
Class<?> b = UUID.class;
}
"""
);
assertExpressionsNotEqual(
"""
import java.util.UUID;
class T {
Class<UUID> u = UUID.class;
Class<UUID> a = UUID.class;
Class<UUID> b = this.u;
}
"""
);
}

@CartesianTest
void staticFieldAccesses(@CartesianTest.Values(strings = {
"java.util.regex.Pattern.CASE_INSENSITIVE",
"Pattern.CASE_INSENSITIVE",
"CASE_INSENSITIVE",
}) String a, @CartesianTest.Values(strings = {
"java.util.regex.Pattern.CASE_INSENSITIVE",
"Pattern.CASE_INSENSITIVE",
"CASE_INSENSITIVE",
}) String b) {
assertExpressionsEqual(
"import java.util.regex.Pattern; import static java.util.regex.Pattern.CASE_INSENSITIVE; class T { int a = " + a + "; int b = " + b + "; }"
);
}

@CartesianTest
void staticMethodAccesses(@CartesianTest.Values(strings = {
"java.util.regex.Pattern.compile",
"Pattern.compile",
"compile",
}) String a, @CartesianTest.Values(strings = {
"java.util.regex.Pattern.compile",
"Pattern.compile",
"compile",
}) String b) {
assertExpressionsEqual(
"import java.util.regex.Pattern; import static java.util.regex.Pattern.compile; class T { Pattern a = " + a + "(\"\"); Pattern b = " + b + "(\"\"); }"
);
}

@Test
void typeCasts() {
assertExpressionsEqual(
"""
class T {
Number a = (java.lang.Number) "";
Number b = (java.lang.Number) "";
}
"""
);
assertExpressionsEqual(
"""
class T {
Number a = (java.lang.Number) "";
Number b = (Number) "";
}
"""
);
assertExpressionsEqual(
"""
import java.util.List;
import java.util.UUID;
class T {
Number a = (List<UUID>) "";
Number b = (List<java.util.UUID>) "";
}
"""
);
assertExpressionsEqual(
"""
import java.util.List;
import java.util.UUID;
class T {
Number a = (List<java.util.UUID>) "";
Number b = (java.util.List<UUID>) "";
}
"""
);
}

@Test
void fieldAccesses() {
assertExpressionsEqual(
"""
class T {
int n = 1;
int a = T.this.n;
int b = T.this.n;
}
"""
);
assertExpressionsEqual(
"""
class T {
int n = 1;
int a = T.this.n;
int b = this.n;
}
"""
);
assertExpressionsEqual(
"""
class T {
int n = 1;
int a = T.this.n;
int b = n;
}
"""
);
assertExpressionsNotEqual(
"""
class T {
int n = 1;
void m(int n) {
int a = T.this.n;
int b = n;
}
}
"""
);
}

private void assertEqualToSelf(@Language("java") String a) {
assertEqual(a, a);
}
Expand All @@ -76,9 +218,50 @@ private void assertEqual(@Language("java") String a, @Language("java") String b)
J.CompilationUnit cua = (J.CompilationUnit) javaParser.parse(a).findFirst().get();
javaParser.reset();
J.CompilationUnit cub = (J.CompilationUnit) javaParser.parse(b).findFirst().get();
javaParser.reset();
assertEqual(cua, cub);
}

@SuppressWarnings("OptionalGetWithoutIsPresent")
private void assertExpressionsEqual(@Language(value = "java") String source) {
J.CompilationUnit cu = (J.CompilationUnit) javaParser.parse(source).findFirst().get();
javaParser.reset();

JavaIsoVisitor<Map<String, J.VariableDeclarations.NamedVariable>> visitor = new JavaIsoVisitor<>() {
@Override
public J.VariableDeclarations.NamedVariable visitVariable(J.VariableDeclarations.NamedVariable variable, Map<String, J.VariableDeclarations.NamedVariable> map) {
map.put(variable.getSimpleName(), variable);
return super.visitVariable(variable, map);
}
};

Map<String, J.VariableDeclarations.NamedVariable> result = visitor.reduce(cu, new HashMap<>());
assertThat(SemanticallyEqual.areEqual(
Objects.requireNonNull(result.get("a").getInitializer()),
Objects.requireNonNull(result.get("b").getInitializer()))
).isTrue();
}

@SuppressWarnings("OptionalGetWithoutIsPresent")
private void assertExpressionsNotEqual(@Language(value = "java") String source) {
J.CompilationUnit cu = (J.CompilationUnit) javaParser.parse(source).findFirst().get();
javaParser.reset();

JavaIsoVisitor<Map<String, J.VariableDeclarations.NamedVariable>> visitor = new JavaIsoVisitor<>() {
@Override
public J.VariableDeclarations.NamedVariable visitVariable(J.VariableDeclarations.NamedVariable variable, Map<String, J.VariableDeclarations.NamedVariable> map) {
map.put(variable.getSimpleName(), variable);
return super.visitVariable(variable, map);
}
};

Map<String, J.VariableDeclarations.NamedVariable> result = visitor.reduce(cu, new HashMap<>());
assertThat(SemanticallyEqual.areEqual(
Objects.requireNonNull(result.get("a").getInitializer()),
Objects.requireNonNull(result.get("b").getInitializer()))
).isFalse();
}

private void assertEqual(J a, J b) {
assertTrue(SemanticallyEqual.areEqual(a, b));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -575,18 +575,27 @@ public J.EnumValueSet visitEnumValueSet(J.EnumValueSet enums, J j) {
public J.FieldAccess visitFieldAccess(J.FieldAccess fieldAccess, J j) {
if (isEqual.get()) {
if (!(j instanceof J.FieldAccess)) {
if (!(j instanceof J.Identifier) || !TypeUtils.isOfType(fieldAccess.getName().getFieldType(), ((J.Identifier) j).getFieldType())) {
if (!(j instanceof J.Identifier) ||
!TypeUtils.isOfType(fieldAccess.getName().getFieldType(), ((J.Identifier) j).getFieldType()) ||
!fieldAccess.getSimpleName().equals(((J.Identifier) j).getSimpleName())) {
isEqual.set(false);
}
return fieldAccess;
}

J.FieldAccess compareTo = (J.FieldAccess) j;
if (!TypeUtils.isOfType(fieldAccess.getType(), compareTo.getType())
|| !TypeUtils.isOfType(fieldAccess.getName().getFieldType(), compareTo.getName().getFieldType())) {
if (fieldAccess.getType() instanceof JavaType.Unknown && compareTo.getType() instanceof JavaType.Unknown) {
if (!fieldAccess.getSimpleName().equals(compareTo.getSimpleName())) {
isEqual.set(false);
return fieldAccess;
}
} else if (!TypeUtils.isOfType(fieldAccess.getType(), compareTo.getType())
|| !TypeUtils.isOfType(fieldAccess.getName().getFieldType(), compareTo.getName().getFieldType())) {
isEqual.set(false);
return fieldAccess;
}

visit(fieldAccess.getTarget(), compareTo.getTarget());
}
return fieldAccess;
}
Expand Down

0 comments on commit 3659f55

Please sign in to comment.