Skip to content

Commit c04784a

Browse files
authored
feat: Add support for complex types in native shuffle (#1655)
1 parent bfcb968 commit c04784a

File tree

3 files changed

+86
-56
lines changed

3 files changed

+86
-56
lines changed

spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ class CometSparkSessionExtensions
240240
plan.transformUp {
241241
case s: ShuffleExchangeExec
242242
if isCometPlan(s.child) && isCometNativeShuffleMode(conf) &&
243-
QueryPlanSerde.supportPartitioning(s.child.output, s.outputPartitioning)._1 =>
243+
QueryPlanSerde.nativeShuffleSupported(s)._1 =>
244244
logInfo("Comet extension enabled for Native Shuffle")
245245

246246
// Switch to use Decimal128 regardless of precision, since Arrow native execution
@@ -253,7 +253,7 @@ class CometSparkSessionExtensions
253253
case s: ShuffleExchangeExec
254254
if (!s.child.supportsColumnar || isCometPlan(s.child)) && isCometJVMShuffleMode(
255255
conf) &&
256-
QueryPlanSerde.supportPartitioningTypes(s.child.output, s.outputPartitioning)._1 &&
256+
QueryPlanSerde.columnarShuffleSupported(s)._1 &&
257257
!isShuffleOperator(s.child) =>
258258
logInfo("Comet extension enabled for JVM Columnar Shuffle")
259259
CometShuffleExchangeExec(s, shuffleType = CometColumnarShuffle)
@@ -719,7 +719,7 @@ class CometSparkSessionExtensions
719719
case s: ShuffleExchangeExec =>
720720
val nativePrecondition = isCometShuffleEnabled(conf) &&
721721
isCometNativeShuffleMode(conf) &&
722-
QueryPlanSerde.supportPartitioning(s.child.output, s.outputPartitioning)._1
722+
QueryPlanSerde.nativeShuffleSupported(s)._1
723723

724724
val nativeShuffle: Option[SparkPlan] =
725725
if (nativePrecondition) {
@@ -753,7 +753,7 @@ class CometSparkSessionExtensions
753753
// If the child of ShuffleExchangeExec is also a ShuffleExchangeExec, we should not
754754
// convert it to CometColumnarShuffle,
755755
if (isCometShuffleEnabled(conf) && isCometJVMShuffleMode(conf) &&
756-
QueryPlanSerde.supportPartitioningTypes(s.child.output, s.outputPartitioning)._1 &&
756+
QueryPlanSerde.columnarShuffleSupported(s)._1 &&
757757
!isShuffleOperator(s.child)) {
758758

759759
val newOp = QueryPlanSerde.operator2Proto(s)
@@ -781,22 +781,22 @@ class CometSparkSessionExtensions
781781
nativeOrColumnarShuffle.get
782782
} else {
783783
val isShuffleEnabled = isCometShuffleEnabled(conf)
784-
val outputPartitioning = s.outputPartitioning
784+
s.outputPartitioning
785785
val reason = getCometShuffleNotEnabledReason(conf).getOrElse("no reason available")
786786
val msg1 = createMessage(!isShuffleEnabled, s"Comet shuffle is not enabled: $reason")
787787
val columnarShuffleEnabled = isCometJVMShuffleMode(conf)
788788
val msg2 = createMessage(
789789
isShuffleEnabled && !columnarShuffleEnabled && !QueryPlanSerde
790-
.supportPartitioning(s.child.output, outputPartitioning)
790+
.nativeShuffleSupported(s)
791791
._1,
792792
"Native shuffle: " +
793-
s"${QueryPlanSerde.supportPartitioning(s.child.output, outputPartitioning)._2}")
793+
s"${QueryPlanSerde.nativeShuffleSupported(s)._2}")
794794
val typeInfo = QueryPlanSerde
795-
.supportPartitioningTypes(s.child.output, outputPartitioning)
795+
.columnarShuffleSupported(s)
796796
._2
797797
val msg3 = createMessage(
798798
isShuffleEnabled && columnarShuffleEnabled && !QueryPlanSerde
799-
.supportPartitioningTypes(s.child.output, outputPartitioning)
799+
.columnarShuffleSupported(s)
800800
._1,
801801
"JVM shuffle: " +
802802
s"$typeInfo")

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 64 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -2785,52 +2785,31 @@ object QueryPlanSerde extends Logging with CometExprShim {
27852785
* Check if the datatypes of shuffle input are supported. This is used for Columnar shuffle
27862786
* which supports struct/array.
27872787
*/
2788-
def supportPartitioningTypes(
2789-
inputs: Seq[Attribute],
2790-
partitioning: Partitioning): (Boolean, String) = {
2791-
def supportedDataType(dt: DataType): Boolean = dt match {
2792-
case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType |
2793-
_: DoubleType | _: StringType | _: BinaryType | _: TimestampType | _: DecimalType |
2794-
_: DateType | _: BooleanType =>
2795-
true
2796-
case StructType(fields) =>
2797-
fields.forall(f => supportedDataType(f.dataType)) &&
2798-
// Java Arrow stream reader cannot work on duplicate field name
2799-
fields.map(f => f.name).distinct.length == fields.length
2800-
case ArrayType(ArrayType(_, _), _) => false // TODO: nested array is not supported
2801-
case ArrayType(MapType(_, _, _), _) => false // TODO: map array element is not supported
2802-
case ArrayType(elementType, _) =>
2803-
supportedDataType(elementType)
2804-
case MapType(MapType(_, _, _), _, _) => false // TODO: nested map is not supported
2805-
case MapType(_, MapType(_, _, _), _) => false
2806-
case MapType(StructType(_), _, _) => false // TODO: struct map key/value is not supported
2807-
case MapType(_, StructType(_), _) => false
2808-
case MapType(ArrayType(_, _), _, _) => false // TODO: array map key/value is not supported
2809-
case MapType(_, ArrayType(_, _), _) => false
2810-
case MapType(keyType, valueType, _) =>
2811-
supportedDataType(keyType) && supportedDataType(valueType)
2812-
case _ =>
2813-
false
2814-
}
2815-
2788+
def columnarShuffleSupported(s: ShuffleExchangeExec): (Boolean, String) = {
2789+
val inputs = s.child.output
2790+
val partitioning = s.outputPartitioning
28162791
var msg = ""
28172792
val supported = partitioning match {
28182793
case HashPartitioning(expressions, _) =>
2794+
// columnar shuffle supports the same data types (including complex types) both for
2795+
// partition keys and for other columns
28192796
val supported =
28202797
expressions.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) &&
2821-
expressions.forall(e => supportedDataType(e.dataType)) &&
2822-
inputs.forall(attr => supportedDataType(attr.dataType))
2798+
expressions.forall(e => supportedShuffleDataType(e.dataType)) &&
2799+
inputs.forall(attr => supportedShuffleDataType(attr.dataType))
28232800
if (!supported) {
28242801
msg = s"unsupported Spark partitioning expressions: $expressions"
28252802
}
28262803
supported
2827-
case SinglePartition => inputs.forall(attr => supportedDataType(attr.dataType))
2828-
case RoundRobinPartitioning(_) => inputs.forall(attr => supportedDataType(attr.dataType))
2804+
case SinglePartition =>
2805+
inputs.forall(attr => supportedShuffleDataType(attr.dataType))
2806+
case RoundRobinPartitioning(_) =>
2807+
inputs.forall(attr => supportedShuffleDataType(attr.dataType))
28292808
case RangePartitioning(orderings, _) =>
28302809
val supported =
28312810
orderings.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) &&
2832-
orderings.forall(e => supportedDataType(e.dataType)) &&
2833-
inputs.forall(attr => supportedDataType(attr.dataType))
2811+
orderings.forall(e => supportedShuffleDataType(e.dataType)) &&
2812+
inputs.forall(attr => supportedShuffleDataType(attr.dataType))
28342813
if (!supported) {
28352814
msg = s"unsupported Spark partitioning expressions: $orderings"
28362815
}
@@ -2849,33 +2828,42 @@ object QueryPlanSerde extends Logging with CometExprShim {
28492828
}
28502829

28512830
/**
2852-
* Whether the given Spark partitioning is supported by Comet.
2831+
* Whether the given Spark partitioning is supported by Comet native shuffle.
28532832
*/
2854-
def supportPartitioning(
2855-
inputs: Seq[Attribute],
2856-
partitioning: Partitioning): (Boolean, String) = {
2857-
def supportedDataType(dt: DataType): Boolean = dt match {
2858-
case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType |
2859-
_: DoubleType | _: StringType | _: BinaryType | _: TimestampType | _: DecimalType |
2860-
_: DateType | _: BooleanType =>
2833+
def nativeShuffleSupported(s: ShuffleExchangeExec): (Boolean, String) = {
2834+
2835+
/**
2836+
* Determine which data types are supported as hash-partition keys in native shuffle.
2837+
*
2838+
* Hash Partition Key determines how data should be collocated for operations like
2839+
* `groupByKey`, `reduceByKey` or `join`.
2840+
*/
2841+
def supportedPartitionKeyDataType(dt: DataType): Boolean = dt match {
2842+
case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _: LongType |
2843+
_: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: TimestampType |
2844+
_: TimestampNTZType | _: DecimalType | _: DateType =>
28612845
true
28622846
case _ =>
2863-
// Native shuffle doesn't support struct/array yet
28642847
false
28652848
}
28662849

2850+
val inputs = s.child.output
2851+
val partitioning = s.outputPartitioning
28672852
var msg = ""
28682853
val supported = partitioning match {
28692854
case HashPartitioning(expressions, _) =>
2855+
// native shuffle currently does not support complex types as partition keys
2856+
// due to lack of hashing support for those types
28702857
val supported =
28712858
expressions.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) &&
2872-
expressions.forall(e => supportedDataType(e.dataType)) &&
2873-
inputs.forall(attr => supportedDataType(attr.dataType))
2859+
expressions.forall(e => supportedPartitionKeyDataType(e.dataType)) &&
2860+
inputs.forall(attr => supportedShuffleDataType(attr.dataType))
28742861
if (!supported) {
28752862
msg = s"unsupported Spark partitioning expressions: $expressions"
28762863
}
28772864
supported
2878-
case SinglePartition => inputs.forall(attr => supportedDataType(attr.dataType))
2865+
case SinglePartition =>
2866+
inputs.forall(attr => supportedShuffleDataType(attr.dataType))
28792867
case _ =>
28802868
msg = s"unsupported Spark partitioning: ${partitioning.getClass.getName}"
28812869
false
@@ -2889,6 +2877,34 @@ object QueryPlanSerde extends Logging with CometExprShim {
28892877
}
28902878
}
28912879

2880+
/**
2881+
* Determine which data types are supported in a shuffle.
2882+
*/
2883+
def supportedShuffleDataType(dt: DataType): Boolean = dt match {
2884+
case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _: LongType |
2885+
_: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: TimestampType |
2886+
_: TimestampNTZType | _: DecimalType | _: DateType =>
2887+
true
2888+
case StructType(fields) =>
2889+
fields.forall(f => supportedShuffleDataType(f.dataType)) &&
2890+
// Java Arrow stream reader cannot work on duplicate field name
2891+
fields.map(f => f.name).distinct.length == fields.length
2892+
case ArrayType(ArrayType(_, _), _) => false // TODO: nested array is not supported
2893+
case ArrayType(MapType(_, _, _), _) => false // TODO: map array element is not supported
2894+
case ArrayType(elementType, _) =>
2895+
supportedShuffleDataType(elementType)
2896+
case MapType(MapType(_, _, _), _, _) => false // TODO: nested map is not supported
2897+
case MapType(_, MapType(_, _, _), _) => false
2898+
case MapType(StructType(_), _, _) => false // TODO: struct map key/value is not supported
2899+
case MapType(_, StructType(_), _) => false
2900+
case MapType(ArrayType(_, _), _, _) => false // TODO: array map key/value is not supported
2901+
case MapType(_, ArrayType(_, _), _) => false
2902+
case MapType(keyType, valueType, _) =>
2903+
supportedShuffleDataType(keyType) && supportedShuffleDataType(valueType)
2904+
case _ =>
2905+
false
2906+
}
2907+
28922908
// Utility method. Adds explain info if the result of calling exprToProto is None
28932909
def optExprWithInfo(
28942910
optExpr: Option[Expr],
@@ -2920,7 +2936,8 @@ object QueryPlanSerde extends Logging with CometExprShim {
29202936
val canSort = sortOrder.head.dataType match {
29212937
case _: BooleanType => true
29222938
case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType |
2923-
_: DoubleType | _: TimestampType | _: TimestampType | _: DecimalType | _: DateType =>
2939+
_: DoubleType | _: TimestampType | _: TimestampNTZType | _: DecimalType |
2940+
_: DateType =>
29242941
true
29252942
case _: BinaryType | _: StringType => true
29262943
case ArrayType(elementType, _) => canRank(elementType)

spark/src/test/scala/org/apache/comet/CometFuzzTestSuite.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import org.scalatest.Tag
3030
import org.apache.commons.io.FileUtils
3131
import org.apache.spark.sql.CometTestBase
3232
import org.apache.spark.sql.comet.{CometNativeScanExec, CometScanExec}
33+
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
3334
import org.apache.spark.sql.execution.SparkPlan
3435
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
3536
import org.apache.spark.sql.internal.SQLConf
@@ -161,6 +162,18 @@ class CometFuzzTestSuite extends CometTestBase with AdaptiveSparkPlanHelper {
161162
}
162163
}
163164

165+
test("shuffle") {
166+
val df = spark.read.parquet(filename)
167+
val df2 = df.repartition(8, df.col("c0")).sort("c1")
168+
df2.collect()
169+
if (CometConf.isExperimentalNativeScan) {
170+
val cometShuffles = collect(df2.queryExecution.executedPlan) {
171+
case exec: CometShuffleExchangeExec => exec
172+
}
173+
assert(1 == cometShuffles.length)
174+
}
175+
}
176+
164177
test("join") {
165178
val df = spark.read.parquet(filename)
166179
df.createOrReplaceTempView("t1")

0 commit comments

Comments
 (0)