Skip to content

Commit dc1b96c

Browse files
committed
prototype cost-based optimizer
1 parent 2aa20f0 commit dc1b96c

File tree

7 files changed

+260
-5
lines changed

7 files changed

+260
-5
lines changed

common/src/main/scala/org/apache/comet/CometConf.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,14 @@ object CometConf extends ShimCometConf {
401401
.booleanConf
402402
.createWithDefault(false)
403403

404+
val COMET_CBO_ENABLED: ConfigEntry[Boolean] =
405+
conf("spark.comet.cbo.enabled")
406+
.doc(
407+
"Cost-based optimizer to avoid performance regressions where Comet plan may " +
408+
"be slower than Spark plan.")
409+
.booleanConf
410+
.createWithDefault(false)
411+
404412
}
405413

406414
object ConfigHelpers {

docs/source/user-guide/configs.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ Comet provides the following configuration settings.
2525
|--------|-------------|---------------|
2626
| spark.comet.batchSize | The columnar batch size, i.e., the maximum number of rows that a batch can contain. | 8192 |
2727
| spark.comet.cast.allowIncompatible | Comet is not currently fully compatible with Spark for all cast operations. Set this config to true to allow them anyway. See compatibility guide for more information. | false |
28+
| spark.comet.cbo.enabled | Cost-based optimizer to avoid performance regressions where Comet plan may be slower than Spark plan. | false |
2829
| spark.comet.columnar.shuffle.async.enabled | Whether to enable asynchronous shuffle for Arrow-based shuffle. By default, this config is false. | false |
2930
| spark.comet.columnar.shuffle.async.max.thread.num | Maximum number of threads on an executor used for Comet async columnar shuffle. By default, this config is 100. This is the upper bound of total number of shuffle threads per executor. In other words, if the number of cores * the number of shuffle threads per task `spark.comet.columnar.shuffle.async.thread.num` is larger than this config. Comet will use this config as the number of shuffle threads per executor instead. | 100 |
3031
| spark.comet.columnar.shuffle.async.thread.num | Number of threads used for Comet async columnar shuffle per shuffle task. By default, this config is 3. Note that more threads means more memory requirement to buffer shuffle data before flushing to disk. Also, more threads may not always improve performance, and should be set based on the number of cores available. | 3 |
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.comet
21+
22+
import org.apache.spark.internal.Logging
23+
import org.apache.spark.sql.comet.{CometExec, CometPlan, CometRowToColumnarExec, CometSinkPlaceHolder}
24+
import org.apache.spark.sql.execution.{ColumnarToRowExec, InputAdapter, RowToColumnarExec, SparkPlan, WholeStageCodegenExec}
25+
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, Cost, CostEvaluator, QueryStageExec, SimpleCost}
26+
27+
/**
28+
* The goal of this cost model is to avoid introducing performance regressions in query stages
29+
* during AQE.
30+
*
31+
* This evaluator will be called twice; once for the original Spark plan and once for the Comet
32+
* plan. Spark will choose the cheapest plan.
33+
*/
34+
class CometCostEvaluator extends CostEvaluator with Logging {
35+
36+
/** Baseline cost for Spark operator is 1.0 */
37+
val DEFAULT_SPARK_OPERATOR_COST = 1.0
38+
39+
/** Relative cost of Comet operator */
40+
val DEFAULT_COMET_OPERATOR_COST = 0.8
41+
42+
/** Relative cost of a transition (C2R, R2C) */
43+
val DEFAULT_TRANSITION_COST = 1.0
44+
45+
override def evaluateCost(plan: SparkPlan): Cost = {
46+
47+
// TODO this is a crude prototype where we just penalize transitions, but
48+
// this can evolve into a true cost model where we have real numbers for the relative
49+
// performance of Comet operators & expressions versus the Spark versions
50+
//
51+
// Some areas to explore
52+
// - can we use statistics from previous query stage(s)?
53+
// - transition after filter should be cheaper than transition before filter (such as when
54+
// reading from Parquet followed by filter. Comet will filter first then transition)
55+
def computePlanCost(plan: SparkPlan): Double = {
56+
57+
// get children even for leaf nodes at query stage edges
58+
def getChildren(plan: SparkPlan) = plan match {
59+
case a: AdaptiveSparkPlanExec => Seq(a.inputPlan)
60+
case qs: QueryStageExec => Seq(qs.plan)
61+
case p => p.children
62+
}
63+
64+
val children = getChildren(plan)
65+
val childPlanCost = children.map(computePlanCost).sum
66+
val operatorCost = plan match {
67+
case _: AdaptiveSparkPlanExec => 0
68+
case _: CometSinkPlaceHolder => 0
69+
case _: InputAdapter => 0
70+
case _: WholeStageCodegenExec => 0
71+
case RowToColumnarExec(_) => DEFAULT_TRANSITION_COST
72+
case ColumnarToRowExec(_) => DEFAULT_TRANSITION_COST
73+
case CometRowToColumnarExec(_) => DEFAULT_TRANSITION_COST
74+
case _: CometExec => DEFAULT_COMET_OPERATOR_COST
75+
case _ => DEFAULT_SPARK_OPERATOR_COST
76+
}
77+
78+
def isSparkNative(plan: SparkPlan): Boolean = plan match {
79+
case p: AdaptiveSparkPlanExec => isSparkNative(p.inputPlan)
80+
case p: QueryStageExec => isSparkNative(p.plan)
81+
case _: CometPlan => false
82+
case _ => true
83+
}
84+
85+
def isCometNative(plan: SparkPlan): Boolean = plan match {
86+
case p: AdaptiveSparkPlanExec => isCometNative(p.inputPlan)
87+
case p: QueryStageExec => isCometNative(p.plan)
88+
case _: CometPlan => true
89+
case _ => false
90+
}
91+
92+
def isTransition(plan1: SparkPlan, plan2: SparkPlan) = {
93+
(isSparkNative(plan1) && isCometNative(plan2)) ||
94+
(isCometNative(plan1) && isSparkNative(plan2))
95+
}
96+
97+
val transitionCost = if (children.exists(ch => isTransition(plan, ch))) {
98+
DEFAULT_TRANSITION_COST
99+
} else {
100+
0
101+
}
102+
103+
104+
val totalCost = operatorCost + transitionCost + childPlanCost
105+
106+
logWarning(s"total cost is $totalCost ($operatorCost + $transitionCost + $childPlanCost) " +
107+
s"for ${plan.nodeName}")
108+
109+
totalCost
110+
}
111+
112+
// TODO can we access statistics from previous query stages?
113+
val estimatedRowCount = 1000
114+
val cost = (computePlanCost(plan) * estimatedRowCount).toLong
115+
116+
logWarning(s"Computed cost of $cost for $plan")
117+
118+
SimpleCost(cost)
119+
}
120+
121+
}

spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,19 @@ import org.apache.spark.sql.comet._
3232
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec, CometShuffleManager}
3333
import org.apache.spark.sql.comet.util.Utils
3434
import org.apache.spark.sql.execution._
35-
import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec}
35+
import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, AdaptiveSparkPlanExec, BroadcastQueryStageExec, QueryStageExec, ShuffleQueryStageExec}
3636
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
3737
import org.apache.spark.sql.execution.datasources._
3838
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
3939
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
4040
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
41-
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
41+
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike}
4242
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec}
4343
import org.apache.spark.sql.internal.SQLConf
4444
import org.apache.spark.sql.types._
4545

4646
import org.apache.comet.CometConf._
47+
import org.apache.comet.CometExplainInfo.CANNOT_RUN_NATIVE
4748
import org.apache.comet.CometSparkSessionExtensions.{createMessage, getCometShuffleNotEnabledReason, isANSIEnabled, isCometBroadCastForceEnabled, isCometEnabled, isCometExecEnabled, isCometJVMShuffleMode, isCometNativeShuffleMode, isCometOperatorEnabled, isCometScan, isCometScanEnabled, isCometShuffleEnabled, isSchemaSupported, isSpark34Plus, isSpark40Plus, shouldApplyRowToColumnar, withInfo, withInfos}
4849
import org.apache.comet.parquet.{CometParquetScan, SupportsComet}
4950
import org.apache.comet.serde.OperatorOuterClass.Operator
@@ -65,15 +66,15 @@ class CometSparkSessionExtensions
6566
extensions.injectColumnar { session => CometScanColumnar(session) }
6667
extensions.injectColumnar { session => CometExecColumnar(session) }
6768
extensions.injectQueryStagePrepRule { session => CometScanRule(session) }
68-
extensions.injectQueryStagePrepRule { session => CometExecRule(session) }
69+
extensions.injectQueryStagePrepRule { session => CometQueryStagePrepRule(session) }
6970
}
7071

7172
case class CometScanColumnar(session: SparkSession) extends ColumnarRule {
7273
override def preColumnarTransitions: Rule[SparkPlan] = CometScanRule(session)
7374
}
7475

7576
case class CometExecColumnar(session: SparkSession) extends ColumnarRule {
76-
override def preColumnarTransitions: Rule[SparkPlan] = CometExecRule(session)
77+
override def preColumnarTransitions: Rule[SparkPlan] = CometPreColumnarRule(session)
7778

7879
override def postColumnarTransitions: Rule[SparkPlan] =
7980
EliminateRedundantTransitions(session)
@@ -192,6 +193,57 @@ class CometSparkSessionExtensions
192193
}
193194
}
194195

196+
case class CometQueryStagePrepRule(session: SparkSession) extends Rule[SparkPlan] {
197+
def apply(plan: SparkPlan): SparkPlan = {
198+
199+
200+
val newPlan = CometExecRule(session).apply(plan)
201+
202+
203+
if (CometConf.COMET_CBO_ENABLED.get()) {
204+
val costEvaluator = new CometCostEvaluator()
205+
println(plan)
206+
println(newPlan)
207+
val sparkCost = costEvaluator.evaluateCost(plan)
208+
val cometCost = costEvaluator.evaluateCost(newPlan)
209+
println(s"sparkCost = $sparkCost, cometCost = $cometCost")
210+
if (cometCost > sparkCost) {
211+
val msg = s"Comet plan is more expensive than Spark plan ($cometCost > $sparkCost)" +
212+
s"\nSPARK: $plan\n" +
213+
s"\nCOMET: $newPlan\n"
214+
logWarning(msg)
215+
println(msg)
216+
println(s"CometQueryStagePrepRule:\nIN: ${plan.getClass}\nOUT: ${plan.getClass}")
217+
218+
def fallbackRecursively(plan: SparkPlan) : Unit = {
219+
plan.setTagValue(CANNOT_RUN_NATIVE, true)
220+
plan match {
221+
case a: AdaptiveSparkPlanExec => fallbackRecursively(a.inputPlan)
222+
case qs: QueryStageExec => fallbackRecursively(qs.plan)
223+
case p => p.children.foreach(fallbackRecursively)
224+
}
225+
}
226+
fallbackRecursively(plan)
227+
228+
return plan
229+
}
230+
}
231+
232+
233+
println(s"CometQueryStagePrepRule:\nIN: ${plan.getClass}\nOUT: ${newPlan.getClass}")
234+
235+
newPlan
236+
}
237+
}
238+
239+
case class CometPreColumnarRule(session: SparkSession) extends Rule[SparkPlan] {
240+
def apply(plan: SparkPlan): SparkPlan = {
241+
val newPlan = CometExecRule(session).apply(plan)
242+
println(s"CometPreColumnarRule:\nIN: ${plan.getClass}\nOUT: ${newPlan.getClass}")
243+
newPlan
244+
}
245+
}
246+
195247
case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
196248
private def applyCometShuffle(plan: SparkPlan): SparkPlan = {
197249
plan.transformUp {
@@ -727,6 +779,11 @@ class CometSparkSessionExtensions
727779
// We shouldn't transform Spark query plan if Comet is disabled.
728780
if (!isCometEnabled(conf)) return plan
729781

782+
if (plan.getTagValue(CANNOT_RUN_NATIVE).getOrElse(false)) {
783+
println("Cannot run native - too slow")
784+
return plan
785+
}
786+
730787
if (!isCometExecEnabled(conf)) {
731788
// Comet exec is disabled, but for Spark shuffle, we still can use Comet columnar shuffle
732789
if (isCometShuffleEnabled(conf)) {

spark/src/main/scala/org/apache/comet/ExtendedExplainInfo.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,4 +85,5 @@ class ExtendedExplainInfo extends ExtendedExplainGenerator {
8585

8686
object CometExplainInfo {
8787
val EXTENSION_INFO = new TreeNodeTag[Set[String]]("CometExtensionInfo")
88+
val CANNOT_RUN_NATIVE = new TreeNodeTag[Boolean]("CometCannotRunNative")
8889
}

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2301,7 +2301,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
23012301
.addAllSortOrders(sortOrders.map(_.get).asJava)
23022302
Some(result.setSort(sortBuilder).build())
23032303
} else {
2304-
withInfo(op, sortOrder: _*)
2304+
withInfo(op, "sort not allowed", sortOrder: _*)
23052305
None
23062306
}
23072307

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.comet
21+
22+
import org.apache.spark.sql.CometTestBase
23+
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
24+
import org.apache.spark.sql.internal.SQLConf
25+
26+
class CostBasedOptimizerSuite extends CometTestBase with AdaptiveSparkPlanHelper {
27+
28+
private val dataGen = DataGenerator.DEFAULT
29+
30+
test("tbd") {
31+
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
32+
SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "false",
33+
CometConf.COMET_ENABLED.key -> "true",
34+
CometConf.COMET_EXEC_ENABLED.key -> "true",
35+
CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
36+
CometConf.COMET_CBO_ENABLED.key -> "true",
37+
CometConf.COMET_EXPLAIN_FALLBACK_ENABLED.key -> "true") {
38+
val table = "t1"
39+
withTable(table, "t2") {
40+
sql(s"create table t1(col string, a int, b float) using parquet")
41+
sql(s"create table t2(col string, a int, b float) using parquet")
42+
val tableSchema = spark.table(table).schema
43+
val rows = dataGen.generateRows(
44+
1000,
45+
tableSchema,
46+
Some(() => dataGen.generateString("tbd:", 6)))
47+
val data = spark.createDataFrame(spark.sparkContext.parallelize(rows), tableSchema)
48+
data.write
49+
.mode("append")
50+
.insertInto(table)
51+
data.write
52+
.mode("append")
53+
.insertInto("t2")
54+
val x = checkSparkAnswer/*AndOperator*/("select t1.col as x " +
55+
"from t1 join t2 on cast(t1.col as timestamp) = cast(t2.col as timestamp) " +
56+
"order by x")
57+
58+
// TODO assert that we fell back for whole plan
59+
println(x._1)
60+
println(x._2)
61+
62+
assert(!x._2.toString().contains("CometSortExec"))
63+
}
64+
}
65+
}
66+
67+
}

0 commit comments

Comments
 (0)