Skip to content

Commit eec034a

Browse files
committed
add complex type support to shuffle
1 parent b8be7b7 commit eec034a

File tree

4 files changed

+91
-87
lines changed

4 files changed

+91
-87
lines changed

spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ class CometSparkSessionExtensions
240240
plan.transformUp {
241241
case s: ShuffleExchangeExec
242242
if isCometPlan(s.child) && isCometNativeShuffleMode(conf) &&
243-
QueryPlanSerde.supportPartitioning(s.child.output, s.outputPartitioning)._1 =>
243+
QueryPlanSerde.nativeShuffleSupported(s)._1 =>
244244
logInfo("Comet extension enabled for Native Shuffle")
245245

246246
// Switch to use Decimal128 regardless of precision, since Arrow native execution
@@ -253,7 +253,7 @@ class CometSparkSessionExtensions
253253
case s: ShuffleExchangeExec
254254
if (!s.child.supportsColumnar || isCometPlan(s.child)) && isCometJVMShuffleMode(
255255
conf) &&
256-
QueryPlanSerde.supportPartitioningTypes(s.child.output, s.outputPartitioning)._1 &&
256+
QueryPlanSerde.columnarShuffleSupported(s)._1 &&
257257
!isShuffleOperator(s.child) =>
258258
logInfo("Comet extension enabled for JVM Columnar Shuffle")
259259
CometShuffleExchangeExec(s, shuffleType = CometColumnarShuffle)
@@ -719,7 +719,7 @@ class CometSparkSessionExtensions
719719
case s: ShuffleExchangeExec =>
720720
val nativePrecondition = isCometShuffleEnabled(conf) &&
721721
isCometNativeShuffleMode(conf) &&
722-
QueryPlanSerde.supportPartitioning(s.child.output, s.outputPartitioning)._1
722+
QueryPlanSerde.nativeShuffleSupported(s)._1
723723

724724
val nativeShuffle: Option[SparkPlan] =
725725
if (nativePrecondition) {
@@ -753,7 +753,7 @@ class CometSparkSessionExtensions
753753
// If the child of ShuffleExchangeExec is also a ShuffleExchangeExec, we should not
754754
// convert it to CometColumnarShuffle,
755755
if (isCometShuffleEnabled(conf) && isCometJVMShuffleMode(conf) &&
756-
QueryPlanSerde.supportPartitioningTypes(s.child.output, s.outputPartitioning)._1 &&
756+
QueryPlanSerde.columnarShuffleSupported(s)._1 &&
757757
!isShuffleOperator(s.child)) {
758758

759759
val newOp = QueryPlanSerde.operator2Proto(s)
@@ -781,22 +781,22 @@ class CometSparkSessionExtensions
781781
nativeOrColumnarShuffle.get
782782
} else {
783783
val isShuffleEnabled = isCometShuffleEnabled(conf)
784-
val outputPartitioning = s.outputPartitioning
784+
s.outputPartitioning
785785
val reason = getCometShuffleNotEnabledReason(conf).getOrElse("no reason available")
786786
val msg1 = createMessage(!isShuffleEnabled, s"Comet shuffle is not enabled: $reason")
787787
val columnarShuffleEnabled = isCometJVMShuffleMode(conf)
788788
val msg2 = createMessage(
789789
isShuffleEnabled && !columnarShuffleEnabled && !QueryPlanSerde
790-
.supportPartitioning(s.child.output, outputPartitioning)
790+
.nativeShuffleSupported(s)
791791
._1,
792792
"Native shuffle: " +
793-
s"${QueryPlanSerde.supportPartitioning(s.child.output, outputPartitioning)._2}")
793+
s"${QueryPlanSerde.nativeShuffleSupported(s)._2}")
794794
val typeInfo = QueryPlanSerde
795-
.supportPartitioningTypes(s.child.output, outputPartitioning)
795+
.columnarShuffleSupported(s)
796796
._2
797797
val msg3 = createMessage(
798798
isShuffleEnabled && columnarShuffleEnabled && !QueryPlanSerde
799-
.supportPartitioningTypes(s.child.output, outputPartitioning)
799+
.columnarShuffleSupported(s)
800800
._1,
801801
"JVM shuffle: " +
802802
s"$typeInfo")

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 56 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
558558
case c @ Cast(child, dt, timeZoneId, _) =>
559559
handleCast(expr, child, inputs, binding, dt, timeZoneId, evalMode(c))
560560

561-
case add @ Add(left, right, _) if supportedDataType(left.dataType) =>
561+
case add @ Add(left, right, _) if supportedShuffleDataType(left.dataType) =>
562562
createMathExpression(
563563
expr,
564564
left,
@@ -569,11 +569,11 @@ object QueryPlanSerde extends Logging with CometExprShim {
569569
add.evalMode == EvalMode.ANSI,
570570
(builder, mathExpr) => builder.setAdd(mathExpr))
571571

572-
case add @ Add(left, _, _) if !supportedDataType(left.dataType) =>
572+
case add @ Add(left, _, _) if !supportedShuffleDataType(left.dataType) =>
573573
withInfo(add, s"Unsupported datatype ${left.dataType}")
574574
None
575575

576-
case sub @ Subtract(left, right, _) if supportedDataType(left.dataType) =>
576+
case sub @ Subtract(left, right, _) if supportedShuffleDataType(left.dataType) =>
577577
createMathExpression(
578578
expr,
579579
left,
@@ -584,11 +584,11 @@ object QueryPlanSerde extends Logging with CometExprShim {
584584
sub.evalMode == EvalMode.ANSI,
585585
(builder, mathExpr) => builder.setSubtract(mathExpr))
586586

587-
case sub @ Subtract(left, _, _) if !supportedDataType(left.dataType) =>
587+
case sub @ Subtract(left, _, _) if !supportedShuffleDataType(left.dataType) =>
588588
withInfo(sub, s"Unsupported datatype ${left.dataType}")
589589
None
590590

591-
case mul @ Multiply(left, right, _) if supportedDataType(left.dataType) =>
591+
case mul @ Multiply(left, right, _) if supportedShuffleDataType(left.dataType) =>
592592
createMathExpression(
593593
expr,
594594
left,
@@ -600,12 +600,12 @@ object QueryPlanSerde extends Logging with CometExprShim {
600600
(builder, mathExpr) => builder.setMultiply(mathExpr))
601601

602602
case mul @ Multiply(left, _, _) =>
603-
if (!supportedDataType(left.dataType)) {
603+
if (!supportedShuffleDataType(left.dataType)) {
604604
withInfo(mul, s"Unsupported datatype ${left.dataType}")
605605
}
606606
None
607607

608-
case div @ Divide(left, right, _) if supportedDataType(left.dataType) =>
608+
case div @ Divide(left, right, _) if supportedShuffleDataType(left.dataType) =>
609609
// Datafusion now throws an exception for dividing by zero
610610
// See https://github.com/apache/arrow-datafusion/pull/6792
611611
// For now, use NullIf to swap zeros with nulls.
@@ -622,12 +622,12 @@ object QueryPlanSerde extends Logging with CometExprShim {
622622
(builder, mathExpr) => builder.setDivide(mathExpr))
623623

624624
case div @ Divide(left, _, _) =>
625-
if (!supportedDataType(left.dataType)) {
625+
if (!supportedShuffleDataType(left.dataType)) {
626626
withInfo(div, s"Unsupported datatype ${left.dataType}")
627627
}
628628
None
629629

630-
case div @ IntegralDivide(left, right, _) if supportedDataType(left.dataType) =>
630+
case div @ IntegralDivide(left, right, _) if supportedShuffleDataType(left.dataType) =>
631631
val rightExpr = nullIfWhenPrimitive(right)
632632

633633
val dataType = (left.dataType, right.dataType) match {
@@ -671,12 +671,12 @@ object QueryPlanSerde extends Logging with CometExprShim {
671671
}
672672

673673
case div @ IntegralDivide(left, _, _) =>
674-
if (!supportedDataType(left.dataType)) {
674+
if (!supportedShuffleDataType(left.dataType)) {
675675
withInfo(div, s"Unsupported datatype ${left.dataType}")
676676
}
677677
None
678678

679-
case rem @ Remainder(left, right, _) if supportedDataType(left.dataType) =>
679+
case rem @ Remainder(left, right, _) if supportedShuffleDataType(left.dataType) =>
680680
val rightExpr = nullIfWhenPrimitive(right)
681681

682682
createMathExpression(
@@ -690,7 +690,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
690690
(builder, mathExpr) => builder.setRemainder(mathExpr))
691691

692692
case rem @ Remainder(left, _, _) =>
693-
if (!supportedDataType(left.dataType)) {
693+
if (!supportedShuffleDataType(left.dataType)) {
694694
withInfo(rem, s"Unsupported datatype ${left.dataType}")
695695
}
696696
None
@@ -816,7 +816,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
816816
withInfo(expr, s"Unsupported datatype $dataType")
817817
None
818818
}
819-
case Literal(_, dataType) if !supportedDataType(dataType) =>
819+
case Literal(_, dataType) if !supportedShuffleDataType(dataType) =>
820820
withInfo(expr, s"Unsupported datatype $dataType")
821821
None
822822

@@ -1786,7 +1786,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
17861786
ExprOuterClass.Expr.newBuilder().setNormalizeNanAndZero(builder).build()
17871787
}
17881788

1789-
case s @ execution.ScalarSubquery(_, _) if supportedDataType(s.dataType) =>
1789+
case s @ execution.ScalarSubquery(_, _) if supportedShuffleDataType(s.dataType) =>
17901790
val dataType = serializeDataType(s.dataType)
17911791
if (dataType.isEmpty) {
17921792
withInfo(s, s"Scalar subquery returns unsupported datatype ${s.dataType}")
@@ -2785,52 +2785,28 @@ object QueryPlanSerde extends Logging with CometExprShim {
27852785
* Check if the datatypes of shuffle input are supported. This is used for Columnar shuffle
27862786
* which supports struct/array.
27872787
*/
2788-
def supportPartitioningTypes(
2789-
inputs: Seq[Attribute],
2790-
partitioning: Partitioning): (Boolean, String) = {
2791-
def supportedDataType(dt: DataType): Boolean = dt match {
2792-
case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType |
2793-
_: DoubleType | _: StringType | _: BinaryType | _: TimestampType | _: DecimalType |
2794-
_: DateType | _: BooleanType =>
2795-
true
2796-
case StructType(fields) =>
2797-
fields.forall(f => supportedDataType(f.dataType)) &&
2798-
// Java Arrow stream reader cannot work on duplicate field name
2799-
fields.map(f => f.name).distinct.length == fields.length
2800-
case ArrayType(ArrayType(_, _), _) => false // TODO: nested array is not supported
2801-
case ArrayType(MapType(_, _, _), _) => false // TODO: map array element is not supported
2802-
case ArrayType(elementType, _) =>
2803-
supportedDataType(elementType)
2804-
case MapType(MapType(_, _, _), _, _) => false // TODO: nested map is not supported
2805-
case MapType(_, MapType(_, _, _), _) => false
2806-
case MapType(StructType(_), _, _) => false // TODO: struct map key/value is not supported
2807-
case MapType(_, StructType(_), _) => false
2808-
case MapType(ArrayType(_, _), _, _) => false // TODO: array map key/value is not supported
2809-
case MapType(_, ArrayType(_, _), _) => false
2810-
case MapType(keyType, valueType, _) =>
2811-
supportedDataType(keyType) && supportedDataType(valueType)
2812-
case _ =>
2813-
false
2814-
}
2815-
2788+
def columnarShuffleSupported(s: ShuffleExchangeExec): (Boolean, String) = {
2789+
val inputs = s.child.output
2790+
val partitioning = s.outputPartitioning
28162791
var msg = ""
28172792
val supported = partitioning match {
28182793
case HashPartitioning(expressions, _) =>
28192794
val supported =
28202795
expressions.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) &&
2821-
expressions.forall(e => supportedDataType(e.dataType)) &&
2822-
inputs.forall(attr => supportedDataType(attr.dataType))
2796+
expressions.forall(e => supportedShuffleDataType(e.dataType)) &&
2797+
inputs.forall(attr => supportedShuffleDataType(attr.dataType))
28232798
if (!supported) {
28242799
msg = s"unsupported Spark partitioning expressions: $expressions"
28252800
}
28262801
supported
2827-
case SinglePartition => inputs.forall(attr => supportedDataType(attr.dataType))
2828-
case RoundRobinPartitioning(_) => inputs.forall(attr => supportedDataType(attr.dataType))
2802+
case SinglePartition => inputs.forall(attr => supportedShuffleDataType(attr.dataType))
2803+
case RoundRobinPartitioning(_) =>
2804+
inputs.forall(attr => supportedShuffleDataType(attr.dataType))
28292805
case RangePartitioning(orderings, _) =>
28302806
val supported =
28312807
orderings.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) &&
2832-
orderings.forall(e => supportedDataType(e.dataType)) &&
2833-
inputs.forall(attr => supportedDataType(attr.dataType))
2808+
orderings.forall(e => supportedShuffleDataType(e.dataType)) &&
2809+
inputs.forall(attr => supportedShuffleDataType(attr.dataType))
28342810
if (!supported) {
28352811
msg = s"unsupported Spark partitioning expressions: $orderings"
28362812
}
@@ -2849,33 +2825,23 @@ object QueryPlanSerde extends Logging with CometExprShim {
28492825
}
28502826

28512827
/**
2852-
* Whether the given Spark partitioning is supported by Comet.
2828+
* Whether the given Spark partitioning is supported by Comet native shuffle.
28532829
*/
2854-
def supportPartitioning(
2855-
inputs: Seq[Attribute],
2856-
partitioning: Partitioning): (Boolean, String) = {
2857-
def supportedDataType(dt: DataType): Boolean = dt match {
2858-
case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType |
2859-
_: DoubleType | _: StringType | _: BinaryType | _: TimestampType | _: DecimalType |
2860-
_: DateType | _: BooleanType =>
2861-
true
2862-
case _ =>
2863-
// Native shuffle doesn't support struct/array yet
2864-
false
2865-
}
2866-
2830+
def nativeShuffleSupported(s: ShuffleExchangeExec): (Boolean, String) = {
2831+
val inputs = s.child.output
2832+
val partitioning = s.outputPartitioning
28672833
var msg = ""
28682834
val supported = partitioning match {
28692835
case HashPartitioning(expressions, _) =>
28702836
val supported =
28712837
expressions.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) &&
2872-
expressions.forall(e => supportedDataType(e.dataType)) &&
2873-
inputs.forall(attr => supportedDataType(attr.dataType))
2838+
expressions.forall(e => supportedShuffleDataType(e.dataType)) &&
2839+
inputs.forall(attr => supportedShuffleDataType(attr.dataType))
28742840
if (!supported) {
28752841
msg = s"unsupported Spark partitioning expressions: $expressions"
28762842
}
28772843
supported
2878-
case SinglePartition => inputs.forall(attr => supportedDataType(attr.dataType))
2844+
case SinglePartition => inputs.forall(attr => supportedShuffleDataType(attr.dataType))
28792845
case _ =>
28802846
msg = s"unsupported Spark partitioning: ${partitioning.getClass.getName}"
28812847
false
@@ -2889,6 +2855,31 @@ object QueryPlanSerde extends Logging with CometExprShim {
28892855
}
28902856
}
28912857

2858+
def supportedShuffleDataType(dt: DataType): Boolean = dt match {
2859+
case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType |
2860+
_: DoubleType | _: StringType | _: BinaryType | _: TimestampType | _: TimestampNTZType |
2861+
_: DecimalType | _: DateType | _: BooleanType =>
2862+
true
2863+
case StructType(fields) =>
2864+
fields.forall(f => supportedShuffleDataType(f.dataType)) &&
2865+
// Java Arrow stream reader cannot work on duplicate field name
2866+
fields.map(f => f.name).distinct.length == fields.length
2867+
case ArrayType(ArrayType(_, _), _) => false // TODO: nested array is not supported
2868+
case ArrayType(MapType(_, _, _), _) => false // TODO: map array element is not supported
2869+
case ArrayType(elementType, _) =>
2870+
supportedShuffleDataType(elementType)
2871+
case MapType(MapType(_, _, _), _, _) => false // TODO: nested map is not supported
2872+
case MapType(_, MapType(_, _, _), _) => false
2873+
case MapType(StructType(_), _, _) => false // TODO: struct map key/value is not supported
2874+
case MapType(_, StructType(_), _) => false
2875+
case MapType(ArrayType(_, _), _, _) => false // TODO: array map key/value is not supported
2876+
case MapType(_, ArrayType(_, _), _) => false
2877+
case MapType(keyType, valueType, _) =>
2878+
supportedShuffleDataType(keyType) && supportedShuffleDataType(valueType)
2879+
case _ =>
2880+
false
2881+
}
2882+
28922883
// Utility method. Adds explain info if the result of calling exprToProto is None
28932884
def optExprWithInfo(
28942885
optExpr: Option[Expr],

spark/src/main/scala/org/apache/comet/serde/hash.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Murmur3
2323
import org.apache.spark.sql.types.{DecimalType, IntegerType, LongType}
2424

2525
import org.apache.comet.CometSparkSessionExtensions.withInfo
26-
import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, scalarExprToProtoWithReturnType, serializeDataType, supportedDataType}
26+
import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, scalarExprToProtoWithReturnType, serializeDataType, supportedShuffleDataType}
2727

2828
object CometXxHash64 extends CometExpressionSerde {
2929
override def convert(
@@ -74,7 +74,7 @@ private object HashUtils {
7474
// Java BigDecimal before hashing
7575
withInfo(expr, s"Unsupported datatype: $dt (precision > 18)")
7676
return false
77-
case dt if !supportedDataType(dt) =>
77+
case dt if !supportedShuffleDataType(dt) =>
7878
withInfo(expr, s"Unsupported datatype $dt")
7979
return false
8080
case _ =>

0 commit comments

Comments
 (0)