Skip to content

Commit

Permalink
8338906: Avoid passing EnumDescs and extra classes to type switch met…
Browse files Browse the repository at this point in the history
…hods that don't use them

Reviewed-by: liach, jlahoda
  • Loading branch information
cl4es committed Aug 26, 2024
1 parent e63418e commit 3f00da8
Showing 1 changed file with 85 additions and 51 deletions.
136 changes: 85 additions & 51 deletions src/java.base/share/classes/java/lang/runtime/SwitchBootstraps.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,14 @@
import java.util.Optional;
import java.util.function.BiPredicate;
import java.util.function.Consumer;
import java.util.stream.Stream;

import jdk.internal.access.SharedSecrets;
import java.lang.classfile.ClassFile;
import java.lang.classfile.Label;
import java.lang.classfile.instruction.SwitchCase;

import jdk.internal.constant.ConstantUtils;
import jdk.internal.constant.MethodTypeDescImpl;
import jdk.internal.constant.ReferenceClassDescImpl;
import jdk.internal.misc.PreviewFeatures;
import jdk.internal.vm.annotation.Stable;
Expand Down Expand Up @@ -81,19 +82,27 @@ private SwitchBootstraps() {}
private static final MethodHandles.Lookup LOOKUP = MethodHandles.lookup();
private static final boolean previewEnabled = PreviewFeatures.isEnabled();

private static final ClassDesc CD_BiPredicate = ReferenceClassDescImpl.ofValidated("Ljava/util/function/BiPredicate;");
private static final ClassDesc CD_Objects = ReferenceClassDescImpl.ofValidated("Ljava/util/Objects;");

private static final MethodType TYPES_SWITCH_TYPE = MethodType.methodType(int.class,
private static final MethodTypeDesc CHECK_INDEX_DESCRIPTOR =
MethodTypeDescImpl.ofValidated(ConstantDescs.CD_int, ConstantDescs.CD_int, ConstantDescs.CD_int);
private static final MethodTypeDesc MTD_TYPE_SWITCH = MethodTypeDescImpl.ofValidated(ConstantDescs.CD_int,
ConstantDescs.CD_Object,
ConstantDescs.CD_int);
private static final MethodTypeDesc MTD_TYPE_SWITCH_EXTRA = MethodTypeDescImpl.ofValidated(ConstantDescs.CD_int,
ConstantDescs.CD_Object,
ConstantDescs.CD_int,
CD_BiPredicate,
ConstantDescs.CD_List);
private static final MethodType MT_TYPE_SWITCH_EXTRA = MethodType.methodType(int.class,
Object.class,
int.class,
BiPredicate.class,
List.class);

private static final MethodTypeDesc TYPES_SWITCH_DESCRIPTOR =
MethodTypeDesc.ofDescriptor("(Ljava/lang/Object;ILjava/util/function/BiPredicate;Ljava/util/List;)I");
private static final MethodTypeDesc CHECK_INDEX_DESCRIPTOR =
MethodTypeDesc.ofDescriptor("(II)I");

private static final ClassDesc CD_Objects = ReferenceClassDescImpl.ofValidated("Ljava/util/Objects;");
private static final MethodType MT_TYPE_SWITCH = MethodType.methodType(int.class,
Object.class,
int.class);

private static class StaticHolders {
private static final MethodHandle MAPPED_ENUM_SWITCH;
Expand Down Expand Up @@ -180,7 +189,7 @@ public static CallSite typeSwitch(MethodHandles.Lookup lookup,
}

MethodHandle target = generateTypeSwitch(lookup, selectorType, labels);

target = target.asType(invocationType);
return new ConstantCallSite(target);
}

Expand Down Expand Up @@ -272,9 +281,8 @@ public static CallSite enumSwitch(MethodHandles.Lookup lookup,
|| !invocationType.parameterType(0).isEnum()
|| !invocationType.parameterType(1).equals(int.class))
throw new IllegalArgumentException("Illegal invocation type " + invocationType);
requireNonNull(labels);

labels = labels.clone();
labels = labels.clone(); // implicit null check

Class<?> enumClass = invocationType.parameterType(0);
boolean constantsOnly = true;
Expand All @@ -301,7 +309,6 @@ public static CallSite enumSwitch(MethodHandles.Lookup lookup,
} else {
target = generateTypeSwitch(lookup, invocationType.parameterType(0), labels);
}

target = target.asType(invocationType);

return new ConstantCallSite(target);
Expand Down Expand Up @@ -434,6 +441,33 @@ private static final class MappedEnumCache {
public MethodHandle generatedSwitch;
}

/**
* Check if the labelConstants can be converted statically to bytecode, or
* whether we'll need to compute and pass in extra information at the call site.
*/
private static boolean needsExtraInfo(Class<?> selectorType, Object[] labelConstants) {
for (int idx = labelConstants.length - 1; idx >= 0; idx--) {
Object currentLabel = labelConstants[idx];
if (currentLabel instanceof Class<?> classLabel) {
// No extra info needed for exact matches or primitives
if (unconditionalExactnessCheck(selectorType, classLabel) || classLabel.isPrimitive()) {
continue;
}
// Hidden classes - or arrays thereof - can't be nominally
// represented. Passed in as arguments.
while (classLabel.isArray()) {
classLabel = classLabel.getComponentType();
}
if (classLabel.isHidden()) {
return true;
}
} else if (currentLabel instanceof EnumDesc<?>) {
// EnumDescs labels needs late binding
return true;
}
}
return false;
}
/*
* Construct test chains for labels inside switch, to handle switch repeats:
* switch (idx) {
Expand Down Expand Up @@ -467,9 +501,10 @@ private static Consumer<CodeBuilder> generateTypeSwitchSkeleton(Class<?> selecto
}
cb.iload(RESTART_IDX);
Label dflt = cb.newLabel();
record Element(Label target, Label next, Object caseLabel) { }
List<Element> cases = new ArrayList<>();
List<SwitchCase> switchCases = new ArrayList<>();
Label[] caseTargets = new Label[labelConstants.length];
Label[] caseNext = new Label[labelConstants.length];
Object[] caseLabels = new Object[labelConstants.length];
SwitchCase[] switchCases = new SwitchCase[labelConstants.length];
Object lastLabel = null;
for (int idx = labelConstants.length - 1; idx >= 0; idx--) {
Object currentLabel = labelConstants[idx];
Expand All @@ -478,22 +513,22 @@ record Element(Label target, Label next, Object caseLabel) { }
if (lastLabel == null) {
next = dflt;
} else if (lastLabel.equals(currentLabel)) {
next = cases.getLast().next();
next = caseNext[idx + 1];
} else {
next = cases.getLast().target();
next = caseTargets[idx + 1];
}
lastLabel = currentLabel;
cases.add(new Element(target, next, currentLabel));
switchCases.add(SwitchCase.of(idx, target));
caseTargets[idx] = target;
caseNext[idx] = next;
caseLabels[idx] = currentLabel;
switchCases[idx] = SwitchCase.of(idx, target);
}
cases = cases.reversed();
switchCases = switchCases.reversed();
cb.tableswitch(0, labelConstants.length - 1, dflt, switchCases);
for (int idx = 0; idx < cases.size(); idx++) {
Element element = cases.get(idx);
Label next = element.next();
cb.labelBinding(element.target());
if (element.caseLabel() instanceof Class<?> classLabel) {
cb.tableswitch(0, labelConstants.length - 1, dflt, Arrays.asList(switchCases));
for (int idx = 0; idx < labelConstants.length; idx++) {
Label next = caseNext[idx];
Object caseLabel = caseLabels[idx];
cb.labelBinding(caseTargets[idx]);
if (caseLabel instanceof Class<?> classLabel) {
if (unconditionalExactnessCheck(selectorType, classLabel)) {
//nothing - unconditionally use this case
} else if (classLabel.isPrimitive()) {
Expand Down Expand Up @@ -577,7 +612,7 @@ record Element(Label target, Label next, Object caseLabel) { }
extraClassLabels.add(classLabel);
}
}
} else if (element.caseLabel() instanceof EnumDesc<?> enumLabel) {
} else if (caseLabel instanceof EnumDesc<?> enumLabel) {
int enumIdx = enumDescs.size();
enumDescs.add(enumLabel);
cb.aload(ENUM_CACHE);
Expand All @@ -587,21 +622,21 @@ record Element(Label target, Label next, Object caseLabel) { }
MethodTypeDesc.of(ConstantDescs.CD_Integer,
ConstantDescs.CD_int));
cb.aload(SELECTOR_OBJ);
cb.invokeinterface(referenceClassDesc(BiPredicate.class),
cb.invokeinterface(CD_BiPredicate,
"test",
MethodTypeDesc.of(ConstantDescs.CD_boolean,
ConstantDescs.CD_Object,
ConstantDescs.CD_Object));
cb.ifeq(next);
} else if (element.caseLabel() instanceof String stringLabel) {
} else if (caseLabel instanceof String stringLabel) {
cb.ldc(stringLabel);
cb.aload(SELECTOR_OBJ);
cb.invokevirtual(ConstantDescs.CD_Object,
"equals",
MethodTypeDesc.of(ConstantDescs.CD_boolean,
ConstantDescs.CD_Object));
cb.ifeq(next);
} else if (element.caseLabel() instanceof Integer integerLabel) {
} else if (caseLabel instanceof Integer integerLabel) {
Label compare = cb.newLabel();
Label notNumber = cb.newLabel();
cb.aload(SELECTOR_OBJ);
Expand All @@ -626,16 +661,16 @@ record Element(Label target, Label next, Object caseLabel) { }

cb.ldc(integerLabel);
cb.if_icmpne(next);
} else if ((element.caseLabel() instanceof Long ||
element.caseLabel() instanceof Float ||
element.caseLabel() instanceof Double ||
element.caseLabel() instanceof Boolean)) {
if (element.caseLabel() instanceof Boolean c) {
} else if ((caseLabel instanceof Long ||
caseLabel instanceof Float ||
caseLabel instanceof Double ||
caseLabel instanceof Boolean)) {
if (caseLabel instanceof Boolean c) {
cb.loadConstant(c ? 1 : 0);
} else {
cb.loadConstant((ConstantDesc) element.caseLabel());
cb.loadConstant((ConstantDesc) caseLabel);
}
var caseLabelWrapper = Wrapper.forWrapperType(element.caseLabel().getClass());
var caseLabelWrapper = Wrapper.forWrapperType(caseLabel.getClass());
cb.invokestatic(caseLabelWrapper.wrapperClassDescriptor(),
"valueOf",
MethodTypeDesc.of(caseLabelWrapper.wrapperClassDescriptor(),
Expand All @@ -648,13 +683,13 @@ record Element(Label target, Label next, Object caseLabel) { }
cb.ifeq(next);
} else {
throw new InternalError("Unsupported label type: " +
element.caseLabel().getClass());
caseLabel.getClass());
}
cb.loadConstant(idx);
cb.ireturn();
}
cb.labelBinding(dflt);
cb.loadConstant(cases.size());
cb.loadConstant(labelConstants.length);
cb.ireturn();
};
}
Expand All @@ -663,14 +698,15 @@ record Element(Label target, Label next, Object caseLabel) { }
* Construct the method handle that represents the method int typeSwitch(Object, int, BiPredicate, List)
*/
private static MethodHandle generateTypeSwitch(MethodHandles.Lookup caller, Class<?> selectorType, Object[] labelConstants) {
List<EnumDesc<?>> enumDescs = new ArrayList<>();
List<Class<?>> extraClassLabels = new ArrayList<>();
boolean addExtraInfo = needsExtraInfo(selectorType, labelConstants);
List<EnumDesc<?>> enumDescs = addExtraInfo ? new ArrayList<>() : null;
List<Class<?>> extraClassLabels = addExtraInfo ? new ArrayList<>() : null;

byte[] classBytes = ClassFile.of().build(ConstantUtils.binaryNameToDesc(typeSwitchClassName(caller.lookupClass())),
clb -> {
clb.withFlags(AccessFlag.FINAL, AccessFlag.SUPER, AccessFlag.SYNTHETIC)
.withMethodBody("typeSwitch",
TYPES_SWITCH_DESCRIPTOR,
addExtraInfo ? MTD_TYPE_SWITCH_EXTRA : MTD_TYPE_SWITCH,
ClassFile.ACC_FINAL | ClassFile.ACC_PUBLIC | ClassFile.ACC_STATIC,
generateTypeSwitchSkeleton(selectorType, labelConstants, enumDescs, extraClassLabels));
});
Expand All @@ -681,13 +717,11 @@ private static MethodHandle generateTypeSwitch(MethodHandles.Lookup caller, Clas
lookup = caller.defineHiddenClass(classBytes, true, NESTMATE, STRONG);
MethodHandle typeSwitch = lookup.findStatic(lookup.lookupClass(),
"typeSwitch",
TYPES_SWITCH_TYPE);
typeSwitch = MethodHandles.insertArguments(typeSwitch, 2, new ResolvedEnumLabels(caller, enumDescs.toArray(new EnumDesc<?>[0])),
List.copyOf(extraClassLabels));
typeSwitch = MethodHandles.explicitCastArguments(typeSwitch,
MethodType.methodType(int.class,
selectorType,
int.class));
addExtraInfo ? MT_TYPE_SWITCH_EXTRA : MT_TYPE_SWITCH);
if (addExtraInfo) {
typeSwitch = MethodHandles.insertArguments(typeSwitch, 2, new ResolvedEnumLabels(caller, enumDescs.toArray(new EnumDesc<?>[0])),
List.copyOf(extraClassLabels));
}
return typeSwitch;
} catch (Throwable t) {
throw new IllegalArgumentException(t);
Expand Down

0 comments on commit 3f00da8

Please sign in to comment.