Skip to content

Commit

Permalink
Fix of issue beehive-lab#521 in PTX.
Browse files Browse the repository at this point in the history
  • Loading branch information
andrii0lomakin committed Aug 14, 2024
1 parent 1fb38e3 commit af7c125
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 16 deletions.
1 change: 1 addition & 0 deletions tornado-assembly/src/bin/tornado-test
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ __TEST_THE_WORLD__ = [
TestEntry("uk.ac.manchester.tornado.unittests.math.TestMath"),
TestEntry("uk.ac.manchester.tornado.unittests.batches.TestBatches"),
TestEntry("uk.ac.manchester.tornado.unittests.lambdas.TestLambdas"),
TestEntry("uk.ac.manchester.tornado.unittests.functional.TestLambdas"),
TestEntry("uk.ac.manchester.tornado.unittests.flatmap.TestFlatMap"),
TestEntry("uk.ac.manchester.tornado.unittests.logic.TestLogic"),
TestEntry("uk.ac.manchester.tornado.unittests.fields.TestFields"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.graalvm.compiler.graph.Node;
import org.graalvm.compiler.graph.iterators.NodeIterable;
import org.graalvm.compiler.nodes.ConstantNode;
import org.graalvm.compiler.nodes.FixedGuardNode;
import org.graalvm.compiler.nodes.GraphState;
import org.graalvm.compiler.nodes.LogicConstantNode;
import org.graalvm.compiler.nodes.NodeView;
Expand All @@ -46,6 +47,7 @@
import org.graalvm.compiler.nodes.StructuredGraph;
import org.graalvm.compiler.nodes.calc.IntegerLessThanNode;
import org.graalvm.compiler.nodes.calc.IsNullNode;
import org.graalvm.compiler.nodes.extended.UnboxNode;
import org.graalvm.compiler.nodes.java.ArrayLengthNode;
import org.graalvm.compiler.nodes.java.LoadFieldNode;
import org.graalvm.compiler.nodes.util.GraphUtil;
Expand Down Expand Up @@ -261,20 +263,22 @@ private void evaluate(final StructuredGraph graph, final Node node, final Object
}
}

private ConstantNode createConstantFromObject(Object obj, StructuredGraph graph) {
ConstantNode result = null;
switch (obj) {
case Byte objByte -> result = ConstantNode.forByte(objByte, graph);
case Character objChar -> result = ConstantNode.forChar(objChar, graph);
case Short objShort -> result = ConstantNode.forShort(objShort, graph);
case HalfFloat objHalfFloat -> result = ConstantNode.forFloat(objHalfFloat.getFloat32(), graph);
case Integer objInteger -> result = ConstantNode.forInt(objInteger, graph);
case Float objFloat -> result = ConstantNode.forFloat(objFloat, graph);
case Double objDouble -> result = ConstantNode.forDouble(objDouble, graph);
case Long objLong -> result = ConstantNode.forLong(objLong, graph);
case null, default -> unimplemented("createConstantFromObject: %s", obj);
}
return result;
private ConstantNode createPrimitiveConstantFromObjectParameter(Object obj, StructuredGraph graph) {
return switch (obj) {
case Boolean objBoolean -> ConstantNode.forBoolean(objBoolean, graph);
case Byte objByte -> ConstantNode.forByte(objByte, graph);
case Character objChar -> ConstantNode.forChar(objChar, graph);
case Short objShort -> ConstantNode.forShort(objShort, graph);
case HalfFloat objHalfFloat -> ConstantNode.forFloat(objHalfFloat.getFloat32(), graph);
case Integer objInteger -> ConstantNode.forInt(objInteger, graph);
case Float objFloat -> ConstantNode.forFloat(objFloat, graph);
case Double objDouble -> ConstantNode.forDouble(objDouble, graph);
case Long objLong -> ConstantNode.forLong(objLong, graph);
case Object object -> {
unimplemented("createPrimitiveConstantFromObjectParameter: %s", obj);
yield null;
}
};
}

private boolean isParameterInvolvedInParallelLoopBound(Node parameterNode) {
Expand All @@ -301,8 +305,34 @@ private void propagateParameters(StructuredGraph graph, ParameterNode parameterN
parameterNode.replaceAtUsages(kernelContextAccessNode);
index++;
} else {
ConstantNode constant = createConstantFromObject(args[parameterNode.index()], graph);
parameterNode.replaceAtUsages(constant);
var value = args[parameterNode.index()];
ConstantNode primitiveConstant = createPrimitiveConstantFromObjectParameter(value, graph);

parameterNode.replaceAtAllUsages(primitiveConstant, true);
parameterNode.safeDelete();

//remove Unbox nodes, they are not needed for constant values and cause compilation errors
graph.getNodes().filter(n -> n instanceof PiNode piNode && piNode.object() == primitiveConstant).snapshot().forEach(node -> {
var usagesSnapshot = node.usages().snapshot();
node.replaceAtAllUsages(primitiveConstant, true);
node.safeDelete();

usagesSnapshot.forEach(n -> {
if (n instanceof UnboxNode unboxNode) {
var prev = n.predecessor();

unboxNode.replaceAtAllUsages(primitiveConstant, true);
graph.removeFixed(unboxNode);

if (prev instanceof FixedGuardNode fixedGuardNode) {
if (fixedGuardNode.condition() instanceof IsNullNode isNullNode && isNullNode.getValue() == primitiveConstant) {
fixedGuardNode.clearInputs();
graph.removeFixed(fixedGuardNode);
}
}
}
});
});
}
} else {
parameterNode.usages().snapshot().forEach(n -> {
Expand Down

0 comments on commit af7c125

Please sign in to comment.