Skip to content

Commit

Permalink
feat: support converting VirtualTableScans into LocalRelations
Browse files Browse the repository at this point in the history
also fix LocalRelation -> VirtualTableScan for rows containing null values
  • Loading branch information
Blizzara committed Oct 24, 2024
1 parent fc8a764 commit f384d92
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class ToSubstraitType {
)
}

def toAttribute(namedStruct: NamedStruct): Seq[AttributeReference] = {
def toAttributeSeq(namedStruct: NamedStruct): Seq[AttributeReference] = {
namedStruct
.struct()
.fields()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import io.substrait.spark.{DefaultRelVisitor, SparkExtension, ToSubstraitType}
import io.substrait.spark.expression._

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedRelation}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction}
Expand Down Expand Up @@ -252,8 +253,20 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan]
}

override def visit(emptyScan: relation.EmptyScan): LogicalPlan = {
LocalRelation(ToSubstraitType.toAttribute(emptyScan.getInitialSchema))
LocalRelation(ToSubstraitType.toAttributeSeq(emptyScan.getInitialSchema))
}

override def visit(virtualTableScan: relation.VirtualTableScan): LogicalPlan = {
val rows = virtualTableScan.getRows.asScala.map(
row =>
InternalRow.fromSeq(
row
.fields()
.asScala
.map(field => field.accept(expressionConverter).asInstanceOf[Literal].value)))
LocalRelation(ToSubstraitType.toAttributeSeq(virtualTableScan.getInitialSchema), rows)
}

override def visit(namedScan: relation.NamedScan): LogicalPlan = {
resolve(UnresolvedRelation(namedScan.getNames.asScala)) match {
case m: MultiInstanceRelation => m.newInstance()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,8 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging {
var idx = 0
val buf = new ArrayBuffer[SExpression.Literal](row.numFields)
while (idx < row.numFields) {
val l = Literal.apply(row.get(idx, localRelation.schema(idx).dataType))
val dt = localRelation.schema(idx).dataType
val l = Literal.apply(row.get(idx, dt), dt)
buf += ToSubstraitLiteral.apply(l)
idx += 1
}
Expand Down
25 changes: 25 additions & 0 deletions spark/src/test/scala/io/substrait/spark/RelationsSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package io.substrait.spark

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.test.SharedSparkSession

class RelationsSuite extends SparkFunSuite with SharedSparkSession with SubstraitPlanTestBase {

override def beforeAll(): Unit = {
super.beforeAll()
sparkContext.setLogLevel("WARN")
}

test("local_relation_simple") {
assertSqlSubstraitRelRoundTrip(
"select * from (values (1, 'a'), (2, 'b') as table(col1, col2))"
)
}

test("local_relation_null") {
assertSqlSubstraitRelRoundTrip(
"select * from (values (1), (NULL) as table(col))"
)
}

}

0 comments on commit f384d92

Please sign in to comment.