Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[spark] Support cross partition write #2212

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
* limitations under the License.
*/

package org.apache.paimon.flink.sink.index;
package org.apache.paimon.crosspartition;

/** Type of record, key or full row. */
public enum KeyPartOrRow {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import java.util.stream.Collectors;
import java.util.stream.Stream;

/** A {@link PartitionKeyExtractor} to {@link InternalRow} with only key and partiton fields. */
/** A {@link PartitionKeyExtractor} to {@link InternalRow} with only key and partition fields. */
public class KeyPartPartitionKeyExtractor implements PartitionKeyExtractor<InternalRow> {

private final Projection partitionProjection;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package org.apache.paimon.flink.sink.index;

import org.apache.paimon.crosspartition.IndexBootstrap;
import org.apache.paimon.crosspartition.KeyPartOrRow;
import org.apache.paimon.data.InternalRow;
import org.apache.paimon.flink.sink.Committable;
import org.apache.paimon.flink.sink.DynamicBucketRowWriteOperator;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package org.apache.paimon.flink.sink.index;

import org.apache.paimon.crosspartition.GlobalIndexAssigner;
import org.apache.paimon.crosspartition.KeyPartOrRow;
import org.apache.paimon.data.InternalRow;
import org.apache.paimon.disk.IOManager;
import org.apache.paimon.table.Table;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package org.apache.paimon.flink.sink.index;

import org.apache.paimon.crosspartition.IndexBootstrap;
import org.apache.paimon.crosspartition.KeyPartOrRow;
import org.apache.paimon.data.InternalRow;
import org.apache.paimon.utils.SerializableFunction;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import org.apache.paimon.codegen.CodeGenUtils;
import org.apache.paimon.codegen.Projection;
import org.apache.paimon.crosspartition.KeyPartOrRow;
import org.apache.paimon.data.BinaryRow;
import org.apache.paimon.data.InternalRow;
import org.apache.paimon.flink.sink.ChannelComputer;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

package org.apache.paimon.flink.sink.index;

import org.apache.paimon.crosspartition.KeyPartOrRow;
import org.apache.paimon.flink.utils.InternalTypeSerializer;

import org.apache.flink.api.common.typeutils.TypeSerializer;
Expand All @@ -28,7 +29,7 @@
import java.io.IOException;
import java.util.Objects;

import static org.apache.paimon.flink.sink.index.KeyPartOrRow.KEY_PART;
import static org.apache.paimon.crosspartition.KeyPartOrRow.KEY_PART;

/** A {@link InternalTypeSerializer} to serialize KeyPartOrRow with T. */
public class KeyWithRowSerializer<T> extends InternalTypeSerializer<Tuple2<KeyPartOrRow, T>> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,25 @@
package org.apache.paimon.spark.commands

import org.apache.paimon.CoreOptions.DYNAMIC_PARTITION_OVERWRITE
import org.apache.paimon.data.BinaryRow
import org.apache.paimon.codegen.{CodeGenUtils, Projection}
import org.apache.paimon.crosspartition.{GlobalIndexAssigner, IndexBootstrap, KeyPartOrRow}
import org.apache.paimon.data.{BinaryRow, GenericRow, JoinedRow}
import org.apache.paimon.data.serializer.InternalSerializers
import org.apache.paimon.index.PartitionIndex
import org.apache.paimon.options.Options
import org.apache.paimon.spark.{DynamicOverWrite, InsertInto, Overwrite, SaveMode, SparkConnectorOptions, SparkRow}
import org.apache.paimon.spark._
import org.apache.paimon.spark.SparkUtils.createIOManager
import org.apache.paimon.spark.schema.SparkSystemColumns
import org.apache.paimon.spark.schema.SparkSystemColumns.{BUCKET_COL, ROW_KIND_COL}
import org.apache.paimon.spark.util.{EncoderUtils, SparkRowUtils}
import org.apache.paimon.table.{BucketMode, FileStoreTable}
import org.apache.paimon.table.sink.{BatchWriteBuilder, CommitMessageSerializer, DynamicBucketRow, RowPartitionKeyExtractor}
import org.apache.paimon.types.RowType
import org.apache.paimon.types.{RowKind, RowType}
import org.apache.paimon.utils.SerializationUtils

import org.apache.spark.TaskContext
import org.apache.spark.{HashPartitioner, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
Expand Down Expand Up @@ -78,51 +83,117 @@ case class WriteIntoPaimonTable(
val primaryKeyCols = tableSchema.trimmedPrimaryKeys().asScala.map(col)
val partitionCols = tableSchema.partitionKeys().asScala.map(col)

val dataEncoder = EncoderUtils.encode(dataSchema).resolveAndBind()
val originFromRow = dataEncoder.createDeserializer()
val (_, _, originFromRow) = EncoderUtils.getEncoderAndSerDe(dataSchema)

val rowkindColIdx = SparkRowUtils.getFieldIndex(data.schema, ROW_KIND_COL)
var newData = data

// append _bucket_ column as placeholder
val withBucketCol = data.withColumn(BUCKET_COL, lit(-1))
val bucketColIdx = withBucketCol.schema.size - 1
val withBucketDataEncoder = EncoderUtils.encode(withBucketCol.schema).resolveAndBind()
val toRow = withBucketDataEncoder.createSerializer()
val fromRow = withBucketDataEncoder.createDeserializer()
if (
bucketMode.equals(BucketMode.GLOBAL_DYNAMIC) && !newData.schema.fieldNames.contains(
ROW_KIND_COL)
) {
newData = data.withColumn(ROW_KIND_COL, lit(RowKind.INSERT.toByteValue))
}
val rowkindColIdx = SparkRowUtils.getFieldIndex(newData.schema, ROW_KIND_COL)

// append bucket column as placeholder
newData = newData.withColumn(BUCKET_COL, lit(-1))
val bucketColIdx = SparkRowUtils.getFieldIndex(newData.schema, BUCKET_COL)

val (newDataEncoder, toRow, fromRow) = EncoderUtils.getEncoderAndSerDe(newData.schema)

def repartitionByBucket(ds: Dataset[Row]) = {
ds.toDF().repartition(partitionCols ++ Seq(col(BUCKET_COL)): _*)
def repartitionByBucket(df: DataFrame) = {
df.repartition(partitionCols ++ Seq(col(BUCKET_COL)): _*)
}

val rowType = table.rowType()
val writeBuilder = table.newBatchWriteBuilder()

val df =
val df: Dataset[Row] =
bucketMode match {
case BucketMode.DYNAMIC =>
// Topology: input -> shuffle by key hash -> bucket-assigner -> shuffle by partition & bucket
val partitioned = if (primaryKeyCols.nonEmpty) {
// Make sure that the records with the same bucket values is within a task.
withBucketCol.repartition(primaryKeyCols: _*)
newData.repartition(primaryKeyCols: _*)
} else {
withBucketCol
newData
}
val numSparkPartitions = partitioned.rdd.getNumPartitions
val dynamicBucketProcessor =
DynamicBucketProcessor(table, rowType, bucketColIdx, numSparkPartitions, toRow, fromRow)

repartitionByBucket(
partitioned.mapPartitions(dynamicBucketProcessor.processPartition)(
withBucketDataEncoder))
partitioned
.mapPartitions(dynamicBucketProcessor.processPartition)(newDataEncoder)
.toDF())
case BucketMode.GLOBAL_DYNAMIC =>
// Topology: input -> bootstrap -> shuffle by key hash -> bucket-assigner -> shuffle by partition & bucket
val numSparkPartitions = newData.rdd.getNumPartitions
val primaryKeys: java.util.List[String] = table.schema().primaryKeys()
val bootstrapType: RowType = IndexBootstrap.bootstrapType(table.schema())
val rowType: RowType = SparkTypeUtils.toPaimonType(newData.schema).asInstanceOf[RowType]

// row: (keyHash, (kind, internalRow))
val bootstrapRow: RDD[(Int, (KeyPartOrRow, Array[Byte]))] = newData.rdd.mapPartitions {
iter =>
{
val sparkPartitionId = TaskContext.getPartitionId()

val keyPartProject: Projection =
CodeGenUtils.newProjection(bootstrapType, primaryKeys)
val rowProject: Projection = CodeGenUtils.newProjection(rowType, primaryKeys)
val bootstrapSer = InternalSerializers.create(bootstrapType)
val rowSer = InternalSerializers.create(rowType)

val lst = scala.collection.mutable.ListBuffer[(Int, (KeyPartOrRow, Array[Byte]))]()

val bootstrap = new IndexBootstrap(table)
bootstrap.bootstrap(
numSparkPartitions,
sparkPartitionId,
row => {
val bytes: Array[Byte] =
SerializationUtils.serializeBinaryRow(bootstrapSer.toBinaryRow(row))
lst.append((keyPartProject(row).hashCode(), (KeyPartOrRow.KEY_PART, bytes)))
}
)
lst.iterator ++ iter.map(
r => {
val sparkRow =
new SparkRow(rowType, r, SparkRowUtils.getRowKind(r, rowkindColIdx))
val bytes: Array[Byte] =
SerializationUtils.serializeBinaryRow(rowSer.toBinaryRow(sparkRow))
(rowProject(sparkRow).hashCode(), (KeyPartOrRow.ROW, bytes))
})
}
}

var assignerParallelism: Integer = table.coreOptions.dynamicBucketAssignerParallelism
if (assignerParallelism == null) {
assignerParallelism = numSparkPartitions
}

val globalDynamicBucketProcessor =
GlobalDynamicBucketProcessor(table, rowType, fromRow, assignerParallelism)
val rowRDD = bootstrapRow
.partitionBy(new HashPartitioner(assignerParallelism))
.mapPartitions(globalDynamicBucketProcessor.processPartition)

repartitionByBucket(sparkSession.createDataFrame(rowRDD, newData.schema))
case BucketMode.UNAWARE =>
// Topology: input -> bucket-assigner
val unawareBucketProcessor = UnawareBucketProcessor(bucketColIdx, toRow, fromRow)
withBucketCol
.mapPartitions(unawareBucketProcessor.processPartition)(withBucketDataEncoder)
newData
.mapPartitions(unawareBucketProcessor.processPartition)(newDataEncoder)
.toDF()
case BucketMode.FIXED =>
// Topology: input -> bucket-assigner -> shuffle by partition & bucket
val commonBucketProcessor =
CommonBucketProcessor(writeBuilder, bucketColIdx, toRow, fromRow)
repartitionByBucket(
withBucketCol.mapPartitions(commonBucketProcessor.processPartition)(
withBucketDataEncoder))
newData.mapPartitions(commonBucketProcessor.processPartition)(newDataEncoder).toDF())
case _ =>
throw new UnsupportedOperationException(s"unsupported bucket mode $bucketMode")
}

val commitMessages = df
Expand Down Expand Up @@ -196,16 +267,16 @@ case class WriteIntoPaimonTable(

object WriteIntoPaimonTable {

sealed trait BucketProcessor {
def processPartition(rowIterator: Iterator[Row]): Iterator[Row]
sealed trait BucketProcessor[In] {
def processPartition(rowIterator: Iterator[In]): Iterator[Row]
}

case class CommonBucketProcessor(
writeBuilder: BatchWriteBuilder,
bucketColIndex: Int,
toRow: ExpressionEncoder.Serializer[Row],
fromRow: ExpressionEncoder.Deserializer[Row])
extends BucketProcessor {
extends BucketProcessor[Row] {

private val rowType = writeBuilder.rowType

Expand Down Expand Up @@ -233,7 +304,7 @@ object WriteIntoPaimonTable {
numSparkPartitions: Long,
toRow: ExpressionEncoder.Serializer[Row],
fromRow: ExpressionEncoder.Deserializer[Row]
) extends BucketProcessor {
) extends BucketProcessor[Row] {

private val targetBucketRowNumber = fileStoreTable.coreOptions.dynamicBucketTargetRowNum

Expand Down Expand Up @@ -270,11 +341,58 @@ object WriteIntoPaimonTable {
}
}

case class GlobalDynamicBucketProcessor(
fileStoreTable: FileStoreTable,
rowType: RowType,
fromRow: ExpressionEncoder.Deserializer[Row],
assignerParallelism: Integer)
extends BucketProcessor[(Int, (KeyPartOrRow, Array[Byte]))] {

override def processPartition(
iter: Iterator[(Int, (KeyPartOrRow, Array[Byte]))]): Iterator[Row] = {
val sparkPartitionId = TaskContext.getPartitionId()
val lst = scala.collection.mutable.ListBuffer[Row]()
val ioManager = createIOManager
val assigner = new GlobalIndexAssigner(fileStoreTable)
try {
assigner.open(
ioManager,
assignerParallelism,
sparkPartitionId,
(row, bucket) => {
val extraRow: GenericRow = new GenericRow(2)
extraRow.setField(0, row.getRowKind.toByteValue)
extraRow.setField(1, bucket)
lst.append(fromRow(SparkInternalRow.fromPaimon(new JoinedRow(row, extraRow), rowType)))
}
)
iter.foreach(
row => {
val tuple: (KeyPartOrRow, Array[Byte]) = row._2
val binaryRow = SerializationUtils.deserializeBinaryRow(tuple._2)
tuple._1 match {
case KeyPartOrRow.KEY_PART => assigner.bootstrapKey(binaryRow)
case KeyPartOrRow.ROW => assigner.processInput(binaryRow)
case _ =>
throw new UnsupportedOperationException(s"unknown kind ${tuple._1}")
}
})
assigner.endBoostrap(true)
lst.iterator
} finally {
assigner.close()
if (ioManager != null) {
ioManager.close()
}
}
}
}

case class UnawareBucketProcessor(
bucketColIndex: Int,
toRow: ExpressionEncoder.Serializer[Row],
fromRow: ExpressionEncoder.Deserializer[Row])
extends BucketProcessor {
extends BucketProcessor[Row] {

def processPartition(rowIterator: Iterator[Row]): Iterator[Row] = {
new Iterator[Row] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,10 @@ object EncoderUtils {
.reflectMethod(method)(schema)
.asInstanceOf[ExpressionEncoder[Row]]
}

def getEncoderAndSerDe(schema: StructType)
: (ExpressionEncoder[Row], ExpressionEncoder.Serializer[Row], ExpressionEncoder.Deserializer[Row]) = {
val encoder = encode(schema).resolveAndBind()
(encoder, encoder.createSerializer(), encoder.createDeserializer())
}
}
Loading