From 3f00da84b3e6fb001e7d56acb198292b28d40c8b Mon Sep 17 00:00:00 2001 From: Claes Redestad Date: Mon, 26 Aug 2024 15:58:25 +0000 Subject: [PATCH] 8338906: Avoid passing EnumDescs and extra classes to type switch methods that don't use them Reviewed-by: liach, jlahoda --- .../java/lang/runtime/SwitchBootstraps.java | 136 +++++++++++------- 1 file changed, 85 insertions(+), 51 deletions(-) diff --git a/src/java.base/share/classes/java/lang/runtime/SwitchBootstraps.java b/src/java.base/share/classes/java/lang/runtime/SwitchBootstraps.java index 43f7339c75a96..bfdb76e2ef1b6 100644 --- a/src/java.base/share/classes/java/lang/runtime/SwitchBootstraps.java +++ b/src/java.base/share/classes/java/lang/runtime/SwitchBootstraps.java @@ -43,13 +43,14 @@ import java.util.Optional; import java.util.function.BiPredicate; import java.util.function.Consumer; -import java.util.stream.Stream; + import jdk.internal.access.SharedSecrets; import java.lang.classfile.ClassFile; import java.lang.classfile.Label; import java.lang.classfile.instruction.SwitchCase; import jdk.internal.constant.ConstantUtils; +import jdk.internal.constant.MethodTypeDescImpl; import jdk.internal.constant.ReferenceClassDescImpl; import jdk.internal.misc.PreviewFeatures; import jdk.internal.vm.annotation.Stable; @@ -81,19 +82,27 @@ private SwitchBootstraps() {} private static final MethodHandles.Lookup LOOKUP = MethodHandles.lookup(); private static final boolean previewEnabled = PreviewFeatures.isEnabled(); + private static final ClassDesc CD_BiPredicate = ReferenceClassDescImpl.ofValidated("Ljava/util/function/BiPredicate;"); + private static final ClassDesc CD_Objects = ReferenceClassDescImpl.ofValidated("Ljava/util/Objects;"); - private static final MethodType TYPES_SWITCH_TYPE = MethodType.methodType(int.class, + private static final MethodTypeDesc CHECK_INDEX_DESCRIPTOR = + MethodTypeDescImpl.ofValidated(ConstantDescs.CD_int, ConstantDescs.CD_int, ConstantDescs.CD_int); + private static final MethodTypeDesc MTD_TYPE_SWITCH = MethodTypeDescImpl.ofValidated(ConstantDescs.CD_int, + ConstantDescs.CD_Object, + ConstantDescs.CD_int); + private static final MethodTypeDesc MTD_TYPE_SWITCH_EXTRA = MethodTypeDescImpl.ofValidated(ConstantDescs.CD_int, + ConstantDescs.CD_Object, + ConstantDescs.CD_int, + CD_BiPredicate, + ConstantDescs.CD_List); + private static final MethodType MT_TYPE_SWITCH_EXTRA = MethodType.methodType(int.class, Object.class, int.class, BiPredicate.class, List.class); - - private static final MethodTypeDesc TYPES_SWITCH_DESCRIPTOR = - MethodTypeDesc.ofDescriptor("(Ljava/lang/Object;ILjava/util/function/BiPredicate;Ljava/util/List;)I"); - private static final MethodTypeDesc CHECK_INDEX_DESCRIPTOR = - MethodTypeDesc.ofDescriptor("(II)I"); - - private static final ClassDesc CD_Objects = ReferenceClassDescImpl.ofValidated("Ljava/util/Objects;"); + private static final MethodType MT_TYPE_SWITCH = MethodType.methodType(int.class, + Object.class, + int.class); private static class StaticHolders { private static final MethodHandle MAPPED_ENUM_SWITCH; @@ -180,7 +189,7 @@ public static CallSite typeSwitch(MethodHandles.Lookup lookup, } MethodHandle target = generateTypeSwitch(lookup, selectorType, labels); - + target = target.asType(invocationType); return new ConstantCallSite(target); } @@ -272,9 +281,8 @@ public static CallSite enumSwitch(MethodHandles.Lookup lookup, || !invocationType.parameterType(0).isEnum() || !invocationType.parameterType(1).equals(int.class)) throw new IllegalArgumentException("Illegal invocation type " + invocationType); - requireNonNull(labels); - labels = labels.clone(); + labels = labels.clone(); // implicit null check Class enumClass = invocationType.parameterType(0); boolean constantsOnly = true; @@ -301,7 +309,6 @@ public static CallSite enumSwitch(MethodHandles.Lookup lookup, } else { target = generateTypeSwitch(lookup, invocationType.parameterType(0), labels); } - target = target.asType(invocationType); return new ConstantCallSite(target); @@ -434,6 +441,33 @@ private static final class MappedEnumCache { public MethodHandle generatedSwitch; } + /** + * Check if the labelConstants can be converted statically to bytecode, or + * whether we'll need to compute and pass in extra information at the call site. + */ + private static boolean needsExtraInfo(Class selectorType, Object[] labelConstants) { + for (int idx = labelConstants.length - 1; idx >= 0; idx--) { + Object currentLabel = labelConstants[idx]; + if (currentLabel instanceof Class classLabel) { + // No extra info needed for exact matches or primitives + if (unconditionalExactnessCheck(selectorType, classLabel) || classLabel.isPrimitive()) { + continue; + } + // Hidden classes - or arrays thereof - can't be nominally + // represented. Passed in as arguments. + while (classLabel.isArray()) { + classLabel = classLabel.getComponentType(); + } + if (classLabel.isHidden()) { + return true; + } + } else if (currentLabel instanceof EnumDesc) { + // EnumDescs labels needs late binding + return true; + } + } + return false; + } /* * Construct test chains for labels inside switch, to handle switch repeats: * switch (idx) { @@ -467,9 +501,10 @@ private static Consumer generateTypeSwitchSkeleton(Class selecto } cb.iload(RESTART_IDX); Label dflt = cb.newLabel(); - record Element(Label target, Label next, Object caseLabel) { } - List cases = new ArrayList<>(); - List switchCases = new ArrayList<>(); + Label[] caseTargets = new Label[labelConstants.length]; + Label[] caseNext = new Label[labelConstants.length]; + Object[] caseLabels = new Object[labelConstants.length]; + SwitchCase[] switchCases = new SwitchCase[labelConstants.length]; Object lastLabel = null; for (int idx = labelConstants.length - 1; idx >= 0; idx--) { Object currentLabel = labelConstants[idx]; @@ -478,22 +513,22 @@ record Element(Label target, Label next, Object caseLabel) { } if (lastLabel == null) { next = dflt; } else if (lastLabel.equals(currentLabel)) { - next = cases.getLast().next(); + next = caseNext[idx + 1]; } else { - next = cases.getLast().target(); + next = caseTargets[idx + 1]; } lastLabel = currentLabel; - cases.add(new Element(target, next, currentLabel)); - switchCases.add(SwitchCase.of(idx, target)); + caseTargets[idx] = target; + caseNext[idx] = next; + caseLabels[idx] = currentLabel; + switchCases[idx] = SwitchCase.of(idx, target); } - cases = cases.reversed(); - switchCases = switchCases.reversed(); - cb.tableswitch(0, labelConstants.length - 1, dflt, switchCases); - for (int idx = 0; idx < cases.size(); idx++) { - Element element = cases.get(idx); - Label next = element.next(); - cb.labelBinding(element.target()); - if (element.caseLabel() instanceof Class classLabel) { + cb.tableswitch(0, labelConstants.length - 1, dflt, Arrays.asList(switchCases)); + for (int idx = 0; idx < labelConstants.length; idx++) { + Label next = caseNext[idx]; + Object caseLabel = caseLabels[idx]; + cb.labelBinding(caseTargets[idx]); + if (caseLabel instanceof Class classLabel) { if (unconditionalExactnessCheck(selectorType, classLabel)) { //nothing - unconditionally use this case } else if (classLabel.isPrimitive()) { @@ -577,7 +612,7 @@ record Element(Label target, Label next, Object caseLabel) { } extraClassLabels.add(classLabel); } } - } else if (element.caseLabel() instanceof EnumDesc enumLabel) { + } else if (caseLabel instanceof EnumDesc enumLabel) { int enumIdx = enumDescs.size(); enumDescs.add(enumLabel); cb.aload(ENUM_CACHE); @@ -587,13 +622,13 @@ record Element(Label target, Label next, Object caseLabel) { } MethodTypeDesc.of(ConstantDescs.CD_Integer, ConstantDescs.CD_int)); cb.aload(SELECTOR_OBJ); - cb.invokeinterface(referenceClassDesc(BiPredicate.class), + cb.invokeinterface(CD_BiPredicate, "test", MethodTypeDesc.of(ConstantDescs.CD_boolean, ConstantDescs.CD_Object, ConstantDescs.CD_Object)); cb.ifeq(next); - } else if (element.caseLabel() instanceof String stringLabel) { + } else if (caseLabel instanceof String stringLabel) { cb.ldc(stringLabel); cb.aload(SELECTOR_OBJ); cb.invokevirtual(ConstantDescs.CD_Object, @@ -601,7 +636,7 @@ record Element(Label target, Label next, Object caseLabel) { } MethodTypeDesc.of(ConstantDescs.CD_boolean, ConstantDescs.CD_Object)); cb.ifeq(next); - } else if (element.caseLabel() instanceof Integer integerLabel) { + } else if (caseLabel instanceof Integer integerLabel) { Label compare = cb.newLabel(); Label notNumber = cb.newLabel(); cb.aload(SELECTOR_OBJ); @@ -626,16 +661,16 @@ record Element(Label target, Label next, Object caseLabel) { } cb.ldc(integerLabel); cb.if_icmpne(next); - } else if ((element.caseLabel() instanceof Long || - element.caseLabel() instanceof Float || - element.caseLabel() instanceof Double || - element.caseLabel() instanceof Boolean)) { - if (element.caseLabel() instanceof Boolean c) { + } else if ((caseLabel instanceof Long || + caseLabel instanceof Float || + caseLabel instanceof Double || + caseLabel instanceof Boolean)) { + if (caseLabel instanceof Boolean c) { cb.loadConstant(c ? 1 : 0); } else { - cb.loadConstant((ConstantDesc) element.caseLabel()); + cb.loadConstant((ConstantDesc) caseLabel); } - var caseLabelWrapper = Wrapper.forWrapperType(element.caseLabel().getClass()); + var caseLabelWrapper = Wrapper.forWrapperType(caseLabel.getClass()); cb.invokestatic(caseLabelWrapper.wrapperClassDescriptor(), "valueOf", MethodTypeDesc.of(caseLabelWrapper.wrapperClassDescriptor(), @@ -648,13 +683,13 @@ record Element(Label target, Label next, Object caseLabel) { } cb.ifeq(next); } else { throw new InternalError("Unsupported label type: " + - element.caseLabel().getClass()); + caseLabel.getClass()); } cb.loadConstant(idx); cb.ireturn(); } cb.labelBinding(dflt); - cb.loadConstant(cases.size()); + cb.loadConstant(labelConstants.length); cb.ireturn(); }; } @@ -663,14 +698,15 @@ record Element(Label target, Label next, Object caseLabel) { } * Construct the method handle that represents the method int typeSwitch(Object, int, BiPredicate, List) */ private static MethodHandle generateTypeSwitch(MethodHandles.Lookup caller, Class selectorType, Object[] labelConstants) { - List> enumDescs = new ArrayList<>(); - List> extraClassLabels = new ArrayList<>(); + boolean addExtraInfo = needsExtraInfo(selectorType, labelConstants); + List> enumDescs = addExtraInfo ? new ArrayList<>() : null; + List> extraClassLabels = addExtraInfo ? new ArrayList<>() : null; byte[] classBytes = ClassFile.of().build(ConstantUtils.binaryNameToDesc(typeSwitchClassName(caller.lookupClass())), clb -> { clb.withFlags(AccessFlag.FINAL, AccessFlag.SUPER, AccessFlag.SYNTHETIC) .withMethodBody("typeSwitch", - TYPES_SWITCH_DESCRIPTOR, + addExtraInfo ? MTD_TYPE_SWITCH_EXTRA : MTD_TYPE_SWITCH, ClassFile.ACC_FINAL | ClassFile.ACC_PUBLIC | ClassFile.ACC_STATIC, generateTypeSwitchSkeleton(selectorType, labelConstants, enumDescs, extraClassLabels)); }); @@ -681,13 +717,11 @@ private static MethodHandle generateTypeSwitch(MethodHandles.Lookup caller, Clas lookup = caller.defineHiddenClass(classBytes, true, NESTMATE, STRONG); MethodHandle typeSwitch = lookup.findStatic(lookup.lookupClass(), "typeSwitch", - TYPES_SWITCH_TYPE); - typeSwitch = MethodHandles.insertArguments(typeSwitch, 2, new ResolvedEnumLabels(caller, enumDescs.toArray(new EnumDesc[0])), - List.copyOf(extraClassLabels)); - typeSwitch = MethodHandles.explicitCastArguments(typeSwitch, - MethodType.methodType(int.class, - selectorType, - int.class)); + addExtraInfo ? MT_TYPE_SWITCH_EXTRA : MT_TYPE_SWITCH); + if (addExtraInfo) { + typeSwitch = MethodHandles.insertArguments(typeSwitch, 2, new ResolvedEnumLabels(caller, enumDescs.toArray(new EnumDesc[0])), + List.copyOf(extraClassLabels)); + } return typeSwitch; } catch (Throwable t) { throw new IllegalArgumentException(t);