@@ -558,7 +558,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
558
558
case c @ Cast (child, dt, timeZoneId, _) =>
559
559
handleCast(expr, child, inputs, binding, dt, timeZoneId, evalMode(c))
560
560
561
- case add @ Add (left, right, _) if supportedDataType (left.dataType) =>
561
+ case add @ Add (left, right, _) if supportedShuffleDataType (left.dataType) =>
562
562
createMathExpression(
563
563
expr,
564
564
left,
@@ -569,11 +569,11 @@ object QueryPlanSerde extends Logging with CometExprShim {
569
569
add.evalMode == EvalMode .ANSI ,
570
570
(builder, mathExpr) => builder.setAdd(mathExpr))
571
571
572
- case add @ Add (left, _, _) if ! supportedDataType (left.dataType) =>
572
+ case add @ Add (left, _, _) if ! supportedShuffleDataType (left.dataType) =>
573
573
withInfo(add, s " Unsupported datatype ${left.dataType}" )
574
574
None
575
575
576
- case sub @ Subtract (left, right, _) if supportedDataType (left.dataType) =>
576
+ case sub @ Subtract (left, right, _) if supportedShuffleDataType (left.dataType) =>
577
577
createMathExpression(
578
578
expr,
579
579
left,
@@ -584,11 +584,11 @@ object QueryPlanSerde extends Logging with CometExprShim {
584
584
sub.evalMode == EvalMode .ANSI ,
585
585
(builder, mathExpr) => builder.setSubtract(mathExpr))
586
586
587
- case sub @ Subtract (left, _, _) if ! supportedDataType (left.dataType) =>
587
+ case sub @ Subtract (left, _, _) if ! supportedShuffleDataType (left.dataType) =>
588
588
withInfo(sub, s " Unsupported datatype ${left.dataType}" )
589
589
None
590
590
591
- case mul @ Multiply (left, right, _) if supportedDataType (left.dataType) =>
591
+ case mul @ Multiply (left, right, _) if supportedShuffleDataType (left.dataType) =>
592
592
createMathExpression(
593
593
expr,
594
594
left,
@@ -600,12 +600,12 @@ object QueryPlanSerde extends Logging with CometExprShim {
600
600
(builder, mathExpr) => builder.setMultiply(mathExpr))
601
601
602
602
case mul @ Multiply (left, _, _) =>
603
- if (! supportedDataType (left.dataType)) {
603
+ if (! supportedShuffleDataType (left.dataType)) {
604
604
withInfo(mul, s " Unsupported datatype ${left.dataType}" )
605
605
}
606
606
None
607
607
608
- case div @ Divide (left, right, _) if supportedDataType (left.dataType) =>
608
+ case div @ Divide (left, right, _) if supportedShuffleDataType (left.dataType) =>
609
609
// Datafusion now throws an exception for dividing by zero
610
610
// See https://github.com/apache/arrow-datafusion/pull/6792
611
611
// For now, use NullIf to swap zeros with nulls.
@@ -622,12 +622,12 @@ object QueryPlanSerde extends Logging with CometExprShim {
622
622
(builder, mathExpr) => builder.setDivide(mathExpr))
623
623
624
624
case div @ Divide (left, _, _) =>
625
- if (! supportedDataType (left.dataType)) {
625
+ if (! supportedShuffleDataType (left.dataType)) {
626
626
withInfo(div, s " Unsupported datatype ${left.dataType}" )
627
627
}
628
628
None
629
629
630
- case div @ IntegralDivide (left, right, _) if supportedDataType (left.dataType) =>
630
+ case div @ IntegralDivide (left, right, _) if supportedShuffleDataType (left.dataType) =>
631
631
val rightExpr = nullIfWhenPrimitive(right)
632
632
633
633
val dataType = (left.dataType, right.dataType) match {
@@ -671,12 +671,12 @@ object QueryPlanSerde extends Logging with CometExprShim {
671
671
}
672
672
673
673
case div @ IntegralDivide (left, _, _) =>
674
- if (! supportedDataType (left.dataType)) {
674
+ if (! supportedShuffleDataType (left.dataType)) {
675
675
withInfo(div, s " Unsupported datatype ${left.dataType}" )
676
676
}
677
677
None
678
678
679
- case rem @ Remainder (left, right, _) if supportedDataType (left.dataType) =>
679
+ case rem @ Remainder (left, right, _) if supportedShuffleDataType (left.dataType) =>
680
680
val rightExpr = nullIfWhenPrimitive(right)
681
681
682
682
createMathExpression(
@@ -690,7 +690,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
690
690
(builder, mathExpr) => builder.setRemainder(mathExpr))
691
691
692
692
case rem @ Remainder (left, _, _) =>
693
- if (! supportedDataType (left.dataType)) {
693
+ if (! supportedShuffleDataType (left.dataType)) {
694
694
withInfo(rem, s " Unsupported datatype ${left.dataType}" )
695
695
}
696
696
None
@@ -816,7 +816,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
816
816
withInfo(expr, s " Unsupported datatype $dataType" )
817
817
None
818
818
}
819
- case Literal (_, dataType) if ! supportedDataType (dataType) =>
819
+ case Literal (_, dataType) if ! supportedShuffleDataType (dataType) =>
820
820
withInfo(expr, s " Unsupported datatype $dataType" )
821
821
None
822
822
@@ -1786,7 +1786,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
1786
1786
ExprOuterClass .Expr .newBuilder().setNormalizeNanAndZero(builder).build()
1787
1787
}
1788
1788
1789
- case s @ execution.ScalarSubquery (_, _) if supportedDataType (s.dataType) =>
1789
+ case s @ execution.ScalarSubquery (_, _) if supportedShuffleDataType (s.dataType) =>
1790
1790
val dataType = serializeDataType(s.dataType)
1791
1791
if (dataType.isEmpty) {
1792
1792
withInfo(s, s " Scalar subquery returns unsupported datatype ${s.dataType}" )
@@ -2785,52 +2785,28 @@ object QueryPlanSerde extends Logging with CometExprShim {
2785
2785
* Check if the datatypes of shuffle input are supported. This is used for Columnar shuffle
2786
2786
* which supports struct/array.
2787
2787
*/
2788
- def supportPartitioningTypes (
2789
- inputs : Seq [Attribute ],
2790
- partitioning : Partitioning ): (Boolean , String ) = {
2791
- def supportedDataType (dt : DataType ): Boolean = dt match {
2792
- case _ : ByteType | _ : ShortType | _ : IntegerType | _ : LongType | _ : FloatType |
2793
- _ : DoubleType | _ : StringType | _ : BinaryType | _ : TimestampType | _ : DecimalType |
2794
- _ : DateType | _ : BooleanType =>
2795
- true
2796
- case StructType (fields) =>
2797
- fields.forall(f => supportedDataType(f.dataType)) &&
2798
- // Java Arrow stream reader cannot work on duplicate field name
2799
- fields.map(f => f.name).distinct.length == fields.length
2800
- case ArrayType (ArrayType (_, _), _) => false // TODO: nested array is not supported
2801
- case ArrayType (MapType (_, _, _), _) => false // TODO: map array element is not supported
2802
- case ArrayType (elementType, _) =>
2803
- supportedDataType(elementType)
2804
- case MapType (MapType (_, _, _), _, _) => false // TODO: nested map is not supported
2805
- case MapType (_, MapType (_, _, _), _) => false
2806
- case MapType (StructType (_), _, _) => false // TODO: struct map key/value is not supported
2807
- case MapType (_, StructType (_), _) => false
2808
- case MapType (ArrayType (_, _), _, _) => false // TODO: array map key/value is not supported
2809
- case MapType (_, ArrayType (_, _), _) => false
2810
- case MapType (keyType, valueType, _) =>
2811
- supportedDataType(keyType) && supportedDataType(valueType)
2812
- case _ =>
2813
- false
2814
- }
2815
-
2788
+ def columnarShuffleSupported (s : ShuffleExchangeExec ): (Boolean , String ) = {
2789
+ val inputs = s.child.output
2790
+ val partitioning = s.outputPartitioning
2816
2791
var msg = " "
2817
2792
val supported = partitioning match {
2818
2793
case HashPartitioning (expressions, _) =>
2819
2794
val supported =
2820
2795
expressions.map(QueryPlanSerde .exprToProto(_, inputs)).forall(_.isDefined) &&
2821
- expressions.forall(e => supportedDataType (e.dataType)) &&
2822
- inputs.forall(attr => supportedDataType (attr.dataType))
2796
+ expressions.forall(e => supportedShuffleDataType (e.dataType)) &&
2797
+ inputs.forall(attr => supportedShuffleDataType (attr.dataType))
2823
2798
if (! supported) {
2824
2799
msg = s " unsupported Spark partitioning expressions: $expressions"
2825
2800
}
2826
2801
supported
2827
- case SinglePartition => inputs.forall(attr => supportedDataType(attr.dataType))
2828
- case RoundRobinPartitioning (_) => inputs.forall(attr => supportedDataType(attr.dataType))
2802
+ case SinglePartition => inputs.forall(attr => supportedShuffleDataType(attr.dataType))
2803
+ case RoundRobinPartitioning (_) =>
2804
+ inputs.forall(attr => supportedShuffleDataType(attr.dataType))
2829
2805
case RangePartitioning (orderings, _) =>
2830
2806
val supported =
2831
2807
orderings.map(QueryPlanSerde .exprToProto(_, inputs)).forall(_.isDefined) &&
2832
- orderings.forall(e => supportedDataType (e.dataType)) &&
2833
- inputs.forall(attr => supportedDataType (attr.dataType))
2808
+ orderings.forall(e => supportedShuffleDataType (e.dataType)) &&
2809
+ inputs.forall(attr => supportedShuffleDataType (attr.dataType))
2834
2810
if (! supported) {
2835
2811
msg = s " unsupported Spark partitioning expressions: $orderings"
2836
2812
}
@@ -2849,33 +2825,23 @@ object QueryPlanSerde extends Logging with CometExprShim {
2849
2825
}
2850
2826
2851
2827
/**
2852
- * Whether the given Spark partitioning is supported by Comet.
2828
+ * Whether the given Spark partitioning is supported by Comet native shuffle .
2853
2829
*/
2854
- def supportPartitioning (
2855
- inputs : Seq [Attribute ],
2856
- partitioning : Partitioning ): (Boolean , String ) = {
2857
- def supportedDataType (dt : DataType ): Boolean = dt match {
2858
- case _ : ByteType | _ : ShortType | _ : IntegerType | _ : LongType | _ : FloatType |
2859
- _ : DoubleType | _ : StringType | _ : BinaryType | _ : TimestampType | _ : DecimalType |
2860
- _ : DateType | _ : BooleanType =>
2861
- true
2862
- case _ =>
2863
- // Native shuffle doesn't support struct/array yet
2864
- false
2865
- }
2866
-
2830
+ def nativeShuffleSupported (s : ShuffleExchangeExec ): (Boolean , String ) = {
2831
+ val inputs = s.child.output
2832
+ val partitioning = s.outputPartitioning
2867
2833
var msg = " "
2868
2834
val supported = partitioning match {
2869
2835
case HashPartitioning (expressions, _) =>
2870
2836
val supported =
2871
2837
expressions.map(QueryPlanSerde .exprToProto(_, inputs)).forall(_.isDefined) &&
2872
- expressions.forall(e => supportedDataType (e.dataType)) &&
2873
- inputs.forall(attr => supportedDataType (attr.dataType))
2838
+ expressions.forall(e => supportedShuffleDataType (e.dataType)) &&
2839
+ inputs.forall(attr => supportedShuffleDataType (attr.dataType))
2874
2840
if (! supported) {
2875
2841
msg = s " unsupported Spark partitioning expressions: $expressions"
2876
2842
}
2877
2843
supported
2878
- case SinglePartition => inputs.forall(attr => supportedDataType (attr.dataType))
2844
+ case SinglePartition => inputs.forall(attr => supportedShuffleDataType (attr.dataType))
2879
2845
case _ =>
2880
2846
msg = s " unsupported Spark partitioning: ${partitioning.getClass.getName}"
2881
2847
false
@@ -2889,6 +2855,31 @@ object QueryPlanSerde extends Logging with CometExprShim {
2889
2855
}
2890
2856
}
2891
2857
2858
+ def supportedShuffleDataType (dt : DataType ): Boolean = dt match {
2859
+ case _ : ByteType | _ : ShortType | _ : IntegerType | _ : LongType | _ : FloatType |
2860
+ _ : DoubleType | _ : StringType | _ : BinaryType | _ : TimestampType | _ : TimestampNTZType |
2861
+ _ : DecimalType | _ : DateType | _ : BooleanType =>
2862
+ true
2863
+ case StructType (fields) =>
2864
+ fields.forall(f => supportedShuffleDataType(f.dataType)) &&
2865
+ // Java Arrow stream reader cannot work on duplicate field name
2866
+ fields.map(f => f.name).distinct.length == fields.length
2867
+ case ArrayType (ArrayType (_, _), _) => false // TODO: nested array is not supported
2868
+ case ArrayType (MapType (_, _, _), _) => false // TODO: map array element is not supported
2869
+ case ArrayType (elementType, _) =>
2870
+ supportedShuffleDataType(elementType)
2871
+ case MapType (MapType (_, _, _), _, _) => false // TODO: nested map is not supported
2872
+ case MapType (_, MapType (_, _, _), _) => false
2873
+ case MapType (StructType (_), _, _) => false // TODO: struct map key/value is not supported
2874
+ case MapType (_, StructType (_), _) => false
2875
+ case MapType (ArrayType (_, _), _, _) => false // TODO: array map key/value is not supported
2876
+ case MapType (_, ArrayType (_, _), _) => false
2877
+ case MapType (keyType, valueType, _) =>
2878
+ supportedShuffleDataType(keyType) && supportedShuffleDataType(valueType)
2879
+ case _ =>
2880
+ false
2881
+ }
2882
+
2892
2883
// Utility method. Adds explain info if the result of calling exprToProto is None
2893
2884
def optExprWithInfo (
2894
2885
optExpr : Option [Expr ],
0 commit comments