diff --git a/spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala b/spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala index d393c327..29590137 100644 --- a/spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala +++ b/spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala @@ -128,7 +128,7 @@ class ToSubstraitType { ) } - def toAttribute(namedStruct: NamedStruct): Seq[AttributeReference] = { + def toAttributeSeq(namedStruct: NamedStruct): Seq[AttributeReference] = { namedStruct .struct() .fields() diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala index cdc54b2e..9d49303b 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala @@ -20,6 +20,7 @@ import io.substrait.spark.{DefaultRelVisitor, SparkExtension, ToSubstraitType} import io.substrait.spark.expression._ import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedRelation} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction} @@ -252,8 +253,20 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan] } override def visit(emptyScan: relation.EmptyScan): LogicalPlan = { - LocalRelation(ToSubstraitType.toAttribute(emptyScan.getInitialSchema)) + LocalRelation(ToSubstraitType.toAttributeSeq(emptyScan.getInitialSchema)) } + + override def visit(virtualTableScan: relation.VirtualTableScan): LogicalPlan = { + val rows = virtualTableScan.getRows.asScala.map( + row => + InternalRow.fromSeq( + row + .fields() + .asScala + .map(field => field.accept(expressionConverter).asInstanceOf[Literal].value))) + LocalRelation(ToSubstraitType.toAttributeSeq(virtualTableScan.getInitialSchema), rows) + } + override def visit(namedScan: relation.NamedScan): LogicalPlan = { resolve(UnresolvedRelation(namedScan.getNames.asScala)) match { case m: MultiInstanceRelation => m.newInstance() diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala index e6ce3a90..c27c98a9 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala @@ -333,7 +333,8 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging { var idx = 0 val buf = new ArrayBuffer[SExpression.Literal](row.numFields) while (idx < row.numFields) { - val l = Literal.apply(row.get(idx, localRelation.schema(idx).dataType)) + val dt = localRelation.schema(idx).dataType + val l = Literal.apply(row.get(idx, dt), dt) buf += ToSubstraitLiteral.apply(l) idx += 1 } diff --git a/spark/src/test/scala/io/substrait/spark/RelationsSuite.scala b/spark/src/test/scala/io/substrait/spark/RelationsSuite.scala new file mode 100644 index 00000000..e29d8034 --- /dev/null +++ b/spark/src/test/scala/io/substrait/spark/RelationsSuite.scala @@ -0,0 +1,25 @@ +package io.substrait.spark + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.test.SharedSparkSession + +class RelationsSuite extends SparkFunSuite with SharedSparkSession with SubstraitPlanTestBase { + + override def beforeAll(): Unit = { + super.beforeAll() + sparkContext.setLogLevel("WARN") + } + + test("local_relation_simple") { + assertSqlSubstraitRelRoundTrip( + "select * from (values (1, 'a'), (2, 'b') as table(col1, col2))" + ) + } + + test("local_relation_null") { + assertSqlSubstraitRelRoundTrip( + "select * from (values (1), (NULL) as table(col))" + ) + } + +}