diff --git a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/procedure/CompactProcedure.java b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/procedure/CompactProcedure.java index cc78d1d68e47..964fca5aa80c 100644 --- a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/procedure/CompactProcedure.java +++ b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/procedure/CompactProcedure.java @@ -56,6 +56,7 @@ import org.apache.spark.sql.Dataset; import org.apache.spark.sql.PaimonUtils; import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; @@ -66,6 +67,8 @@ import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import javax.annotation.Nullable; @@ -97,6 +100,8 @@ */ public class CompactProcedure extends BaseProcedure { + private static final Logger LOG = LoggerFactory.getLogger(CompactProcedure.class); + private static final ProcedureParameter[] PARAMETERS = new ProcedureParameter[] { ProcedureParameter.required("table", StringType), @@ -182,7 +187,6 @@ public InternalRow[] call(InternalRow args) { dynamicOptions.putAll(ParameterUtils.parseCommaSeparatedKeyValues(options)); } table = table.copy(dynamicOptions); - InternalRow internalRow = newInternalRow( execute( @@ -279,10 +283,11 @@ private void compactAwareBucketTable( return; } + int readParallelism = readParallelism(partitionBuckets, spark()); BatchWriteBuilder writeBuilder = table.newBatchWriteBuilder(); JavaRDD commitMessageJavaRDD = javaSparkContext - .parallelize(partitionBuckets) + .parallelize(partitionBuckets, readParallelism) .mapPartitions( (FlatMapFunction>, byte[]>) pairIterator -> { @@ -355,6 +360,7 @@ private void compactUnAwareBucketTable( .collect(Collectors.toList()); } if (compactionTasks.isEmpty()) { + System.out.println("compaction task is empty."); return; } @@ -368,10 +374,11 @@ private void compactUnAwareBucketTable( throw new RuntimeException("serialize compaction task failed"); } + int readParallelism = readParallelism(serializedTasks, spark()); String commitUser = createCommitUser(table.coreOptions().toConfiguration()); JavaRDD commitMessageJavaRDD = javaSparkContext - .parallelize(serializedTasks) + .parallelize(serializedTasks, readParallelism) .mapPartitions( (FlatMapFunction, byte[]>) taskIterator -> { @@ -485,6 +492,22 @@ private Map packForSort(List dataSplits) { list -> list.toArray(new DataSplit[0])))); } + private int readParallelism(List groupedTasks, SparkSession spark) { + int sparkParallelism = + Math.max( + spark.sparkContext().defaultParallelism(), + spark.sessionState().conf().numShufflePartitions()); + int readParallelism = Math.min(groupedTasks.size(), sparkParallelism); + if (sparkParallelism > readParallelism) { + LOG.warn( + String.format( + "Spark default parallelism (%s) is greater than bucket or task parallelism (%s)," + + "we use %s as the final read parallelism", + sparkParallelism, readParallelism, readParallelism)); + } + return readParallelism; + } + @VisibleForTesting static String toWhere(String partitions) { List> maps = ParameterUtils.getPartitions(partitions.split(";")); diff --git a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/procedure/CompactProcedureTestBase.scala b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/procedure/CompactProcedureTestBase.scala index d3d77ccef41c..130860c8351e 100644 --- a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/procedure/CompactProcedureTestBase.scala +++ b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/procedure/CompactProcedureTestBase.scala @@ -24,6 +24,7 @@ import org.apache.paimon.spark.PaimonSparkTestBase import org.apache.paimon.table.FileStoreTable import org.apache.paimon.table.source.DataSplit +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd, SparkListenerStageCompleted, SparkListenerStageSubmitted} import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.streaming.StreamTest @@ -648,6 +649,87 @@ abstract class CompactProcedureTestBase extends PaimonSparkTestBase with StreamT } } + test("Paimon Procedure: test aware-bucket compaction read parallelism") { + spark.sql(s""" + |CREATE TABLE T (id INT, value STRING) + |TBLPROPERTIES ('primary-key'='id', 'bucket'='3', 'write-only'='true') + |""".stripMargin) + + val table = loadTable("T") + for (i <- 1 to 10) { + sql(s"INSERT INTO T VALUES ($i, '$i')") + } + assertResult(10)(table.snapshotManager().snapshotCount()) + + val buckets = table.newSnapshotReader().bucketEntries().asScala.map(_.bucket()).distinct.size + assertResult(3)(buckets) + + val taskBuffer = scala.collection.mutable.ListBuffer.empty[Int] + val listener = new SparkListener { + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = { + taskBuffer += stageSubmitted.stageInfo.numTasks + } + } + + try { + spark.sparkContext.addSparkListener(listener) + + // spark.default.parallelism cannot be change in spark session + // sparkParallelism is 2, bucket is 3, use 2 as the read parallelism + spark.conf.set("spark.sql.shuffle.partitions", 2) + spark.sql("CALL sys.compact(table => 'T')") + + // sparkParallelism is 5, bucket is 3, use 3 as the read parallelism + spark.conf.set("spark.sql.shuffle.partitions", 5) + spark.sql("CALL sys.compact(table => 'T')") + + assertResult(Seq(2, 3))(taskBuffer) + } finally { + spark.sparkContext.removeSparkListener(listener) + } + } + + test("Paimon Procedure: test unaware-bucket compaction read parallelism") { + spark.sql(s""" + |CREATE TABLE T (id INT, value STRING) + |TBLPROPERTIES ('bucket'='-1', 'write-only'='true') + |""".stripMargin) + + val table = loadTable("T") + for (i <- 1 to 12) { + sql(s"INSERT INTO T VALUES ($i, '$i')") + } + assertResult(12)(table.snapshotManager().snapshotCount()) + + val buckets = table.newSnapshotReader().bucketEntries().asScala.map(_.bucket()).distinct.size + // only has bucket-0 + assertResult(1)(buckets) + + val taskBuffer = scala.collection.mutable.ListBuffer.empty[Int] + val listener = new SparkListener { + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = { + taskBuffer += stageSubmitted.stageInfo.numTasks + } + } + + try { + spark.sparkContext.addSparkListener(listener) + + // spark.default.parallelism cannot be change in spark session + // sparkParallelism is 2, task groups is 6, use 2 as the read parallelism + spark.conf.set("spark.sql.shuffle.partitions", 2) + spark.sql("CALL sys.compact(table => 'T', options => 'compaction.max.file-num=2')") + + // sparkParallelism is 5, task groups is 3, use 3 as the read parallelism + spark.conf.set("spark.sql.shuffle.partitions", 5) + spark.sql("CALL sys.compact(table => 'T', options => 'compaction.max.file-num=2')") + + assertResult(Seq(2, 3))(taskBuffer) + } finally { + spark.sparkContext.removeSparkListener(listener) + } + } + def lastSnapshotCommand(table: FileStoreTable): CommitKind = { table.snapshotManager().latestSnapshot().commitKind() }