-
Notifications
You must be signed in to change notification settings - Fork 205
feat: Add support for complex types in native shuffle #1655
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
eec034a
307beef
ae416b3
7811f61
788801b
4ff9f2e
1707161
da79d4e
e9d0029
a9aa537
b5b4d27
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2785,52 +2785,31 @@ object QueryPlanSerde extends Logging with CometExprShim { | |
* Check if the datatypes of shuffle input are supported. This is used for Columnar shuffle | ||
* which supports struct/array. | ||
*/ | ||
def supportPartitioningTypes( | ||
inputs: Seq[Attribute], | ||
partitioning: Partitioning): (Boolean, String) = { | ||
def supportedDataType(dt: DataType): Boolean = dt match { | ||
case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType | | ||
_: DoubleType | _: StringType | _: BinaryType | _: TimestampType | _: DecimalType | | ||
_: DateType | _: BooleanType => | ||
true | ||
case StructType(fields) => | ||
fields.forall(f => supportedDataType(f.dataType)) && | ||
// Java Arrow stream reader cannot work on duplicate field name | ||
fields.map(f => f.name).distinct.length == fields.length | ||
case ArrayType(ArrayType(_, _), _) => false // TODO: nested array is not supported | ||
case ArrayType(MapType(_, _, _), _) => false // TODO: map array element is not supported | ||
case ArrayType(elementType, _) => | ||
supportedDataType(elementType) | ||
case MapType(MapType(_, _, _), _, _) => false // TODO: nested map is not supported | ||
case MapType(_, MapType(_, _, _), _) => false | ||
case MapType(StructType(_), _, _) => false // TODO: struct map key/value is not supported | ||
case MapType(_, StructType(_), _) => false | ||
case MapType(ArrayType(_, _), _, _) => false // TODO: array map key/value is not supported | ||
case MapType(_, ArrayType(_, _), _) => false | ||
case MapType(keyType, valueType, _) => | ||
supportedDataType(keyType) && supportedDataType(valueType) | ||
case _ => | ||
false | ||
} | ||
|
||
def columnarShuffleSupported(s: ShuffleExchangeExec): (Boolean, String) = { | ||
val inputs = s.child.output | ||
val partitioning = s.outputPartitioning | ||
var msg = "" | ||
val supported = partitioning match { | ||
case HashPartitioning(expressions, _) => | ||
// columnar shuffle supports the same data types (including complex types) both for | ||
// partition keys and for other columns | ||
val supported = | ||
expressions.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) && | ||
expressions.forall(e => supportedDataType(e.dataType)) && | ||
inputs.forall(attr => supportedDataType(attr.dataType)) | ||
expressions.forall(e => supportedShuffleDataType(e.dataType)) && | ||
inputs.forall(attr => supportedShuffleDataType(attr.dataType)) | ||
if (!supported) { | ||
msg = s"unsupported Spark partitioning expressions: $expressions" | ||
} | ||
supported | ||
case SinglePartition => inputs.forall(attr => supportedDataType(attr.dataType)) | ||
case RoundRobinPartitioning(_) => inputs.forall(attr => supportedDataType(attr.dataType)) | ||
case SinglePartition => | ||
inputs.forall(attr => supportedShuffleDataType(attr.dataType)) | ||
case RoundRobinPartitioning(_) => | ||
inputs.forall(attr => supportedShuffleDataType(attr.dataType)) | ||
case RangePartitioning(orderings, _) => | ||
val supported = | ||
orderings.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) && | ||
orderings.forall(e => supportedDataType(e.dataType)) && | ||
inputs.forall(attr => supportedDataType(attr.dataType)) | ||
orderings.forall(e => supportedShuffleDataType(e.dataType)) && | ||
inputs.forall(attr => supportedShuffleDataType(attr.dataType)) | ||
if (!supported) { | ||
msg = s"unsupported Spark partitioning expressions: $orderings" | ||
} | ||
|
@@ -2849,33 +2828,42 @@ object QueryPlanSerde extends Logging with CometExprShim { | |
} | ||
|
||
/** | ||
* Whether the given Spark partitioning is supported by Comet. | ||
* Whether the given Spark partitioning is supported by Comet native shuffle. | ||
*/ | ||
def supportPartitioning( | ||
inputs: Seq[Attribute], | ||
partitioning: Partitioning): (Boolean, String) = { | ||
def supportedDataType(dt: DataType): Boolean = dt match { | ||
case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType | | ||
_: DoubleType | _: StringType | _: BinaryType | _: TimestampType | _: DecimalType | | ||
_: DateType | _: BooleanType => | ||
def nativeShuffleSupported(s: ShuffleExchangeExec): (Boolean, String) = { | ||
|
||
/** | ||
* Determine which data types are supported as hash-partition keys in native shuffle. | ||
* | ||
* Hash Partition Key determines how data should be collocated for operations like | ||
* `groupByKey`, `reduceByKey` or `join`. | ||
*/ | ||
def supportedPartitionKeyDataType(dt: DataType): Boolean = dt match { | ||
case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _: LongType | | ||
_: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: TimestampType | | ||
_: TimestampNTZType | _: DecimalType | _: DateType => | ||
true | ||
case _ => | ||
// Native shuffle doesn't support struct/array yet | ||
false | ||
} | ||
|
||
val inputs = s.child.output | ||
val partitioning = s.outputPartitioning | ||
var msg = "" | ||
val supported = partitioning match { | ||
case HashPartitioning(expressions, _) => | ||
// native shuffle currently does not support complex types as partition keys | ||
// due to lack of hashing support for those types | ||
val supported = | ||
expressions.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) && | ||
expressions.forall(e => supportedDataType(e.dataType)) && | ||
inputs.forall(attr => supportedDataType(attr.dataType)) | ||
expressions.forall(e => supportedPartitionKeyDataType(e.dataType)) && | ||
inputs.forall(attr => supportedShuffleDataType(attr.dataType)) | ||
if (!supported) { | ||
msg = s"unsupported Spark partitioning expressions: $expressions" | ||
} | ||
supported | ||
case SinglePartition => inputs.forall(attr => supportedDataType(attr.dataType)) | ||
case SinglePartition => | ||
inputs.forall(attr => supportedShuffleDataType(attr.dataType)) | ||
case _ => | ||
msg = s"unsupported Spark partitioning: ${partitioning.getClass.getName}" | ||
false | ||
|
@@ -2889,6 +2877,34 @@ object QueryPlanSerde extends Logging with CometExprShim { | |
} | ||
} | ||
|
||
/** | ||
* Determine which data types are supported in a shuffle. | ||
*/ | ||
def supportedShuffleDataType(dt: DataType): Boolean = dt match { | ||
case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _: LongType | | ||
_: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: TimestampType | | ||
_: TimestampNTZType | _: DecimalType | _: DateType => | ||
true | ||
case StructType(fields) => | ||
fields.forall(f => supportedShuffleDataType(f.dataType)) && | ||
// Java Arrow stream reader cannot work on duplicate field name | ||
fields.map(f => f.name).distinct.length == fields.length | ||
case ArrayType(ArrayType(_, _), _) => false // TODO: nested array is not supported | ||
case ArrayType(MapType(_, _, _), _) => false // TODO: map array element is not supported | ||
case ArrayType(elementType, _) => | ||
supportedShuffleDataType(elementType) | ||
case MapType(MapType(_, _, _), _, _) => false // TODO: nested map is not supported | ||
case MapType(_, MapType(_, _, _), _) => false | ||
case MapType(StructType(_), _, _) => false // TODO: struct map key/value is not supported | ||
case MapType(_, StructType(_), _) => false | ||
case MapType(ArrayType(_, _), _, _) => false // TODO: array map key/value is not supported | ||
case MapType(_, ArrayType(_, _), _) => false | ||
case MapType(keyType, valueType, _) => | ||
supportedShuffleDataType(keyType) && supportedShuffleDataType(valueType) | ||
case _ => | ||
false | ||
} | ||
|
||
// Utility method. Adds explain info if the result of calling exprToProto is None | ||
def optExprWithInfo( | ||
optExpr: Option[Expr], | ||
|
@@ -2920,7 +2936,8 @@ object QueryPlanSerde extends Logging with CometExprShim { | |
val canSort = sortOrder.head.dataType match { | ||
case _: BooleanType => true | ||
case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType | | ||
_: DoubleType | _: TimestampType | _: TimestampType | _: DecimalType | _: DateType => | ||
_: DoubleType | _: TimestampType | _: TimestampNTZType | _: DecimalType | | ||
Comment on lines
-2923
to
+2939
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is unrelated to the goal of the PR but I noticed we had TimestampType twice and no TimestampNTZType |
||
_: DateType => | ||
true | ||
case _: BinaryType | _: StringType => true | ||
case ArrayType(elementType, _) => canRank(elementType) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,7 @@ import org.scalatest.Tag | |
import org.apache.commons.io.FileUtils | ||
import org.apache.spark.sql.CometTestBase | ||
import org.apache.spark.sql.comet.{CometNativeScanExec, CometScanExec} | ||
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec | ||
import org.apache.spark.sql.execution.SparkPlan | ||
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper | ||
import org.apache.spark.sql.internal.SQLConf | ||
|
@@ -161,6 +162,18 @@ class CometFuzzTestSuite extends CometTestBase with AdaptiveSparkPlanHelper { | |
} | ||
} | ||
|
||
test("shuffle") { | ||
val df = spark.read.parquet(filename) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does the data have complex type? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, the data has arrays and structs but not maps yet |
||
val df2 = df.repartition(8, df.col("c0")).sort("c1") | ||
df2.collect() | ||
if (CometConf.isExperimentalNativeScan) { | ||
val cometShuffles = collect(df2.queryExecution.executedPlan) { | ||
case exec: CometShuffleExchangeExec => exec | ||
} | ||
assert(1 == cometShuffles.length) | ||
} | ||
} | ||
|
||
test("join") { | ||
val df = spark.read.parquet(filename) | ||
df.createOrReplaceTempView("t1") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, it does! This method is removed and we now have a single
supportedShuffleDataType
method that is used for both native and columnar shuffle type checks.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for that, I was so confused about having this supported check in at least 3 places