diff --git a/tornado-assembly/src/bin/tornado-test b/tornado-assembly/src/bin/tornado-test index aabc61e782..b4262fa85c 100755 --- a/tornado-assembly/src/bin/tornado-test +++ b/tornado-assembly/src/bin/tornado-test @@ -95,6 +95,7 @@ __TEST_THE_WORLD__ = [ TestEntry("uk.ac.manchester.tornado.unittests.math.TestMath"), TestEntry("uk.ac.manchester.tornado.unittests.batches.TestBatches"), TestEntry("uk.ac.manchester.tornado.unittests.lambdas.TestLambdas"), + TestEntry("uk.ac.manchester.tornado.unittests.functional.TestLambdas"), TestEntry("uk.ac.manchester.tornado.unittests.flatmap.TestFlatMap"), TestEntry("uk.ac.manchester.tornado.unittests.logic.TestLogic"), TestEntry("uk.ac.manchester.tornado.unittests.fields.TestFields"), diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/phases/TornadoTaskSpecialisation.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/phases/TornadoTaskSpecialisation.java index 9da3da2f65..2b1da197c4 100644 --- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/phases/TornadoTaskSpecialisation.java +++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/phases/TornadoTaskSpecialisation.java @@ -37,6 +37,7 @@ import org.graalvm.compiler.graph.Node; import org.graalvm.compiler.graph.iterators.NodeIterable; import org.graalvm.compiler.nodes.ConstantNode; +import org.graalvm.compiler.nodes.FixedGuardNode; import org.graalvm.compiler.nodes.GraphState; import org.graalvm.compiler.nodes.LogicConstantNode; import org.graalvm.compiler.nodes.NodeView; @@ -46,6 +47,7 @@ import org.graalvm.compiler.nodes.StructuredGraph; import org.graalvm.compiler.nodes.calc.IntegerLessThanNode; import org.graalvm.compiler.nodes.calc.IsNullNode; +import org.graalvm.compiler.nodes.extended.UnboxNode; import org.graalvm.compiler.nodes.java.ArrayLengthNode; import org.graalvm.compiler.nodes.java.LoadFieldNode; import org.graalvm.compiler.nodes.util.GraphUtil; @@ -261,20 +263,22 @@ private void evaluate(final StructuredGraph graph, final Node node, final Object } } - private ConstantNode createConstantFromObject(Object obj, StructuredGraph graph) { - ConstantNode result = null; - switch (obj) { - case Byte objByte -> result = ConstantNode.forByte(objByte, graph); - case Character objChar -> result = ConstantNode.forChar(objChar, graph); - case Short objShort -> result = ConstantNode.forShort(objShort, graph); - case HalfFloat objHalfFloat -> result = ConstantNode.forFloat(objHalfFloat.getFloat32(), graph); - case Integer objInteger -> result = ConstantNode.forInt(objInteger, graph); - case Float objFloat -> result = ConstantNode.forFloat(objFloat, graph); - case Double objDouble -> result = ConstantNode.forDouble(objDouble, graph); - case Long objLong -> result = ConstantNode.forLong(objLong, graph); - case null, default -> unimplemented("createConstantFromObject: %s", obj); - } - return result; + private ConstantNode createPrimitiveConstantFromObjectParameter(Object obj, StructuredGraph graph) { + return switch (obj) { + case Boolean objBoolean -> ConstantNode.forBoolean(objBoolean, graph); + case Byte objByte -> ConstantNode.forByte(objByte, graph); + case Character objChar -> ConstantNode.forChar(objChar, graph); + case Short objShort -> ConstantNode.forShort(objShort, graph); + case HalfFloat objHalfFloat -> ConstantNode.forFloat(objHalfFloat.getFloat32(), graph); + case Integer objInteger -> ConstantNode.forInt(objInteger, graph); + case Float objFloat -> ConstantNode.forFloat(objFloat, graph); + case Double objDouble -> ConstantNode.forDouble(objDouble, graph); + case Long objLong -> ConstantNode.forLong(objLong, graph); + case Object object -> { + unimplemented("createPrimitiveConstantFromObjectParameter: %s", obj); + yield null; + } + }; } private boolean isParameterInvolvedInParallelLoopBound(Node parameterNode) { @@ -301,8 +305,34 @@ private void propagateParameters(StructuredGraph graph, ParameterNode parameterN parameterNode.replaceAtUsages(kernelContextAccessNode); index++; } else { - ConstantNode constant = createConstantFromObject(args[parameterNode.index()], graph); - parameterNode.replaceAtUsages(constant); + var value = args[parameterNode.index()]; + ConstantNode primitiveConstant = createPrimitiveConstantFromObjectParameter(value, graph); + + parameterNode.replaceAtAllUsages(primitiveConstant, true); + parameterNode.safeDelete(); + + //remove Unbox nodes, they are not needed for constant values and cause compilation errors + graph.getNodes().filter(n -> n instanceof PiNode piNode && piNode.object() == primitiveConstant).snapshot().forEach(node -> { + var usagesSnapshot = node.usages().snapshot(); + node.replaceAtAllUsages(primitiveConstant, true); + node.safeDelete(); + + usagesSnapshot.forEach(n -> { + if (n instanceof UnboxNode unboxNode) { + var prev = n.predecessor(); + + unboxNode.replaceAtAllUsages(primitiveConstant, true); + graph.removeFixed(unboxNode); + + if (prev instanceof FixedGuardNode fixedGuardNode) { + if (fixedGuardNode.condition() instanceof IsNullNode isNullNode && isNullNode.getValue() == primitiveConstant) { + fixedGuardNode.clearInputs(); + graph.removeFixed(fixedGuardNode); + } + } + } + }); + }); } } else { parameterNode.usages().snapshot().forEach(n -> {