diff --git a/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java b/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java index 5193c5b53..5f30a1931 100644 --- a/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java +++ b/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java @@ -129,6 +129,11 @@ public OUTPUT visit(Expression.MapLiteral expr) throws EXCEPTION { return visitFallback(expr); } + @Override + public OUTPUT visit(Expression.EmptyMapLiteral expr) throws EXCEPTION { + return visitFallback(expr); + } + @Override public OUTPUT visit(Expression.ListLiteral expr) throws EXCEPTION { return visitFallback(expr); diff --git a/core/src/main/java/io/substrait/expression/Expression.java b/core/src/main/java/io/substrait/expression/Expression.java index 7a049580b..9fc719f35 100644 --- a/core/src/main/java/io/substrait/expression/Expression.java +++ b/core/src/main/java/io/substrait/expression/Expression.java @@ -498,6 +498,25 @@ public R accept(ExpressionVisitor visitor) throws } } + @Value.Immutable + abstract static class EmptyMapLiteral implements Literal { + public abstract Type keyType(); + + public abstract Type valueType(); + + public Type getType() { + return Type.withNullability(nullable()).map(keyType(), valueType()); + } + + public static ImmutableExpression.EmptyMapLiteral.Builder builder() { + return ImmutableExpression.EmptyMapLiteral.builder(); + } + + public R accept(ExpressionVisitor visitor) throws E { + return visitor.visit(this); + } + } + @Value.Immutable abstract static class ListLiteral implements Literal { public abstract List values(); diff --git a/core/src/main/java/io/substrait/expression/ExpressionCreator.java b/core/src/main/java/io/substrait/expression/ExpressionCreator.java index 2ec0b3b40..55e71b78d 100644 --- a/core/src/main/java/io/substrait/expression/ExpressionCreator.java +++ b/core/src/main/java/io/substrait/expression/ExpressionCreator.java @@ -252,6 +252,15 @@ public static Expression.MapLiteral map( return Expression.MapLiteral.builder().nullable(nullable).putAllValues(values).build(); } + public static Expression.EmptyMapLiteral emptyMap( + boolean nullable, Type keyType, Type valueType) { + return Expression.EmptyMapLiteral.builder() + .keyType(keyType) + .valueType(valueType) + .nullable(nullable) + .build(); + } + public static Expression.ListLiteral list(boolean nullable, Expression.Literal... values) { return Expression.ListLiteral.builder().nullable(nullable).addValues(values).build(); } diff --git a/core/src/main/java/io/substrait/expression/ExpressionVisitor.java b/core/src/main/java/io/substrait/expression/ExpressionVisitor.java index 7bcdd4eab..b27a241a2 100644 --- a/core/src/main/java/io/substrait/expression/ExpressionVisitor.java +++ b/core/src/main/java/io/substrait/expression/ExpressionVisitor.java @@ -53,6 +53,8 @@ public interface ExpressionVisitor { R visit(Expression.MapLiteral expr) throws E; + R visit(Expression.EmptyMapLiteral expr) throws E; + R visit(Expression.ListLiteral expr) throws E; R visit(Expression.EmptyListLiteral expr) throws E; diff --git a/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java b/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java index a400522b1..255e86cc7 100644 --- a/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java @@ -241,6 +241,21 @@ public Expression visit(io.substrait.expression.Expression.MapLiteral expr) { }); } + @Override + public Expression visit(io.substrait.expression.Expression.EmptyMapLiteral expr) { + return lit( + bldr -> { + var protoMapType = expr.getType().accept(typeProtoConverter); + bldr.setEmptyMap(protoMapType.getMap()) + // For empty maps, the Literal message's own nullable field should be ignored + // in favor of the nullability of the Type.Map in the literal's + // empty_map field. But for safety we set the literal's nullable field + // to match in case any readers either look in the wrong location + // or want to verify that they are consistent. + .setNullable(expr.nullable()); + }); + } + @Override public Expression visit(io.substrait.expression.Expression.ListLiteral expr) { return lit( diff --git a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java index 44a4aa24e..d2b95d74f 100644 --- a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java @@ -391,6 +391,12 @@ public Expression.Literal from(io.substrait.proto.Expression.Literal literal) { literal.getNullable(), literal.getMap().getKeyValuesList().stream() .collect(Collectors.toMap(kv -> from(kv.getKey()), kv -> from(kv.getValue())))); + case EMPTY_MAP -> { + // literal.getNullable() is intentionally ignored in favor of the nullability + // specified in the literal.getEmptyMap() type. + var mapType = protoTypeConverter.fromMap(literal.getEmptyMap()); + yield ExpressionCreator.emptyMap(mapType.nullable(), mapType.key(), mapType.value()); + } case UUID -> ExpressionCreator.uuid(literal.getNullable(), literal.getUuid()); case NULL -> ExpressionCreator.typedNull(protoTypeConverter.from(literal.getNull())); case LIST -> ExpressionCreator.list( diff --git a/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java b/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java index 8a2fefb37..29bbe1a8c 100644 --- a/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java +++ b/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java @@ -153,6 +153,11 @@ public Optional visit(Expression.MapLiteral expr) throws EXCEPTION { return visitLiteral(expr); } + @Override + public Optional visit(Expression.EmptyMapLiteral expr) throws EXCEPTION { + return visitLiteral(expr); + } + @Override public Optional visit(Expression.ListLiteral expr) throws EXCEPTION { return visitLiteral(expr); diff --git a/core/src/main/java/io/substrait/type/TypeCreator.java b/core/src/main/java/io/substrait/type/TypeCreator.java index a53d71aa6..b94a468f6 100644 --- a/core/src/main/java/io/substrait/type/TypeCreator.java +++ b/core/src/main/java/io/substrait/type/TypeCreator.java @@ -95,7 +95,7 @@ public Type.ListType list(Type type) { return Type.ListType.builder().nullable(nullable).elementType(type).build(); } - public Type map(Type key, Type value) { + public Type.Map map(Type key, Type value) { return Type.Map.builder().nullable(nullable).key(key).value(value).build(); } diff --git a/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java b/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java index 90776ba23..8bea1a42a 100644 --- a/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java +++ b/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java @@ -57,8 +57,7 @@ public Type from(io.substrait.proto.Type type) { .map(this::from) .collect(java.util.stream.Collectors.toList())); case LIST -> fromList(type.getList()); - case MAP -> n(type.getMap().getNullability()) - .map(from(type.getMap().getKey()), from(type.getMap().getValue())); + case MAP -> fromMap(type.getMap()); case USER_DEFINED -> { var userDefined = type.getUserDefined(); var t = lookup.getType(userDefined.getTypeReference(), extensions); @@ -74,6 +73,10 @@ public Type.ListType fromList(io.substrait.proto.Type.List list) { return n(list.getNullability()).list(from(list.getType())); } + public Type.Map fromMap(io.substrait.proto.Type.Map map) { + return n(map.getNullability()).map(from(map.getKey()), from(map.getValue())); + } + public static boolean isNullable(io.substrait.proto.Type.Nullability nullability) { return io.substrait.proto.Type.Nullability.NULLABILITY_NULLABLE == nullability; } diff --git a/spark/src/main/scala/io/substrait/debug/ExpressionToString.scala b/spark/src/main/scala/io/substrait/debug/ExpressionToString.scala index 421a040d2..29c951363 100644 --- a/spark/src/main/scala/io/substrait/debug/ExpressionToString.scala +++ b/spark/src/main/scala/io/substrait/debug/ExpressionToString.scala @@ -72,4 +72,8 @@ class ExpressionToString extends DefaultExpressionVisitor[String] { override def visit(expr: Expression.UserDefinedLiteral): String = { expr.toString } + + override def visit(expr: Expression.EmptyMapLiteral): String = { + expr.toString + } } diff --git a/spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala b/spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala index fa5ec4fea..951900cd8 100644 --- a/spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala +++ b/spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala @@ -44,11 +44,38 @@ private class ToSparkType override def visit(expr: Type.Str): DataType = StringType + override def visit(expr: Type.Binary): DataType = BinaryType + override def visit(expr: Type.FixedChar): DataType = StringType override def visit(expr: Type.VarChar): DataType = StringType override def visit(expr: Type.Bool): DataType = BooleanType + + override def visit(expr: Type.PrecisionTimestamp): DataType = { + Util.assertMicroseconds(expr.precision()) + TimestampNTZType + } + override def visit(expr: Type.PrecisionTimestampTZ): DataType = { + Util.assertMicroseconds(expr.precision()) + TimestampType + } + + override def visit(expr: Type.IntervalDay): DataType = { + Util.assertMicroseconds(expr.precision()) + DayTimeIntervalType.DEFAULT + } + + override def visit(expr: Type.IntervalYear): DataType = YearMonthIntervalType.DEFAULT + + override def visit(expr: Type.ListType): DataType = + ArrayType(expr.elementType().accept(this), containsNull = expr.elementType().nullable()) + + override def visit(expr: Type.Map): DataType = + MapType( + expr.key().accept(this), + expr.value().accept(this), + valueContainsNull = expr.value().nullable()) } class ToSubstraitType { @@ -81,10 +108,12 @@ class ToSubstraitType { case charType: CharType => Some(creator.fixedChar(charType.length)) case varcharType: VarcharType => Some(creator.varChar(varcharType.length)) case StringType => Some(creator.STRING) - case DateType => Some(creator.DATE) - case TimestampType => Some(creator.TIMESTAMP) - case TimestampNTZType => Some(creator.TIMESTAMP_TZ) case BinaryType => Some(creator.BINARY) + case DateType => Some(creator.DATE) + case TimestampNTZType => Some(creator.precisionTimestamp(Util.MICROSECOND_PRECISION)) + case TimestampType => Some(creator.precisionTimestampTZ(Util.MICROSECOND_PRECISION)) + case DayTimeIntervalType.DEFAULT => Some(creator.intervalDay(Util.MICROSECOND_PRECISION)) + case YearMonthIntervalType.DEFAULT => Some(creator.INTERVAL_YEAR) case ArrayType(elementType, containsNull) => convert(elementType, Seq.empty, containsNull).map(creator.list) case MapType(keyType, valueType, valueContainsNull) => diff --git a/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala b/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala index 62b4bfcd9..d531cbc01 100644 --- a/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala +++ b/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala @@ -28,8 +28,9 @@ import io.substrait.`type`.{StringTypeVisitor, Type} import io.substrait.{expression => exp} import io.substrait.expression.{Expression => SExpression} import io.substrait.util.DecimalUtil +import io.substrait.utils.Util -import scala.collection.JavaConverters.asScalaBufferConverter +import scala.collection.JavaConverters.{asScalaBufferConverter, mapAsScalaMapConverter} class ToSparkExpression( val scalarFunctionConverter: ToScalarFunction, @@ -61,6 +62,10 @@ class ToSparkExpression( Literal(expr.value(), ToSubstraitType.convert(expr.getType)) } + override def visit(expr: SExpression.FP32Literal): Literal = { + Literal(expr.value(), ToSubstraitType.convert(expr.getType)) + } + override def visit(expr: SExpression.FP64Literal): Expression = { Literal(expr.value(), ToSubstraitType.convert(expr.getType)) } @@ -77,15 +82,71 @@ class ToSparkExpression( Literal(UTF8String.fromString(expr.value()), ToSubstraitType.convert(expr.getType)) } + override def visit(expr: SExpression.BinaryLiteral): Literal = { + Literal(expr.value().toByteArray, ToSubstraitType.convert(expr.getType)) + } + override def visit(expr: SExpression.DecimalLiteral): Expression = { val value = expr.value.toByteArray val decimal = DecimalUtil.getBigDecimalFromBytes(value, expr.scale, 16) Literal(Decimal(decimal), ToSubstraitType.convert(expr.getType)) } + override def visit(expr: SExpression.DateLiteral): Expression = { Literal(expr.value(), ToSubstraitType.convert(expr.getType)) } + override def visit(expr: SExpression.PrecisionTimestampLiteral): Literal = { + // Spark timestamps are stored as a microseconds Long + Util.assertMicroseconds(expr.precision()) + Literal(expr.value(), ToSubstraitType.convert(expr.getType)) + } + + override def visit(expr: SExpression.PrecisionTimestampTZLiteral): Literal = { + // Spark timestamps are stored as a microseconds Long + Util.assertMicroseconds(expr.precision()) + Literal(expr.value(), ToSubstraitType.convert(expr.getType)) + } + + override def visit(expr: SExpression.IntervalDayLiteral): Literal = { + Util.assertMicroseconds(expr.precision()) + // Spark uses a single microseconds Long as the "physical" type for DayTimeInterval + val micros = + (expr.days() * Util.SECONDS_PER_DAY + expr.seconds()) * Util.MICROS_PER_SECOND + + expr.subseconds() + Literal(micros, ToSubstraitType.convert(expr.getType)) + } + + override def visit(expr: SExpression.IntervalYearLiteral): Literal = { + // Spark uses a single months Int as the "physical" type for YearMonthInterval + val months = expr.years() * 12 + expr.months() + Literal(months, ToSubstraitType.convert(expr.getType)) + } + + override def visit(expr: SExpression.ListLiteral): Literal = { + val array = expr.values().asScala.map(value => value.accept(this).asInstanceOf[Literal].value) + Literal.create(array, ToSubstraitType.convert(expr.getType)) + } + + override def visit(expr: SExpression.EmptyListLiteral): Expression = { + Literal.default(ToSubstraitType.convert(expr.getType)) + } + + override def visit(expr: SExpression.MapLiteral): Literal = { + val map = expr.values().asScala.map { + case (key, value) => + ( + key.accept(this).asInstanceOf[Literal].value, + value.accept(this).asInstanceOf[Literal].value + ) + } + Literal.create(map, ToSubstraitType.convert(expr.getType)) + } + + override def visit(expr: SExpression.EmptyMapLiteral): Literal = { + Literal.default(ToSubstraitType.convert(expr.getType)) + } + override def visit(expr: SExpression.NullLiteral): Expression = { Literal(null, ToSubstraitType.convert(expr.getType)) } @@ -98,6 +159,7 @@ class ToSparkExpression( override def visit(expr: exp.FieldReference): Expression = { withFieldReference(expr)(i => currentOutput(i).clone()) } + override def visit(expr: SExpression.IfThen): Expression = { val branches = expr .ifClauses() diff --git a/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitLiteral.scala b/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitLiteral.scala index 73362e982..4acf4a2f0 100644 --- a/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitLiteral.scala +++ b/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitLiteral.scala @@ -19,11 +19,15 @@ package io.substrait.spark.expression import io.substrait.spark.ToSubstraitType import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import io.substrait.expression.{Expression => SExpression} import io.substrait.expression.ExpressionCreator._ +import io.substrait.utils.Util + +import scala.collection.JavaConverters class ToSubstraitLiteral { @@ -34,6 +38,34 @@ class ToSubstraitLiteral { scale: Int): SExpression.Literal = decimal(false, d.toJavaBigDecimal, precision, scale) + private def sparkArray2Substrait( + arrayData: ArrayData, + elementType: DataType, + containsNull: Boolean): SExpression.Literal = { + val elements = arrayData.array.map(any => apply(Literal(any, elementType))) + if (elements.isEmpty) { + return emptyList(false, ToSubstraitType.convert(elementType, nullable = containsNull).get) + } + list(false, JavaConverters.asJavaIterable(elements)) // TODO: handle containsNull + } + + private def sparkMap2Substrait( + mapData: MapData, + keyType: DataType, + valueType: DataType, + valueContainsNull: Boolean): SExpression.Literal = { + val keys = mapData.keyArray().array.map(any => apply(Literal(any, keyType))) + val values = mapData.valueArray().array.map(any => apply(Literal(any, valueType))) + if (keys.isEmpty) { + return emptyMap( + false, + ToSubstraitType.convert(keyType, nullable = false).get, + ToSubstraitType.convert(valueType, nullable = valueContainsNull).get) + } + // TODO: handle valueContainsNull + map(false, JavaConverters.mapAsJavaMap(keys.zip(values).toMap)) + } + val _bool: Boolean => SExpression.Literal = bool(false, _) val _i8: Byte => SExpression.Literal = i8(false, _) val _i16: Short => SExpression.Literal = i16(false, _) @@ -43,7 +75,21 @@ class ToSubstraitLiteral { val _fp64: Double => SExpression.Literal = fp64(false, _) val _decimal: (Decimal, Int, Int) => SExpression.Literal = sparkDecimal2Substrait val _date: Int => SExpression.Literal = date(false, _) + val _timestamp: Long => SExpression.Literal = + precisionTimestamp(false, _, Util.MICROSECOND_PRECISION) // Spark ts is in microseconds + val _timestampTz: Long => SExpression.Literal = + precisionTimestampTZ(false, _, Util.MICROSECOND_PRECISION) // Spark ts is in microseconds + val _intervalDay: Long => SExpression.Literal = (ms: Long) => { + val days = (ms / Util.MICROS_PER_SECOND / Util.SECONDS_PER_DAY).toInt + val seconds = (ms / Util.MICROS_PER_SECOND % Util.SECONDS_PER_DAY).toInt + val micros = ms % Util.MICROS_PER_SECOND + intervalDay(false, days, seconds, micros, Util.MICROSECOND_PRECISION) + } + val _intervalYear: Int => SExpression.Literal = (m: Int) => intervalYear(false, m / 12, m % 12) val _string: String => SExpression.Literal = string(false, _) + val _binary: Array[Byte] => SExpression.Literal = binary(false, _) + val _array: (ArrayData, DataType, Boolean) => SExpression.Literal = sparkArray2Substrait + val _map: (MapData, DataType, DataType, Boolean) => SExpression.Literal = sparkMap2Substrait } private def convertWithValue(literal: Literal): Option[SExpression.Literal] = { @@ -59,7 +105,16 @@ class ToSubstraitLiteral { case Literal(d: Decimal, dataType: DecimalType) => Nonnull._decimal(d, dataType.precision, dataType.scale) case Literal(d: Integer, DateType) => Nonnull._date(d) + case Literal(t: Long, TimestampType) => Nonnull._timestampTz(t) + case Literal(t: Long, TimestampNTZType) => Nonnull._timestamp(t) + case Literal(d: Long, DayTimeIntervalType.DEFAULT) => Nonnull._intervalDay(d) + case Literal(ym: Int, YearMonthIntervalType.DEFAULT) => Nonnull._intervalYear(ym) case Literal(u: UTF8String, StringType) => Nonnull._string(u.toString) + case Literal(b: Array[Byte], BinaryType) => Nonnull._binary(b) + case Literal(a: ArrayData, ArrayType(et, containsNull)) => + Nonnull._array(a, et, containsNull) + case Literal(m: MapData, MapType(keyType, valueType, valueContainsNull)) => + Nonnull._map(m, keyType, valueType, valueContainsNull) case _ => null } ) diff --git a/spark/src/main/scala/io/substrait/utils/Util.scala b/spark/src/main/scala/io/substrait/utils/Util.scala index 165d59953..37a6a631a 100644 --- a/spark/src/main/scala/io/substrait/utils/Util.scala +++ b/spark/src/main/scala/io/substrait/utils/Util.scala @@ -21,6 +21,18 @@ import scala.collection.mutable.ArrayBuffer object Util { + val SECONDS_PER_DAY: Long = 24 * 60 * 60 + val MICROS_PER_SECOND: Long = 1000 * 1000 + val MICROSECOND_PRECISION = 6 // for PrecisionTimestamp(TZ) and IntervalDay types + + def assertMicroseconds(precision: Int): Unit = { + // Spark uses microseconds as a Long value as the "physical" type for most time things + if (precision != MICROSECOND_PRECISION) { + throw new UnsupportedOperationException( + s"Unsupported precision: $precision. Only microsecond precision ($MICROSECOND_PRECISION) is supported") + } + } + /** * Compute the cartesian product for n lists. * diff --git a/spark/src/test/scala/io/substrait/spark/TypesAndLiteralsSuite.scala b/spark/src/test/scala/io/substrait/spark/TypesAndLiteralsSuite.scala new file mode 100644 index 000000000..5246c0069 --- /dev/null +++ b/spark/src/test/scala/io/substrait/spark/TypesAndLiteralsSuite.scala @@ -0,0 +1,120 @@ +package io.substrait.spark + +import io.substrait.spark.expression.{ToSparkExpression, ToSubstraitLiteral} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.util.MapData +import org.apache.spark.sql.types._ +import org.apache.spark.substrait.SparkTypeUtil +import org.apache.spark.unsafe.types.UTF8String + +import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period} + +class TypesAndLiteralsSuite extends SparkFunSuite { + + val toSparkExpression = new ToSparkExpression(null, null) + + val types: Seq[DataType] = List( + ByteType, + ShortType, + IntegerType, + LongType, + FloatType, + DoubleType, + StringType, + BinaryType, + BooleanType, + DecimalType(10, 2), + TimestampNTZType, + TimestampType, + DayTimeIntervalType.DEFAULT, + YearMonthIntervalType.DEFAULT, + ArrayType(IntegerType, containsNull = false), + ArrayType(IntegerType, containsNull = true), + MapType(IntegerType, StringType, valueContainsNull = false), + MapType(IntegerType, StringType, valueContainsNull = true) + ) + + types.foreach( + t => { + test(s"test type: $t") { + // Nullability doesn't matter as in Spark it's not a property of the type + val substraitType = ToSubstraitType.convert(t, nullable = true).get + val sparkType = ToSubstraitType.convert(substraitType) + + println("Before: " + t) + println("After: " + sparkType) + println("Substrait: " + substraitType) + + assert(t == sparkType) + } + }) + + val defaultLiterals: Seq[Literal] = types.map(Literal.default) + + val literals: Seq[Literal] = List( + Literal(1.toByte), + Literal(1.toShort), + Literal(1), + Literal(1L), + Literal(1.0f), + Literal(1.0), + Literal("1"), + Literal(Array[Byte](1)), + Literal(true), + Literal(BigDecimal("123.4567890")), + Literal(Instant.now()), // Timestamp + Literal(LocalDateTime.now()), // TimestampNTZ + Literal(LocalDate.now()), // Date + Literal(Duration.ofDays(1)), // DayTimeInterval + Literal( + Duration.ofDays(1).plusHours(2).plusMinutes(3).plusSeconds(4).plusMillis(5) + ), // DayTimeInterval + Literal(Period.ofYears(1)), // YearMonthInterval + Literal(Period.of(1, 2, 0)), // YearMonthInterval, days are ignored + Literal.create(Array(1, 2, 3), ArrayType(IntegerType, containsNull = false)) +// Literal.create(Array(1, null, 3), ArrayType(IntegerType, containsNull = true)) // TODO: handle containsNulls + ) + + (defaultLiterals ++ literals).foreach( + l => { + test(s"test literal: $l (${l.dataType})") { + val substraitLiteral = ToSubstraitLiteral.convert(l).get + val sparkLiteral = substraitLiteral.accept(toSparkExpression).asInstanceOf[Literal] + + println("Before: " + l + " " + l.dataType) + println("After: " + sparkLiteral + " " + sparkLiteral.dataType) + println("Substrait: " + substraitLiteral) + + assert(l.dataType == sparkLiteral.dataType) // makes understanding failures easier + assert(l == sparkLiteral) + } + }) + + test(s"test map literal") { + val l = Literal.create( + Map(1 -> "a", 2 -> "b"), + MapType(IntegerType, StringType, valueContainsNull = false)) + + val substraitLiteral = ToSubstraitLiteral.convert(l).get + val sparkLiteral = substraitLiteral.accept(toSparkExpression).asInstanceOf[Literal] + + println("Before: " + l + " " + l.dataType) + println("After: " + sparkLiteral + " " + sparkLiteral.dataType) + println("Substrait: " + substraitLiteral) + + assert(l.dataType == sparkLiteral.dataType) // makes understanding failures easier + assert(SparkTypeUtil.sameType(l.dataType, sparkLiteral.dataType)) + + // MapData doesn't implement equality so we have to compare the arrays manually + val originalKeys = l.value.asInstanceOf[MapData].keyArray().toIntArray().sorted + val sparkKeys = sparkLiteral.value.asInstanceOf[MapData].keyArray().toIntArray().sorted + assert(originalKeys.sameElements(sparkKeys)) + + val originalValues = l.value.asInstanceOf[MapData].valueArray().toArray[UTF8String](StringType) + val sparkValues = + sparkLiteral.value.asInstanceOf[MapData].valueArray().toArray[UTF8String](StringType) + assert(originalValues.sorted.sameElements(sparkValues.sorted)) + } +}