Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support intersection type casts #3652

Merged
merged 8 commits into from
Nov 21, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,14 @@ public J visitInstanceOf(InstanceOfTree node, Space fmt) {
typeMapping.type(node));
}

@Override
public J visitIntersectionType(IntersectionTypeTree node, Space fmt) {
JContainer<TypeTree> bounds = node.getBounds().isEmpty() ? null :
JContainer.build(EMPTY,
convertAll(node.getBounds(), t -> sourceBefore("&"), noDelim), Markers.EMPTY);
return new J.IntersectionType(randomId(), fmt, Markers.EMPTY, bounds);
}

@Override
public J visitLabeledStatement(LabeledStatementTree node, Space fmt) {
skip(node.getLabel().toString());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.openrewrite.java.tree.TypeUtils;

import javax.lang.model.type.NullType;
import javax.lang.model.type.TypeMirror;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.util.ArrayList;
Expand Down Expand Up @@ -59,7 +60,9 @@ public JavaType type(@Nullable com.sun.tools.javac.code.Type type) {
return existing;
}

if (type instanceof Type.ClassType) {
if (type instanceof Type.IntersectionClassType) {
return intersectionType((Type.IntersectionClassType) type, signature);
} else if (type instanceof Type.ClassType) {
return classType((Type.ClassType) type, signature);
} else if (type instanceof Type.TypeVar) {
return generic((Type.TypeVar) type, signature);
Expand All @@ -78,6 +81,19 @@ public JavaType type(@Nullable com.sun.tools.javac.code.Type type) {
throw new UnsupportedOperationException("Unknown type " + type.getClass().getName());
}

private JavaType intersectionType(Type.IntersectionClassType type, String signature) {
JavaType.Intersection intersection = new JavaType.Intersection(null);
typeCache.put(signature, intersection);
JavaType[] types = new JavaType[type.getBounds().size()];
List<? extends TypeMirror> bounds = type.getBounds();
for (int i = 0; i < bounds.size(); i++) {
TypeMirror bound = bounds.get(i);
types[i] = type((Type) bound);
}
intersection.unsafeSet(types);
return intersection;
}

private JavaType array(Type type, String signature) {
JavaType.Array arr = new JavaType.Array(null, null);
typeCache.put(signature, arr);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.openrewrite.java.tree.JavaType;

import javax.lang.model.type.NullType;
import javax.lang.model.type.TypeMirror;
import java.util.HashSet;
import java.util.Set;
import java.util.StringJoiner;
Expand All @@ -39,6 +40,13 @@ public String signature(@Nullable Object t) {
private String signature(@Nullable Type type) {
if (type == null || type instanceof Type.UnknownType || type instanceof NullType) {
return "{undefined}";
} else if (type instanceof Type.IntersectionClassType) {
Type.IntersectionClassType intersectionClassType = (Type.IntersectionClassType) type;
StringJoiner joiner = new StringJoiner(" & ");
for (TypeMirror typeArg : intersectionClassType.getBounds()) {
joiner.add(signature(typeArg));
}
return joiner.toString();
} else if (type instanceof Type.ClassType) {
try {
return ((Type.ClassType) type).typarams_field != null && ((Type.ClassType) type).typarams_field.length() > 0 ? parameterizedSignature(type) : classSignature(type);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,14 @@ public J visitInstanceOf(InstanceOfTree node, Space fmt) {
type);
}

@Override
public J visitIntersectionType(IntersectionTypeTree node, Space fmt) {
JContainer<TypeTree> bounds = node.getBounds().isEmpty() ? null :
JContainer.build(EMPTY,
convertAll(node.getBounds(), t -> sourceBefore("&"), noDelim), Markers.EMPTY);
return new J.IntersectionType(randomId(), fmt, Markers.EMPTY, bounds);
}

@Override
public J visitLabeledStatement(LabeledStatementTree node, Space fmt) {
skip(node.getLabel().toString());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.openrewrite.java.tree.TypeUtils;

import javax.lang.model.type.NullType;
import javax.lang.model.type.TypeMirror;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.util.ArrayList;
Expand Down Expand Up @@ -56,7 +57,9 @@ public JavaType type(@Nullable com.sun.tools.javac.code.Type type) {
return existing;
}

if (type instanceof Type.ClassType) {
if (type instanceof Type.IntersectionClassType) {
return intersectionType((Type.IntersectionClassType) type, signature);
} else if (type instanceof Type.ClassType) {
return classType((Type.ClassType) type, signature);
} else if (type instanceof Type.TypeVar) {
return generic((Type.TypeVar) type, signature);
Expand All @@ -75,6 +78,19 @@ public JavaType type(@Nullable com.sun.tools.javac.code.Type type) {
throw new UnsupportedOperationException("Unknown type " + type.getClass().getName());
}

private JavaType intersectionType(Type.IntersectionClassType type, String signature) {
JavaType.Intersection intersection = new JavaType.Intersection(null);
typeCache.put(signature, intersection);
JavaType[] types = new JavaType[type.getBounds().size()];
List<? extends TypeMirror> bounds = type.getBounds();
for (int i = 0; i < bounds.size(); i++) {
TypeMirror bound = bounds.get(i);
types[i] = type((Type) bound);
}
intersection.unsafeSet(types);
return intersection;
}

private JavaType array(Type type, String signature) {
JavaType.Array arr = new JavaType.Array(null, null);
typeCache.put(signature, arr);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.openrewrite.java.tree.JavaType;

import javax.lang.model.type.NullType;
import javax.lang.model.type.TypeMirror;
import java.util.HashSet;
import java.util.Set;
import java.util.StringJoiner;
Expand All @@ -39,6 +40,13 @@ public String signature(@Nullable Object t) {
private String signature(@Nullable Type type) {
if (type == null || type instanceof Type.UnknownType || type instanceof NullType) {
return "{undefined}";
} else if (type instanceof Type.IntersectionClassType) {
Type.IntersectionClassType intersectionClassType = (Type.IntersectionClassType) type;
StringJoiner joiner = new StringJoiner(" & ");
for (TypeMirror typeArg : intersectionClassType.getBounds()) {
joiner.add(signature(typeArg));
}
return joiner.toString();
} else if (type instanceof Type.ClassType) {
try {
return ((Type.ClassType) type).typarams_field != null && ((Type.ClassType) type).typarams_field.length() > 0 ? parameterizedSignature(type) : classSignature(type);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,14 @@ public J visitInstanceOf(InstanceOfTree node, Space fmt) {
type);
}

@Override
public J visitIntersectionType(IntersectionTypeTree node, Space fmt) {
JContainer<TypeTree> bounds = node.getBounds().isEmpty() ? null :
JContainer.build(EMPTY,
convertAll(node.getBounds(), t -> sourceBefore("&"), noDelim), Markers.EMPTY);
return new J.IntersectionType(randomId(), fmt, Markers.EMPTY, bounds);
}

@Override
public J visitLabeledStatement(LabeledStatementTree node, Space fmt) {
skip(node.getLabel().toString());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.openrewrite.java.tree.TypeUtils;

import javax.lang.model.type.NullType;
import javax.lang.model.type.TypeMirror;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.util.ArrayList;
Expand Down Expand Up @@ -56,7 +57,9 @@ public JavaType type(@Nullable com.sun.tools.javac.code.Type type) {
return existing;
}

if (type instanceof Type.ClassType) {
if (type instanceof Type.IntersectionClassType) {
return intersectionType((Type.IntersectionClassType) type, signature);
} else if (type instanceof Type.ClassType) {
return classType((Type.ClassType) type, signature);
} else if (type instanceof Type.TypeVar) {
return generic((Type.TypeVar) type, signature);
Expand All @@ -75,6 +78,19 @@ public JavaType type(@Nullable com.sun.tools.javac.code.Type type) {
throw new UnsupportedOperationException("Unknown type " + type.getClass().getName());
}

private JavaType intersectionType(Type.IntersectionClassType type, String signature) {
JavaType.Intersection intersection = new JavaType.Intersection(null);
typeCache.put(signature, intersection);
JavaType[] types = new JavaType[type.getBounds().size()];
List<? extends TypeMirror> bounds = type.getBounds();
for (int i = 0; i < bounds.size(); i++) {
TypeMirror bound = bounds.get(i);
types[i] = type((Type) bound);
}
intersection.unsafeSet(types);
return intersection;
}

private JavaType array(Type type, String signature) {
JavaType.Array arr = new JavaType.Array(null, null);
typeCache.put(signature, arr);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.openrewrite.java.tree.JavaType;

import javax.lang.model.type.NullType;
import javax.lang.model.type.TypeMirror;
import java.util.HashSet;
import java.util.Set;
import java.util.StringJoiner;
Expand All @@ -39,6 +40,13 @@ public String signature(@Nullable Object t) {
private String signature(@Nullable Type type) {
if (type == null || type instanceof Type.UnknownType || type instanceof NullType) {
return "{undefined}";
} else if (type instanceof Type.IntersectionClassType) {
Type.IntersectionClassType intersectionClassType = (Type.IntersectionClassType) type;
StringJoiner joiner = new StringJoiner(" & ");
for (TypeMirror typeArg : intersectionClassType.getBounds()) {
joiner.add(signature(typeArg));
}
return joiner.toString();
} else if (type instanceof Type.ClassType) {
try {
return ((Type.ClassType) type).typarams_field != null && ((Type.ClassType) type).typarams_field.length() > 0 ? parameterizedSignature(type) : classSignature(type);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,14 @@ public J visitInstanceOf(InstanceOfTree node, Space fmt) {
typeMapping.type(node));
}

@Override
public J visitIntersectionType(IntersectionTypeTree node, Space fmt) {
JContainer<TypeTree> bounds = node.getBounds().isEmpty() ? null :
JContainer.build(EMPTY,
convertAll(node.getBounds(), t -> sourceBefore("&"), noDelim), Markers.EMPTY);
return new J.IntersectionType(randomId(), fmt, Markers.EMPTY, bounds);
}

@Override
public J visitLabeledStatement(LabeledStatementTree node, Space fmt) {
skip(node.getLabel().toString());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.openrewrite.java.tree.TypeUtils;

import javax.lang.model.type.NullType;
import javax.lang.model.type.TypeMirror;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.util.ArrayList;
Expand Down Expand Up @@ -58,7 +59,9 @@ public JavaType type(@Nullable com.sun.tools.javac.code.Type type) {
return existing;
}

if (type instanceof Type.ClassType) {
if (type instanceof Type.IntersectionClassType) {
return intersectionType((Type.IntersectionClassType) type, signature);
} else if (type instanceof Type.ClassType) {
return classType((Type.ClassType) type, signature);
} else if (type instanceof Type.TypeVar) {
return generic((Type.TypeVar) type, signature);
Expand All @@ -79,6 +82,19 @@ public JavaType type(@Nullable com.sun.tools.javac.code.Type type) {
throw new UnsupportedOperationException("Unknown type " + type.getClass().getName());
}

private JavaType intersectionType(Type.IntersectionClassType type, String signature) {
JavaType.Intersection intersection = new JavaType.Intersection(null);
typeCache.put(signature, intersection);
JavaType[] types = new JavaType[type.getBounds().size()];
List<? extends TypeMirror> bounds = type.getBounds();
for (int i = 0; i < bounds.size(); i++) {
TypeMirror bound = bounds.get(i);
types[i] = type((Type) bound);
}
intersection.unsafeSet(types);
return intersection;
}

private JavaType array(Type type, String signature) {
JavaType.Array arr = new JavaType.Array(null, null);
typeCache.put(signature, arr);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.openrewrite.java.tree.JavaType;

import javax.lang.model.type.NullType;
import javax.lang.model.type.TypeMirror;
import java.util.HashSet;
import java.util.Set;
import java.util.StringJoiner;
Expand All @@ -38,6 +39,13 @@ public String signature(@Nullable Object t) {
private String signature(@Nullable Type type) {
if (type == null || type instanceof Type.UnknownType || type instanceof NullType) {
return "{undefined}";
} else if (type instanceof Type.IntersectionClassType) {
Type.IntersectionClassType intersectionClassType = (Type.IntersectionClassType) type;
StringJoiner joiner = new StringJoiner(" & ");
for (TypeMirror typeArg : intersectionClassType.getBounds()) {
joiner.add(signature(typeArg));
}
return joiner.toString();
} else if (type instanceof Type.ClassType) {
try {
return ((Type.ClassType) type).typarams_field != null && ((Type.ClassType) type).typarams_field.length() > 0 ? parameterizedSignature(type) : classSignature(type);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
package org.openrewrite.java.tree;

import org.junit.jupiter.api.Test;
import org.openrewrite.java.MinimumJava11;
import org.openrewrite.test.RewriteTest;

import static org.assertj.core.api.Assertions.assertThat;
import static org.openrewrite.java.Assertions.java;

class TypeCastTest implements RewriteTest {
Expand All @@ -35,4 +37,59 @@ class Test {
)
);
}

@Test
void intersectionCast() {
rewriteRun(
java(
"""
import java.io.Serializable;
import java.util.function.BiFunction;

class Test {
Serializable s = (Serializable & BiFunction<Integer, Integer, Integer>) Integer::sum;
}
"""
)
);
}

@MinimumJava11
@Test
void intersectionCastAssignedToVar() {
rewriteRun(
java(
"""
import java.io.Serializable;
import java.util.function.BiFunction;

class Test {
void m() {
var s = (Serializable & BiFunction<Integer, Integer, Integer>) Integer::sum;
}
}
""",
spec -> spec.afterRecipe(cu -> {
J.MethodDeclaration m = (J.MethodDeclaration) cu.getClasses().get(0).getBody().getStatements().get(0);
J.VariableDeclarations s = (J.VariableDeclarations) m.getBody().getStatements().get(0);
assertThat(s.getType()).isInstanceOf(JavaType.Intersection.class);
JavaType.Intersection intersection = (JavaType.Intersection) s.getType();
assertThat(intersection.getBounds()).satisfiesExactly(
b1 -> assertThat(b1).satisfies(
t -> assertThat(t).isInstanceOf(JavaType.Class.class),
t -> assertThat(((JavaType.Class) t).getFullyQualifiedName()).isEqualTo("java.io.Serializable")
),
b2 -> assertThat(b2).satisfies(
t -> assertThat(t).isInstanceOf(JavaType.Parameterized.class),
t -> assertThat(((JavaType.Parameterized) t).getFullyQualifiedName()).isEqualTo("java.util.function.BiFunction"),
t -> assertThat(((JavaType.Parameterized) t).getTypeParameters()).hasSize(3),
t -> assertThat(((JavaType.Parameterized) t).getTypeParameters()).allSatisfy(
p -> assertThat(((JavaType.Class) p).getFullyQualifiedName()).isEqualTo("java.lang.Integer")
)
)
);
})
)
);
}
}
Loading