Skip to content

Commit

Permalink
feat(spark): bitwise functions
Browse files Browse the repository at this point in the history
Adds support in the spark module for 8-bit and 16-bit integer types and for some bitwise functions.
The catalyst optimizer generates expressions using these for certain query types.

Note that `shift_right` (and other bit shifting functions) might want to be considered for the
core substrait function catalog, but it has been added here (temporarily?) as spark extension
pending a longer term discussion/decision on their wider utility.

Signed-off-by: Andrew Coleman <[email protected]>
  • Loading branch information
andrew-coleman committed Oct 25, 2024
1 parent 6413e55 commit 024b800
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 4 deletions.
19 changes: 19 additions & 0 deletions spark/src/main/resources/spark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,22 @@ scalar_functions:
- args:
- value: i64
return: DECIMAL<P,S>
- name: shift_right
description: >-
Bitwise (signed) shift right.
Params:
base – the base number to shift.
shift – number of bits to right shift.
impls:
- args:
- name: base
value: i64
- name: shift
value: i32
return: i64
- args:
- name: base
value: i32
- name: shift
value: i32
return: i32
2 changes: 2 additions & 0 deletions spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ import scala.collection.JavaConverters.asScalaBufferConverter
private class ToSparkType
extends TypeVisitor.TypeThrowsVisitor[DataType, RuntimeException]("Unknown expression type.") {

override def visit(expr: Type.I8): DataType = ByteType
override def visit(expr: Type.I16): DataType = ShortType
override def visit(expr: Type.I32): DataType = IntegerType
override def visit(expr: Type.I64): DataType = LongType

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ class FunctionMappings {
s[Concat]("concat"),
s[Coalesce]("coalesce"),
s[Year]("year"),
s[ShiftRight]("shift_right"),
s[BitwiseAnd]("bitwise_and"),
s[BitwiseOr]("bitwise_or"),
s[BitwiseXor]("bitwise_xor"),

// internal
s[MakeDecimal]("make_decimal"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,15 @@ class ToSparkExpression(
Literal.FalseLiteral
}
}

override def visit(expr: SExpression.I8Literal): Expression = {
Literal(expr.value().asInstanceOf[Byte], ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.I16Literal): Expression = {
Literal(expr.value().asInstanceOf[Short], ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.I32Literal): Expression = {
Literal(expr.value(), ToSubstraitType.convert(expr.getType))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,12 @@ abstract class ToSubstraitExpression extends HasOutputStack[Seq[Attribute]] {
"org.apache.spark.sql.catalyst.expressions.PromotePrecision") =>
translateUp(p.children.head)
case CaseWhen(branches, elseValue) => translateCaseWhen(branches, elseValue)
case In(value, list) => translateIn(value, list)
case InSet(value, set) => translateIn(value, set.toSeq.map(v => Literal(v)))
case scalar @ ScalarFunction(children) =>
Util
.seqToOption(children.map(translateUp))
.flatMap(toScalarFunction.convert(scalar, _))
case In(value, list) => translateIn(value, list)
case p: PlanExpression[_] => translateSubQuery(p)
case other => default(other)
}
Expand Down
6 changes: 3 additions & 3 deletions spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ class TPCDSPlan extends TPCDSBase with SubstraitPlanTestBase {
}

// spotless:off
val successfulSQL: Set[String] = Set("q1", "q3", "q4", "q7",
"q11", "q13", "q14a", "q14b", "q15", "q16", "q18", "q19",
"q21", "q22", "q23a", "q23b", "q24a", "q24b", "q25", "q26", "q28", "q29",
val successfulSQL: Set[String] = Set("q1", "q3", "q4", "q7", "q8",
"q11", "q13", "q14a" "q14b", "q15", "q16", "q18", "q19",
"q21", "q22", "q23a", "q23b", "q24a", "q24b", "q25", "q26", "q27", "q28", "q29",
"q30", "q31", "q32", "q33", "q37", "q38",
"q40", "q41", "q42", "q43", "q46", "q48",
"q50", "q52", "q54", "q55", "q56", "q58", "q59",
Expand Down

0 comments on commit 024b800

Please sign in to comment.