Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(spark): add support for more types and literals (binary, list, map, intervals, timestamps) #311

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,11 @@ public OUTPUT visit(Expression.MapLiteral expr) throws EXCEPTION {
return visitFallback(expr);
}

@Override
public OUTPUT visit(Expression.EmptyMapLiteral expr) throws EXCEPTION {
return visitFallback(expr);
}

@Override
public OUTPUT visit(Expression.ListLiteral expr) throws EXCEPTION {
return visitFallback(expr);
Expand Down
19 changes: 19 additions & 0 deletions core/src/main/java/io/substrait/expression/Expression.java
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,25 @@ public <R, E extends Throwable> R accept(ExpressionVisitor<R, E> visitor) throws
}
}

@Value.Immutable
abstract static class EmptyMapLiteral implements Literal {
public abstract Type keyType();

public abstract Type valueType();

public Type getType() {
return Type.withNullability(nullable()).map(keyType(), valueType());
}

public static ImmutableExpression.EmptyMapLiteral.Builder builder() {
return ImmutableExpression.EmptyMapLiteral.builder();
}

public <R, E extends Throwable> R accept(ExpressionVisitor<R, E> visitor) throws E {
return visitor.visit(this);
}
}

@Value.Immutable
abstract static class ListLiteral implements Literal {
public abstract List<Literal> values();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,15 @@ public static Expression.MapLiteral map(
return Expression.MapLiteral.builder().nullable(nullable).putAllValues(values).build();
}

public static Expression.EmptyMapLiteral emptyMap(
boolean nullable, Type keyType, Type valueType) {
return Expression.EmptyMapLiteral.builder()
.keyType(keyType)
.valueType(valueType)
.nullable(nullable)
.build();
}

public static Expression.ListLiteral list(boolean nullable, Expression.Literal... values) {
return Expression.ListLiteral.builder().nullable(nullable).addValues(values).build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ public interface ExpressionVisitor<R, E extends Throwable> {

R visit(Expression.MapLiteral expr) throws E;

R visit(Expression.EmptyMapLiteral expr) throws E;

R visit(Expression.ListLiteral expr) throws E;

R visit(Expression.EmptyListLiteral expr) throws E;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,21 @@ public Expression visit(io.substrait.expression.Expression.MapLiteral expr) {
});
}

@Override
public Expression visit(io.substrait.expression.Expression.EmptyMapLiteral expr) {
return lit(
bldr -> {
var protoMapType = expr.getType().accept(typeProtoConverter);
bldr.setEmptyMap(protoMapType.getMap())
// For empty maps, the Literal message's own nullable field should be ignored
// in favor of the nullability of the Type.Map in the literal's
// empty_map field. But for safety we set the literal's nullable field
// to match in case any readers either look in the wrong location
// or want to verify that they are consistent.
.setNullable(expr.nullable());
});
}

@Override
public Expression visit(io.substrait.expression.Expression.ListLiteral expr) {
return lit(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,12 @@ public Expression.Literal from(io.substrait.proto.Expression.Literal literal) {
literal.getNullable(),
literal.getMap().getKeyValuesList().stream()
.collect(Collectors.toMap(kv -> from(kv.getKey()), kv -> from(kv.getValue()))));
case EMPTY_MAP -> {
// literal.getNullable() is intentionally ignored in favor of the nullability
// specified in the literal.getEmptyMap() type.
var mapType = protoTypeConverter.fromMap(literal.getEmptyMap());
yield ExpressionCreator.emptyMap(mapType.nullable(), mapType.key(), mapType.value());
}
case UUID -> ExpressionCreator.uuid(literal.getNullable(), literal.getUuid());
case NULL -> ExpressionCreator.typedNull(protoTypeConverter.from(literal.getNull()));
case LIST -> ExpressionCreator.list(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,11 @@ public Optional<Expression> visit(Expression.MapLiteral expr) throws EXCEPTION {
return visitLiteral(expr);
}

@Override
public Optional<Expression> visit(Expression.EmptyMapLiteral expr) throws EXCEPTION {
return visitLiteral(expr);
}

@Override
public Optional<Expression> visit(Expression.ListLiteral expr) throws EXCEPTION {
return visitLiteral(expr);
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/java/io/substrait/type/TypeCreator.java
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ public Type.ListType list(Type type) {
return Type.ListType.builder().nullable(nullable).elementType(type).build();
}

public Type map(Type key, Type value) {
public Type.Map map(Type key, Type value) {
vbarua marked this conversation as resolved.
Show resolved Hide resolved
return Type.Map.builder().nullable(nullable).key(key).value(value).build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ public Type from(io.substrait.proto.Type type) {
.map(this::from)
.collect(java.util.stream.Collectors.toList()));
case LIST -> fromList(type.getList());
case MAP -> n(type.getMap().getNullability())
.map(from(type.getMap().getKey()), from(type.getMap().getValue()));
case MAP -> fromMap(type.getMap());
case USER_DEFINED -> {
var userDefined = type.getUserDefined();
var t = lookup.getType(userDefined.getTypeReference(), extensions);
Expand All @@ -74,6 +73,10 @@ public Type.ListType fromList(io.substrait.proto.Type.List list) {
return n(list.getNullability()).list(from(list.getType()));
}

public Type.Map fromMap(io.substrait.proto.Type.Map map) {
return n(map.getNullability()).map(from(map.getKey()), from(map.getValue()));
}

public static boolean isNullable(io.substrait.proto.Type.Nullability nullability) {
return io.substrait.proto.Type.Nullability.NULLABILITY_NULLABLE == nullability;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,8 @@ class ExpressionToString extends DefaultExpressionVisitor[String] {
override def visit(expr: Expression.UserDefinedLiteral): String = {
expr.toString
}

override def visit(expr: Expression.EmptyMapLiteral): String = {
expr.toString
}
}
35 changes: 32 additions & 3 deletions spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,38 @@ private class ToSparkType

override def visit(expr: Type.Str): DataType = StringType

override def visit(expr: Type.Binary): DataType = BinaryType

override def visit(expr: Type.FixedChar): DataType = StringType

override def visit(expr: Type.VarChar): DataType = StringType

override def visit(expr: Type.Bool): DataType = BooleanType

override def visit(expr: Type.PrecisionTimestamp): DataType = {
Util.assertMicroseconds(expr.precision())
TimestampNTZType
}
override def visit(expr: Type.PrecisionTimestampTZ): DataType = {
Util.assertMicroseconds(expr.precision())
TimestampType
}

override def visit(expr: Type.IntervalDay): DataType = {
Util.assertMicroseconds(expr.precision())
DayTimeIntervalType.DEFAULT
}

override def visit(expr: Type.IntervalYear): DataType = YearMonthIntervalType.DEFAULT

override def visit(expr: Type.ListType): DataType =
ArrayType(expr.elementType().accept(this), containsNull = expr.elementType().nullable())

override def visit(expr: Type.Map): DataType =
MapType(
expr.key().accept(this),
expr.value().accept(this),
valueContainsNull = expr.value().nullable())
}
class ToSubstraitType {

Expand Down Expand Up @@ -81,10 +108,12 @@ class ToSubstraitType {
case charType: CharType => Some(creator.fixedChar(charType.length))
case varcharType: VarcharType => Some(creator.varChar(varcharType.length))
case StringType => Some(creator.STRING)
case DateType => Some(creator.DATE)
case TimestampType => Some(creator.TIMESTAMP)
case TimestampNTZType => Some(creator.TIMESTAMP_TZ)
case BinaryType => Some(creator.BINARY)
case DateType => Some(creator.DATE)
case TimestampNTZType => Some(creator.precisionTimestamp(Util.MICROSECOND_PRECISION))
Copy link
Contributor Author

@Blizzara Blizzara Oct 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This changes producing from the deprecated timestamp type to the new precision-aware version. That may break some consumers if they haven't been updated yet.

case TimestampType => Some(creator.precisionTimestampTZ(Util.MICROSECOND_PRECISION))
case DayTimeIntervalType.DEFAULT => Some(creator.intervalDay(Util.MICROSECOND_PRECISION))
case YearMonthIntervalType.DEFAULT => Some(creator.INTERVAL_YEAR)
case ArrayType(elementType, containsNull) =>
convert(elementType, Seq.empty, containsNull).map(creator.list)
case MapType(keyType, valueType, valueContainsNull) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ import io.substrait.`type`.{StringTypeVisitor, Type}
import io.substrait.{expression => exp}
import io.substrait.expression.{Expression => SExpression}
import io.substrait.util.DecimalUtil
import io.substrait.utils.Util

import scala.collection.JavaConverters.asScalaBufferConverter
import scala.collection.JavaConverters.{asScalaBufferConverter, mapAsScalaMapConverter}

class ToSparkExpression(
val scalarFunctionConverter: ToScalarFunction,
Expand Down Expand Up @@ -61,6 +62,10 @@ class ToSparkExpression(
Literal(expr.value(), ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.FP32Literal): Literal = {
Literal(expr.value(), ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.FP64Literal): Expression = {
Literal(expr.value(), ToSubstraitType.convert(expr.getType))
}
Expand All @@ -77,15 +82,71 @@ class ToSparkExpression(
Literal(UTF8String.fromString(expr.value()), ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.BinaryLiteral): Literal = {
Literal(expr.value().toByteArray, ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.DecimalLiteral): Expression = {
val value = expr.value.toByteArray
val decimal = DecimalUtil.getBigDecimalFromBytes(value, expr.scale, 16)
Literal(Decimal(decimal), ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.DateLiteral): Expression = {
Literal(expr.value(), ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.PrecisionTimestampLiteral): Literal = {
// Spark timestamps are stored as a microseconds Long
Util.assertMicroseconds(expr.precision())
Literal(expr.value(), ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.PrecisionTimestampTZLiteral): Literal = {
// Spark timestamps are stored as a microseconds Long
Util.assertMicroseconds(expr.precision())
Literal(expr.value(), ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.IntervalDayLiteral): Literal = {
Util.assertMicroseconds(expr.precision())
// Spark uses a single microseconds Long as the "physical" type for DayTimeInterval
val micros =
(expr.days() * Util.SECONDS_PER_DAY + expr.seconds()) * Util.MICROS_PER_SECOND +
expr.subseconds()
Literal(micros, ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.IntervalYearLiteral): Literal = {
// Spark uses a single months Int as the "physical" type for YearMonthInterval
val months = expr.years() * 12 + expr.months()
Literal(months, ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.ListLiteral): Literal = {
val array = expr.values().asScala.map(value => value.accept(this).asInstanceOf[Literal].value)
Literal.create(array, ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.EmptyListLiteral): Expression = {
Literal.default(ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.MapLiteral): Literal = {
val map = expr.values().asScala.map {
case (key, value) =>
(
key.accept(this).asInstanceOf[Literal].value,
value.accept(this).asInstanceOf[Literal].value
)
}
Literal.create(map, ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.EmptyMapLiteral): Literal = {
Literal.default(ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.NullLiteral): Expression = {
Literal(null, ToSubstraitType.convert(expr.getType))
}
Expand All @@ -98,6 +159,7 @@ class ToSparkExpression(
override def visit(expr: exp.FieldReference): Expression = {
withFieldReference(expr)(i => currentOutput(i).clone())
}

override def visit(expr: SExpression.IfThen): Expression = {
val branches = expr
.ifClauses()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,15 @@ package io.substrait.spark.expression
import io.substrait.spark.ToSubstraitType

import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

import io.substrait.expression.{Expression => SExpression}
import io.substrait.expression.ExpressionCreator._
import io.substrait.utils.Util

import scala.collection.JavaConverters

class ToSubstraitLiteral {

Expand All @@ -34,6 +38,34 @@ class ToSubstraitLiteral {
scale: Int): SExpression.Literal =
decimal(false, d.toJavaBigDecimal, precision, scale)

private def sparkArray2Substrait(
arrayData: ArrayData,
elementType: DataType,
containsNull: Boolean): SExpression.Literal = {
val elements = arrayData.array.map(any => apply(Literal(any, elementType)))
if (elements.isEmpty) {
return emptyList(false, ToSubstraitType.convert(elementType, nullable = containsNull).get)
}
list(false, JavaConverters.asJavaIterable(elements)) // TODO: handle containsNull
Copy link
Contributor Author

@Blizzara Blizzara Oct 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handling containsNull here is non-trivial, since Substrait doesn't have such a concept. The nullability of the non-empty list is taken from the nullability of the element type, which is gotten from the first element. So we'll need to encode containsNull into the elementType. It can be done, but requires some more work (as there's no "withNullable" method or anything) and this PR is already big enough, so I'll do it as a followup.

Same applies for map below

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not entirely sure what you mean here.

The nullability of the non-empty list is taken from the nullability of the element type, which is gotten from the first element.

Is this when going from Substrait to Spark or from Spark to Substrait?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Substrait itself doesn't contain the "containsNull" arg, so the Substrait -> Spark conversion has to infer it, and the Spark -> Substrait conversion doesn't provide the correct information to infer it based on.

Easiest to see from the commented-out test:
Starting with Spark: value: [1,null,3] type: ArrayType(IntegerType,true)
Gets converted into Substrait literal: ListLiteral{nullable=false, values=[I32Literal{nullable=false, value=1}, NullLiteral{nullable=false, type=I32{nullable=true}}, I32Literal{nullable=false, value=3}]}

Note that there is no "containsNull" anywhere. So the way we infer it when converting back into Substrait is to look at the first element and it's nullable arg. Which in today's world is always nullabe=false, leading to:

Spark: value: [1,null,3] type: ArrayType(IntegerType,false)

(Funnily enough that's the case even if the first element is a null.. but fixing that in itself wouldn't be enough, since the null might be elsewhere in the list or even in a different row, but we should have the "type" of the literal match across rows.)

I think fixing this requires changing the ToSubstraitLiteral.convert/convertWithValue so that the nullability they set can be overridden in this case.

}

private def sparkMap2Substrait(
mapData: MapData,
keyType: DataType,
valueType: DataType,
valueContainsNull: Boolean): SExpression.Literal = {
val keys = mapData.keyArray().array.map(any => apply(Literal(any, keyType)))
val values = mapData.valueArray().array.map(any => apply(Literal(any, valueType)))
if (keys.isEmpty) {
return emptyMap(
false,
ToSubstraitType.convert(keyType, nullable = false).get,
ToSubstraitType.convert(valueType, nullable = valueContainsNull).get)
}
// TODO: handle valueContainsNull
map(false, JavaConverters.mapAsJavaMap(keys.zip(values).toMap))
}

val _bool: Boolean => SExpression.Literal = bool(false, _)
val _i8: Byte => SExpression.Literal = i8(false, _)
val _i16: Short => SExpression.Literal = i16(false, _)
Expand All @@ -43,7 +75,21 @@ class ToSubstraitLiteral {
val _fp64: Double => SExpression.Literal = fp64(false, _)
val _decimal: (Decimal, Int, Int) => SExpression.Literal = sparkDecimal2Substrait
val _date: Int => SExpression.Literal = date(false, _)
val _timestamp: Long => SExpression.Literal =
precisionTimestamp(false, _, Util.MICROSECOND_PRECISION) // Spark ts is in microseconds
val _timestampTz: Long => SExpression.Literal =
precisionTimestampTZ(false, _, Util.MICROSECOND_PRECISION) // Spark ts is in microseconds
val _intervalDay: Long => SExpression.Literal = (ms: Long) => {
val days = (ms / Util.MICROS_PER_SECOND / Util.SECONDS_PER_DAY).toInt
val seconds = (ms / Util.MICROS_PER_SECOND % Util.SECONDS_PER_DAY).toInt
val micros = ms % Util.MICROS_PER_SECOND
intervalDay(false, days, seconds, micros, Util.MICROSECOND_PRECISION)
}
val _intervalYear: Int => SExpression.Literal = (m: Int) => intervalYear(false, m / 12, m % 12)
val _string: String => SExpression.Literal = string(false, _)
val _binary: Array[Byte] => SExpression.Literal = binary(false, _)
val _array: (ArrayData, DataType, Boolean) => SExpression.Literal = sparkArray2Substrait
val _map: (MapData, DataType, DataType, Boolean) => SExpression.Literal = sparkMap2Substrait
}

private def convertWithValue(literal: Literal): Option[SExpression.Literal] = {
Expand All @@ -59,7 +105,16 @@ class ToSubstraitLiteral {
case Literal(d: Decimal, dataType: DecimalType) =>
Nonnull._decimal(d, dataType.precision, dataType.scale)
case Literal(d: Integer, DateType) => Nonnull._date(d)
case Literal(t: Long, TimestampType) => Nonnull._timestampTz(t)
case Literal(t: Long, TimestampNTZType) => Nonnull._timestamp(t)
case Literal(d: Long, DayTimeIntervalType.DEFAULT) => Nonnull._intervalDay(d)
case Literal(ym: Int, YearMonthIntervalType.DEFAULT) => Nonnull._intervalYear(ym)
case Literal(u: UTF8String, StringType) => Nonnull._string(u.toString)
case Literal(b: Array[Byte], BinaryType) => Nonnull._binary(b)
case Literal(a: ArrayData, ArrayType(et, containsNull)) =>
Nonnull._array(a, et, containsNull)
case Literal(m: MapData, MapType(keyType, valueType, valueContainsNull)) =>
Nonnull._map(m, keyType, valueType, valueContainsNull)
case _ => null
}
)
Expand Down
Loading
Loading