From ba7cadc8ba564058802325d671d919920206709d Mon Sep 17 00:00:00 2001 From: wforget <643348094@qq.com> Date: Tue, 5 Nov 2024 15:16:41 +0800 Subject: [PATCH 1/4] Fix check write test --- .../sql/KyuubiSparkSQLExtensionTest.scala | 39 +++++++++-- .../sql/RebalanceBeforeWritingSuite.scala | 67 +++++++++---------- 2 files changed, 66 insertions(+), 40 deletions(-) diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/KyuubiSparkSQLExtensionTest.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/KyuubiSparkSQLExtensionTest.scala index 996bef763a2..11704a26d40 100644 --- a/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/KyuubiSparkSQLExtensionTest.scala +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/KyuubiSparkSQLExtensionTest.scala @@ -16,10 +16,13 @@ */ package org.apache.spark.sql +import java.util.Collections + import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.spark.SparkConf -import org.apache.spark.sql.execution.QueryExecution -import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, RebalancePartitions} +import org.apache.spark.sql.execution.{QueryExecution, SparkPlan} +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, LogicalQueryStage} import org.apache.spark.sql.execution.command.{DataWritingCommand, DataWritingCommandExec} import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} import org.apache.spark.sql.test.SQLTestData.TestData @@ -99,15 +102,24 @@ trait KyuubiSparkSQLExtensionTest extends QueryTest withListener(sql(sqlString))(callback) } - def withListener(df: => DataFrame)(callback: DataWritingCommand => Unit): Unit = { + def withListener(df: => DataFrame)( + callback: DataWritingCommand => Unit, + failIfNotCallback: Boolean = true): Unit = { + val writes = Collections.synchronizedList(new java.util.ArrayList[DataWritingCommand]()) + val listener = new QueryExecutionListener { override def onFailure(f: String, qe: QueryExecution, e: Exception): Unit = {} override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { - qe.executedPlan match { - case write: DataWritingCommandExec => callback(write.cmd) - case _ => + def doCallback(plan: SparkPlan): Unit = { + plan match { + case write: DataWritingCommandExec => + writes.add(write.cmd) + case a: AdaptiveSparkPlanExec => doCallback(a.executedPlan) + case _ => + } } + doCallback(qe.executedPlan) } } spark.listenerManager.register(listener) @@ -117,5 +129,20 @@ trait KyuubiSparkSQLExtensionTest extends QueryTest } finally { spark.listenerManager.unregister(listener) } + if (failIfNotCallback && writes.isEmpty) { + fail("No write command found") + } + writes.forEach(callback(_)) + } + + def collectRebalancePartitions(plan: LogicalPlan): Seq[RebalancePartitions] = { + def collect(p: LogicalPlan): Seq[RebalancePartitions] = { + p.flatMap { + case r: RebalancePartitions => Seq(r) + case s: LogicalQueryStage => collect(s.logicalPlan) + case _ => Nil + } + } + collect(plan) } } diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/RebalanceBeforeWritingSuite.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/RebalanceBeforeWritingSuite.scala index 46ba272011b..2eb970573f0 100644 --- a/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/RebalanceBeforeWritingSuite.scala +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/RebalanceBeforeWritingSuite.scala @@ -22,7 +22,6 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, RebalancePartit import org.apache.spark.sql.execution.command.DataWritingCommand import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand import org.apache.spark.sql.hive.HiveUtils -import org.apache.spark.sql.hive.execution.InsertIntoHiveTable import org.apache.kyuubi.sql.KyuubiSQLConf @@ -31,17 +30,15 @@ class RebalanceBeforeWritingSuite extends KyuubiSparkSQLExtensionTest { test("check rebalance exists") { def check(df: => DataFrame, expectedRebalanceNum: Int = 1): Unit = { withSQLConf(KyuubiSQLConf.INSERT_REPARTITION_BEFORE_WRITE_IF_NO_SHUFFLE.key -> "true") { + var rebalanceNum = 0 withListener(df) { write => - assert(write.collect { - case r: RebalancePartitions => r - }.size == expectedRebalanceNum) + rebalanceNum += collectRebalancePartitions(write).size } + assert(rebalanceNum == expectedRebalanceNum) } withSQLConf(KyuubiSQLConf.INSERT_REPARTITION_BEFORE_WRITE_IF_NO_SHUFFLE.key -> "false") { withListener(df) { write => - assert(write.collect { - case r: RebalancePartitions => r - }.isEmpty) + assert(collectRebalancePartitions(write).isEmpty) } } } @@ -97,11 +94,12 @@ class RebalanceBeforeWritingSuite extends KyuubiSparkSQLExtensionTest { } test("check rebalance does not exists") { - def check(df: DataFrame): Unit = { + def checkQuery(df: => DataFrame): Unit = { + assert(collectRebalancePartitions(df.queryExecution.analyzed).isEmpty) + } + def checkWrite(df: => DataFrame): Unit = { withListener(df) { write => - assert(write.collect { - case r: RebalancePartitions => r - }.isEmpty) + assert(collectRebalancePartitions(write).isEmpty) } } @@ -109,17 +107,17 @@ class RebalanceBeforeWritingSuite extends KyuubiSparkSQLExtensionTest { KyuubiSQLConf.INSERT_REPARTITION_BEFORE_WRITE.key -> "true", KyuubiSQLConf.INSERT_REPARTITION_BEFORE_WRITE_IF_NO_SHUFFLE.key -> "true") { // test no write command - check(sql("SELECT * FROM VALUES(1, 'a'),(2, 'b') AS t(c1, c2)")) - check(sql("SELECT count(*) FROM VALUES(1, 'a'),(2, 'b') AS t(c1, c2)")) + checkQuery(sql("SELECT * FROM VALUES(1, 'a'),(2, 'b') AS t(c1, c2)")) + checkQuery(sql("SELECT count(*) FROM VALUES(1, 'a'),(2, 'b') AS t(c1, c2)")) // test not supported plan withTable("tmp1") { sql(s"CREATE TABLE tmp1 (c1 int) PARTITIONED BY (c2 string)") - check(sql("INSERT INTO TABLE tmp1 PARTITION(c2) " + + checkWrite(sql("INSERT INTO TABLE tmp1 PARTITION(c2) " + "SELECT /*+ repartition(10) */ * FROM VALUES(1, 'a'),(2, 'b') AS t(c1, c2)")) - check(sql("INSERT INTO TABLE tmp1 PARTITION(c2) " + + checkWrite(sql("INSERT INTO TABLE tmp1 PARTITION(c2) " + "SELECT * FROM VALUES(1, 'a'),(2, 'b') AS t(c1, c2) ORDER BY c1")) - check(sql("INSERT INTO TABLE tmp1 PARTITION(c2) " + + checkWrite(sql("INSERT INTO TABLE tmp1 PARTITION(c2) " + "SELECT * FROM VALUES(1, 'a'),(2, 'b') AS t(c1, c2) LIMIT 10")) } } @@ -128,13 +126,13 @@ class RebalanceBeforeWritingSuite extends KyuubiSparkSQLExtensionTest { Seq("USING PARQUET", "").foreach { storage => withTable("tmp1") { sql(s"CREATE TABLE tmp1 (c1 int) $storage PARTITIONED BY (c2 string)") - check(sql("INSERT INTO TABLE tmp1 PARTITION(c2) " + + checkWrite(sql("INSERT INTO TABLE tmp1 PARTITION(c2) " + "SELECT * FROM VALUES(1, 'a'),(2, 'b') AS t(c1, c2)")) } withTable("tmp1") { sql(s"CREATE TABLE tmp1 (c1 int) $storage") - check(sql("INSERT INTO TABLE tmp1 SELECT * FROM VALUES(1),(2),(3) AS t(c1)")) + checkWrite(sql("INSERT INTO TABLE tmp1 SELECT * FROM VALUES(1),(2),(3) AS t(c1)")) } } } @@ -143,12 +141,10 @@ class RebalanceBeforeWritingSuite extends KyuubiSparkSQLExtensionTest { test("test dynamic partition write") { def checkRepartitionExpression(sqlString: String): Unit = { withListener(sqlString) { write => - assert(write.isInstanceOf[InsertIntoHiveTable]) - assert(write.collect { - case r: RebalancePartitions if r.partitionExpressions.size == 1 => - assert(r.partitionExpressions.head.asInstanceOf[Attribute].name === "c2") - r - }.size == 1) + val rebalancePartitions = collectRebalancePartitions(write) + assert(rebalancePartitions.size == 1) + assert(rebalancePartitions.head.partitionExpressions.size == 1 && + rebalancePartitions.head.partitionExpressions.head.asInstanceOf[Attribute].name === "c2") } } @@ -156,14 +152,19 @@ class RebalanceBeforeWritingSuite extends KyuubiSparkSQLExtensionTest { KyuubiSQLConf.INSERT_REPARTITION_BEFORE_WRITE.key -> "true", KyuubiSQLConf.INSERT_REPARTITION_BEFORE_WRITE_IF_NO_SHUFFLE.key -> "true") { Seq("USING PARQUET", "").foreach { storage => - withTable("tmp1") { - sql(s"CREATE TABLE tmp1 (c1 int) $storage PARTITIONED BY (c2 string)") - checkRepartitionExpression("INSERT INTO TABLE tmp1 SELECT 1 as c1, 'a' as c2 ") - } + withTable("tmp2") { + sql(s"CREATE TABLE tmp2 (c1 int, c2 string)") + sql(s"INSERT INTO tmp2 SELECT 1, 'a'") - withTable("tmp1") { - checkRepartitionExpression( - "CREATE TABLE tmp1 PARTITIONED BY(C2) SELECT 1 as c1, 'a' as c2") + withTable("tmp1") { + sql(s"CREATE TABLE tmp1 (c1 int) $storage PARTITIONED BY (c2 string)") + checkRepartitionExpression("INSERT INTO TABLE tmp1 SELECT c1, c2 from tmp2") + } + + withTable("tmp1") { + checkRepartitionExpression( + "CREATE TABLE tmp1 PARTITIONED BY(C2) SELECT c1, c2 from tmp2") + } } } } @@ -177,9 +178,7 @@ class RebalanceBeforeWritingSuite extends KyuubiSparkSQLExtensionTest { withTable("t") { withListener("CREATE TABLE t STORED AS parquet AS SELECT 1 as a") { write => assert(write.isInstanceOf[InsertIntoHadoopFsRelationCommand]) - assert(write.collect { - case _: RebalancePartitions => true - }.size == 1) + assert(collectRebalancePartitions(write).size == 1) } } } From df4d79866863ecc3fe8ae1d34ab4df7ed42d4b5e Mon Sep 17 00:00:00 2001 From: wforget <643348094@qq.com> Date: Tue, 5 Nov 2024 16:08:48 +0800 Subject: [PATCH 2/4] fix name --- .../org/apache/spark/sql/KyuubiSparkSQLExtensionTest.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/KyuubiSparkSQLExtensionTest.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/KyuubiSparkSQLExtensionTest.scala index 11704a26d40..762d6bb1ebc 100644 --- a/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/KyuubiSparkSQLExtensionTest.scala +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/KyuubiSparkSQLExtensionTest.scala @@ -111,15 +111,15 @@ trait KyuubiSparkSQLExtensionTest extends QueryTest override def onFailure(f: String, qe: QueryExecution, e: Exception): Unit = {} override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { - def doCallback(plan: SparkPlan): Unit = { + def collectWrite(plan: SparkPlan): Unit = { plan match { case write: DataWritingCommandExec => writes.add(write.cmd) - case a: AdaptiveSparkPlanExec => doCallback(a.executedPlan) + case a: AdaptiveSparkPlanExec => collectWrite(a.executedPlan) case _ => } } - doCallback(qe.executedPlan) + collectWrite(qe.executedPlan) } } spark.listenerManager.register(listener) From adb0f9b037845a3fdf7ecd23359c04be0331f3ce Mon Sep 17 00:00:00 2001 From: wforget <643348094@qq.com> Date: Tue, 5 Nov 2024 18:09:13 +0800 Subject: [PATCH 3/4] fix zorder test --- .../apache/spark/sql/ZorderSuiteBase.scala | 39 ++++++++++++------- 1 file changed, 24 insertions(+), 15 deletions(-) diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/ZorderSuiteBase.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/ZorderSuiteBase.scala index b8e01fa836b..c50e94f43f4 100644 --- a/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/ZorderSuiteBase.scala +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/ZorderSuiteBase.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFu import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, AttributeReference, EqualTo, Expression, ExpressionEvalHelper, Literal, NullsLast, SortOrder} import org.apache.spark.sql.catalyst.parser.{ParseException, ParserInterface} import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation, Project, Sort} -import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand +import org.apache.spark.sql.execution.datasources.{InsertIntoHadoopFsRelationCommand, WriteFiles} import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.execution.InsertIntoHiveTable import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} @@ -245,7 +245,15 @@ trait ZorderSuiteBase extends KyuubiSparkSQLExtensionTest with ExpressionEvalHel planHasRepartition: Boolean, resHasSort: Boolean): Unit = { def checkSort(plan: LogicalPlan): Unit = { - assert(plan.isInstanceOf[Sort] === resHasSort) + def collectSort(plan: LogicalPlan): Option[Sort] = { + plan match { + case sort: Sort => Some(sort) + case f: WriteFiles => collectSort(f.child) + case _ => None + } + } + val sortOpt = collectSort(plan) + assert(sortOpt.isDefined === resHasSort) plan match { case sort: Sort => val colArr = cols.split(",") @@ -332,19 +340,20 @@ trait ZorderSuiteBase extends KyuubiSparkSQLExtensionTest with ExpressionEvalHel assert(df1.queryExecution.analyzed.isInstanceOf[InsertIntoHadoopFsRelationCommand]) checkSort(df1.queryExecution.analyzed.children.head) - withListener( - s""" - |CREATE TABLE zorder_t4 USING PARQUET - |TBLPROPERTIES ( - | 'kyuubi.zorder.enabled' = '$enabled', - | 'kyuubi.zorder.cols' = '$cols') - | - |SELECT $repartition * FROM - |VALUES(1,'a',2,4D),(2,'b',3,6D) AS t(c1 ,c2 , c3, c4) - |""".stripMargin) { write => - assert(write.isInstanceOf[InsertIntoHadoopFsRelationCommand]) - checkSort(write.query) - } + // TODO: CreateDataSourceTableAsSelectCommand is not supported +// withListener( +// s""" +// |CREATE TABLE zorder_t4 USING PARQUET +// |TBLPROPERTIES ( +// | 'kyuubi.zorder.enabled' = '$enabled', +// | 'kyuubi.zorder.cols' = '$cols') +// | +// |SELECT $repartition * FROM +// |VALUES(1,'a',2,4D),(2,'b',3,6D) AS t(c1 ,c2 , c3, c4) +// |""".stripMargin) { write => +// assert(write.isInstanceOf[InsertIntoHadoopFsRelationCommand]) +// checkSort(write.query) +// } } } } From a989de58f38ee9c4c8c652650c0fb314dad0f399 Mon Sep 17 00:00:00 2001 From: wforget <643348094@qq.com> Date: Wed, 6 Nov 2024 10:16:07 +0800 Subject: [PATCH 4/4] fix listener --- .../org/apache/spark/sql/KyuubiSparkSQLExtensionTest.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/KyuubiSparkSQLExtensionTest.scala b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/KyuubiSparkSQLExtensionTest.scala index 762d6bb1ebc..72d05a7289d 100644 --- a/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/KyuubiSparkSQLExtensionTest.scala +++ b/extensions/spark/kyuubi-extension-spark-3-5/src/test/scala/org/apache/spark/sql/KyuubiSparkSQLExtensionTest.scala @@ -122,6 +122,8 @@ trait KyuubiSparkSQLExtensionTest extends QueryTest collectWrite(qe.executedPlan) } } + // Make sure the listener is registered after all previous events have been processed + sparkContext.listenerBus.waitUntilEmpty() spark.listenerManager.register(listener) try { df.collect()