Skip to content

Commit

Permalink
[spark] Compaction add parallelize parallelism to avoid small partiti…
Browse files Browse the repository at this point in the history
…ons (#4158)
  • Loading branch information
askwang committed Sep 13, 2024
1 parent 588d7f2 commit 2c45ac0
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.PaimonUtils;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
Expand All @@ -66,6 +67,8 @@
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nullable;

Expand Down Expand Up @@ -97,6 +100,8 @@
*/
public class CompactProcedure extends BaseProcedure {

private static final Logger LOG = LoggerFactory.getLogger(CompactProcedure.class);

private static final ProcedureParameter[] PARAMETERS =
new ProcedureParameter[] {
ProcedureParameter.required("table", StringType),
Expand Down Expand Up @@ -182,7 +187,6 @@ public InternalRow[] call(InternalRow args) {
dynamicOptions.putAll(ParameterUtils.parseCommaSeparatedKeyValues(options));
}
table = table.copy(dynamicOptions);

InternalRow internalRow =
newInternalRow(
execute(
Expand Down Expand Up @@ -279,10 +283,11 @@ private void compactAwareBucketTable(
return;
}

int readParallelism = readParallelism(partitionBuckets, spark());
BatchWriteBuilder writeBuilder = table.newBatchWriteBuilder();
JavaRDD<byte[]> commitMessageJavaRDD =
javaSparkContext
.parallelize(partitionBuckets)
.parallelize(partitionBuckets, readParallelism)
.mapPartitions(
(FlatMapFunction<Iterator<Pair<byte[], Integer>>, byte[]>)
pairIterator -> {
Expand Down Expand Up @@ -355,6 +360,7 @@ private void compactUnAwareBucketTable(
.collect(Collectors.toList());
}
if (compactionTasks.isEmpty()) {
System.out.println("compaction task is empty.");
return;
}

Expand All @@ -368,10 +374,11 @@ private void compactUnAwareBucketTable(
throw new RuntimeException("serialize compaction task failed");
}

int readParallelism = readParallelism(serializedTasks, spark());
String commitUser = createCommitUser(table.coreOptions().toConfiguration());
JavaRDD<byte[]> commitMessageJavaRDD =
javaSparkContext
.parallelize(serializedTasks)
.parallelize(serializedTasks, readParallelism)
.mapPartitions(
(FlatMapFunction<Iterator<byte[]>, byte[]>)
taskIterator -> {
Expand Down Expand Up @@ -485,6 +492,22 @@ private Map<BinaryRow, DataSplit[]> packForSort(List<DataSplit> dataSplits) {
list -> list.toArray(new DataSplit[0]))));
}

private int readParallelism(List<?> groupedTasks, SparkSession spark) {
int sparkParallelism =
Math.max(
spark.sparkContext().defaultParallelism(),
spark.sessionState().conf().numShufflePartitions());
int readParallelism = Math.min(groupedTasks.size(), sparkParallelism);
if (sparkParallelism > readParallelism) {
LOG.warn(
String.format(
"Spark default parallelism (%s) is greater than bucket or task parallelism (%s),"
+ "we use %s as the final read parallelism",
sparkParallelism, readParallelism, readParallelism));
}
return readParallelism;
}

@VisibleForTesting
static String toWhere(String partitions) {
List<Map<String, String>> maps = ParameterUtils.getPartitions(partitions.split(";"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.paimon.spark.PaimonSparkTestBase
import org.apache.paimon.table.FileStoreTable
import org.apache.paimon.table.source.DataSplit

import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd, SparkListenerStageCompleted, SparkListenerStageSubmitted}
import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.streaming.StreamTest
Expand Down Expand Up @@ -648,6 +649,87 @@ abstract class CompactProcedureTestBase extends PaimonSparkTestBase with StreamT
}
}

test("Paimon Procedure: test aware-bucket compaction read parallelism") {
spark.sql(s"""
|CREATE TABLE T (id INT, value STRING)
|TBLPROPERTIES ('primary-key'='id', 'bucket'='3', 'write-only'='true')
|""".stripMargin)

val table = loadTable("T")
for (i <- 1 to 10) {
sql(s"INSERT INTO T VALUES ($i, '$i')")
}
assertResult(10)(table.snapshotManager().snapshotCount())

val buckets = table.newSnapshotReader().bucketEntries().asScala.map(_.bucket()).distinct.size
assertResult(3)(buckets)

val taskBuffer = scala.collection.mutable.ListBuffer.empty[Int]
val listener = new SparkListener {
override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = {
taskBuffer += stageSubmitted.stageInfo.numTasks
}
}

try {
spark.sparkContext.addSparkListener(listener)

// spark.default.parallelism cannot be change in spark session
// sparkParallelism is 2, bucket is 3, use 2 as the read parallelism
spark.conf.set("spark.sql.shuffle.partitions", 2)
spark.sql("CALL sys.compact(table => 'T')")

// sparkParallelism is 5, bucket is 3, use 3 as the read parallelism
spark.conf.set("spark.sql.shuffle.partitions", 5)
spark.sql("CALL sys.compact(table => 'T')")

assertResult(Seq(2, 3))(taskBuffer)
} finally {
spark.sparkContext.removeSparkListener(listener)
}
}

test("Paimon Procedure: test unaware-bucket compaction read parallelism") {
spark.sql(s"""
|CREATE TABLE T (id INT, value STRING)
|TBLPROPERTIES ('bucket'='-1', 'write-only'='true')
|""".stripMargin)

val table = loadTable("T")
for (i <- 1 to 12) {
sql(s"INSERT INTO T VALUES ($i, '$i')")
}
assertResult(12)(table.snapshotManager().snapshotCount())

val buckets = table.newSnapshotReader().bucketEntries().asScala.map(_.bucket()).distinct.size
// only has bucket-0
assertResult(1)(buckets)

val taskBuffer = scala.collection.mutable.ListBuffer.empty[Int]
val listener = new SparkListener {
override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = {
taskBuffer += stageSubmitted.stageInfo.numTasks
}
}

try {
spark.sparkContext.addSparkListener(listener)

// spark.default.parallelism cannot be change in spark session
// sparkParallelism is 2, task groups is 6, use 2 as the read parallelism
spark.conf.set("spark.sql.shuffle.partitions", 2)
spark.sql("CALL sys.compact(table => 'T', options => 'compaction.max.file-num=2')")

// sparkParallelism is 5, task groups is 3, use 3 as the read parallelism
spark.conf.set("spark.sql.shuffle.partitions", 5)
spark.sql("CALL sys.compact(table => 'T', options => 'compaction.max.file-num=2')")

assertResult(Seq(2, 3))(taskBuffer)
} finally {
spark.sparkContext.removeSparkListener(listener)
}
}

def lastSnapshotCommand(table: FileStoreTable): CommitKind = {
table.snapshotManager().latestSnapshot().commitKind()
}
Expand Down

0 comments on commit 2c45ac0

Please sign in to comment.