From 2d17376e22f46a5aa6e7295e4830de68bba86d3b Mon Sep 17 00:00:00 2001 From: andrii0lomakin Date: Wed, 14 Aug 2024 08:45:19 +0200 Subject: [PATCH] Fix of issue #521 in SpirV. --- .../phases/TornadoTaskSpecialization.java | 62 ++++++++++++++----- 1 file changed, 46 insertions(+), 16 deletions(-) diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/phases/TornadoTaskSpecialization.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/phases/TornadoTaskSpecialization.java index a9ca24a38a..478eb04e08 100644 --- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/phases/TornadoTaskSpecialization.java +++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/phases/TornadoTaskSpecialization.java @@ -39,6 +39,7 @@ import org.graalvm.compiler.graph.Node; import org.graalvm.compiler.graph.iterators.NodeIterable; import org.graalvm.compiler.nodes.ConstantNode; +import org.graalvm.compiler.nodes.FixedGuardNode; import org.graalvm.compiler.nodes.GraphState; import org.graalvm.compiler.nodes.LogicConstantNode; import org.graalvm.compiler.nodes.NodeView; @@ -48,6 +49,7 @@ import org.graalvm.compiler.nodes.StructuredGraph; import org.graalvm.compiler.nodes.calc.IntegerLessThanNode; import org.graalvm.compiler.nodes.calc.IsNullNode; +import org.graalvm.compiler.nodes.extended.UnboxNode; import org.graalvm.compiler.nodes.java.ArrayLengthNode; import org.graalvm.compiler.nodes.java.LoadFieldNode; import org.graalvm.compiler.nodes.util.GraphUtil; @@ -251,20 +253,22 @@ private void evaluate(final StructuredGraph graph, final Node node, final Object } } - private ConstantNode createConstantFromObject(Object obj, StructuredGraph graph) { - ConstantNode result = null; - switch (obj) { - case Byte objByte -> result = ConstantNode.forByte(objByte, graph); - case Character objChar -> result = ConstantNode.forChar(objChar, graph); - case Short objShort -> result = ConstantNode.forShort(objShort, graph); - case HalfFloat objHalfFloat -> result = ConstantNode.forFloat(objHalfFloat.getFloat32(), graph); - case Integer objInteger -> result = ConstantNode.forInt(objInteger, graph); - case Float objFloat -> result = ConstantNode.forFloat(objFloat, graph); - case Double objDouble -> result = ConstantNode.forDouble(objDouble, graph); - case Long objLong -> result = ConstantNode.forLong(objLong, graph); - case null, default -> unimplemented("createConstantFromObject: %s", obj); - } - return result; + private ConstantNode createPrimitiveConstantFromObjectParameter(Object obj, StructuredGraph graph) { + return switch (obj) { + case Boolean objBoolean -> ConstantNode.forBoolean(objBoolean, graph); + case Byte objByte -> ConstantNode.forByte(objByte, graph); + case Character objChar -> ConstantNode.forChar(objChar, graph); + case Short objShort -> ConstantNode.forShort(objShort, graph); + case HalfFloat objHalfFloat -> ConstantNode.forFloat(objHalfFloat.getFloat32(), graph); + case Integer objInteger -> ConstantNode.forInt(objInteger, graph); + case Float objFloat -> ConstantNode.forFloat(objFloat, graph); + case Double objDouble -> ConstantNode.forDouble(objDouble, graph); + case Long objLong -> ConstantNode.forLong(objLong, graph); + case Object object -> { + unimplemented("createPrimitiveConstantFromObjectParameter: %s", obj); + yield null; + } + }; } private boolean isParameterInvolvedInParallelLoopBound(Node parameterNode) { @@ -294,8 +298,34 @@ private void propagateParameters(StructuredGraph graph, ParameterNode parameterN parameterNode.replaceAtUsages(kernelContextAccessNode); index++; } else { - ConstantNode constant = createConstantFromObject(args[parameterNode.index()], graph); - parameterNode.replaceAtUsages(constant); + var value = args[parameterNode.index()]; + ConstantNode primitiveConstant = createPrimitiveConstantFromObjectParameter(value, graph); + + parameterNode.replaceAtAllUsages(primitiveConstant, true); + parameterNode.safeDelete(); + + //remove Unbox nodes, they are not needed for constant values and cause compilation errors + graph.getNodes().filter(n -> n instanceof PiNode piNode && piNode.object() == primitiveConstant).snapshot().forEach(node -> { + var usagesSnapshot = node.usages().snapshot(); + node.replaceAtAllUsages(primitiveConstant, true); + node.safeDelete(); + + usagesSnapshot.forEach(n -> { + if (n instanceof UnboxNode unboxNode) { + var prev = n.predecessor(); + + unboxNode.replaceAtAllUsages(primitiveConstant, true); + graph.removeFixed(unboxNode); + + if (prev instanceof FixedGuardNode fixedGuardNode) { + if (fixedGuardNode.condition() instanceof IsNullNode isNullNode && isNullNode.getValue() == primitiveConstant) { + fixedGuardNode.clearInputs(); + graph.removeFixed(fixedGuardNode); + } + } + } + }); + }); } } else { parameterNode.usages().snapshot().forEach(n -> {