Skip to content

Commit

Permalink
support cross partition insert
Browse files Browse the repository at this point in the history
  • Loading branch information
Zouxxyy committed Oct 31, 2023
1 parent 1deec10 commit 82a803c
Show file tree
Hide file tree
Showing 11 changed files with 242 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
* limitations under the License.
*/

package org.apache.paimon.flink.sink.index;
package org.apache.paimon.crosspartition;

/** Type of record, key or full row. */
public enum KeyPartOrRow {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import java.util.stream.Collectors;
import java.util.stream.Stream;

/** A {@link PartitionKeyExtractor} to {@link InternalRow} with only key and partiton fields. */
/** A {@link PartitionKeyExtractor} to {@link InternalRow} with only key and partition fields. */
public class KeyPartPartitionKeyExtractor implements PartitionKeyExtractor<InternalRow> {

private final Projection partitionProjection;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package org.apache.paimon.flink.sink.index;

import org.apache.paimon.crosspartition.IndexBootstrap;
import org.apache.paimon.crosspartition.KeyPartOrRow;
import org.apache.paimon.data.InternalRow;
import org.apache.paimon.flink.sink.Committable;
import org.apache.paimon.flink.sink.DynamicBucketRowWriteOperator;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package org.apache.paimon.flink.sink.index;

import org.apache.paimon.crosspartition.GlobalIndexAssigner;
import org.apache.paimon.crosspartition.KeyPartOrRow;
import org.apache.paimon.data.InternalRow;
import org.apache.paimon.disk.IOManager;
import org.apache.paimon.table.Table;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package org.apache.paimon.flink.sink.index;

import org.apache.paimon.crosspartition.IndexBootstrap;
import org.apache.paimon.crosspartition.KeyPartOrRow;
import org.apache.paimon.data.InternalRow;
import org.apache.paimon.utils.SerializableFunction;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import org.apache.paimon.codegen.CodeGenUtils;
import org.apache.paimon.codegen.Projection;
import org.apache.paimon.crosspartition.KeyPartOrRow;
import org.apache.paimon.data.BinaryRow;
import org.apache.paimon.data.InternalRow;
import org.apache.paimon.flink.sink.ChannelComputer;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

package org.apache.paimon.flink.sink.index;

import org.apache.paimon.crosspartition.KeyPartOrRow;
import org.apache.paimon.flink.utils.InternalTypeSerializer;

import org.apache.flink.api.common.typeutils.TypeSerializer;
Expand All @@ -28,7 +29,7 @@
import java.io.IOException;
import java.util.Objects;

import static org.apache.paimon.flink.sink.index.KeyPartOrRow.KEY_PART;
import static org.apache.paimon.crosspartition.KeyPartOrRow.KEY_PART;

/** A {@link InternalTypeSerializer} to serialize KeyPartOrRow with T. */
public class KeyWithRowSerializer<T> extends InternalTypeSerializer<Tuple2<KeyPartOrRow, T>> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.paimon.data.InternalMap;
import org.apache.paimon.data.InternalRow;
import org.apache.paimon.data.Timestamp;
import org.apache.paimon.spark.schema.SparkSystemColumns;
import org.apache.paimon.types.ArrayType;
import org.apache.paimon.types.DataType;
import org.apache.paimon.types.DateType;
Expand All @@ -35,6 +36,7 @@
import org.apache.paimon.shade.guava30.com.google.common.collect.Lists;

import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;

import java.io.Serializable;
import java.sql.Date;
Expand All @@ -43,6 +45,7 @@
import java.time.LocalDateTime;
import java.time.ZoneId;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

Expand All @@ -54,11 +57,35 @@ public class SparkRow implements InternalRow, Serializable {

private final RowType type;
private final Row row;
private RowKind rowKind = RowKind.INSERT;
private final RowKind rowKind;

public SparkRow(RowType type, Row row) {
this.type = type;
this.row = row;
this.rowKind = getRowkind(row);
}

private SparkRow(RowType type, Row row, RowKind rowKind) {
this.type = type;
this.row = row;
this.rowKind = rowKind;
}

public static SparkRow reSerializeRow(
RowType type,
Row row,
ExpressionEncoder.Serializer<Row> toRow,
ExpressionEncoder.Deserializer<Row> fromRow) {
return new SparkRow(type, fromRow.apply(toRow.apply(row)), getRowkind(row));
}

private static RowKind getRowkind(Row row) {
if (Arrays.asList(row.schema().fieldNames()).contains(SparkSystemColumns.ROW_KIND_COL())) {
return RowKind.fromByteValue(
row.getByte(row.schema().fieldIndex(SparkSystemColumns.ROW_KIND_COL())));
} else {
return RowKind.INSERT;
}
}

@Override
Expand All @@ -73,7 +100,7 @@ public RowKind getRowKind() {

@Override
public void setRowKind(RowKind rowKind) {
this.rowKind = rowKind;
throw new UnsupportedOperationException();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,25 @@
package org.apache.paimon.spark.commands

import org.apache.paimon.CoreOptions.DYNAMIC_PARTITION_OVERWRITE
import org.apache.paimon.data.BinaryRow
import org.apache.paimon.codegen.{CodeGenUtils, Projection}
import org.apache.paimon.crosspartition.{GlobalIndexAssigner, IndexBootstrap, KeyPartOrRow}
import org.apache.paimon.data.{BinaryRow, GenericRow, JoinedRow}
import org.apache.paimon.data.serializer.InternalSerializers
import org.apache.paimon.index.PartitionIndex
import org.apache.paimon.options.Options
import org.apache.paimon.spark.{DynamicOverWrite, InsertInto, Overwrite, SaveMode, SparkConnectorOptions, SparkRow}
import org.apache.paimon.spark._
import org.apache.paimon.spark.SparkUtils.createIOManager
import org.apache.paimon.spark.schema.SparkSystemColumns
import org.apache.paimon.spark.schema.SparkSystemColumns.{BUCKET_COL, ROW_KIND_COL}
import org.apache.paimon.spark.util.EncoderUtils
import org.apache.paimon.table.{BucketMode, FileStoreTable}
import org.apache.paimon.table.sink.{BatchWriteBuilder, CommitMessageSerializer, DynamicBucketRow, RowPartitionKeyExtractor}
import org.apache.paimon.types.{RowKind, RowType}
import org.apache.paimon.utils.SerializationUtils

import org.apache.spark.TaskContext
import org.apache.spark.{HashPartitioner, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
Expand Down Expand Up @@ -78,15 +83,21 @@ case class WriteIntoPaimonTable(
val primaryKeyCols = tableSchema.trimmedPrimaryKeys().asScala.map(col)
val partitionCols = tableSchema.partitionKeys().asScala.map(col)

val dataEncoder = EncoderUtils.encode(dataSchema).resolveAndBind()
val originFromRow = dataEncoder.createDeserializer()
val (_, _, originFromRow) = EncoderUtils.getEncoderAndSerDe(dataSchema)

// append _bucket_ column as placeholder
val withBucketCol = data.withColumn(BUCKET_COL, lit(-1))
val bucketColIdx = withBucketCol.schema.size - 1
val withBucketDataEncoder = EncoderUtils.encode(withBucketCol.schema).resolveAndBind()
val toRow = withBucketDataEncoder.createSerializer()
val fromRow = withBucketDataEncoder.createDeserializer()
var newData = data
if (
bucketMode.equals(BucketMode.GLOBAL_DYNAMIC) && !newData.schema.fieldNames.contains(
ROW_KIND_COL)
) {
newData = data.withColumn(ROW_KIND_COL, lit(RowKind.INSERT.toByteValue))
}

// append bucket column as placeholder
newData = newData.withColumn(BUCKET_COL, lit(-1))
val bucketColIdx = newData.schema.fieldNames.indexOf(BUCKET_COL)

val (newDataEncoder, toRow, fromRow) = EncoderUtils.getEncoderAndSerDe(newData.schema)

def repartitionByBucket(ds: Dataset[Row]) = {
ds.toDF().repartition(partitionCols ++ Seq(col(BUCKET_COL)): _*)
Expand All @@ -95,32 +106,117 @@ case class WriteIntoPaimonTable(
val rowType = table.rowType()
val writeBuilder = table.newBatchWriteBuilder()

val df =
val df: Dataset[Row] =
bucketMode match {
case BucketMode.DYNAMIC =>
// Topology: input -- shuffle by key hash --> bucket-assigner -- shuffle by partition & bucket
val partitioned = if (primaryKeyCols.nonEmpty) {
// Make sure that the records with the same bucket values is within a task.
withBucketCol.repartition(primaryKeyCols: _*)
newData.repartition(primaryKeyCols: _*)
} else {
withBucketCol
newData
}
val numSparkPartitions = partitioned.rdd.getNumPartitions
val dynamicBucketProcessor =
DynamicBucketProcessor(table, rowType, bucketColIdx, numSparkPartitions, toRow, fromRow)
repartitionByBucket(
partitioned.mapPartitions(dynamicBucketProcessor.processPartition)(
withBucketDataEncoder))
partitioned.mapPartitions(dynamicBucketProcessor.processPartition)(newDataEncoder))
case BucketMode.GLOBAL_DYNAMIC =>
// Topology: input -- bootstrap -- shuffle by key hash --> bucket-assigner -- shuffle by partition & bucket
val numSparkPartitions = newData.rdd.getNumPartitions
val rowType: RowType = SparkTypeUtils.toPaimonType(newData.schema).asInstanceOf[RowType]

// row: (keyHash, (kind, internalRow))
val bootstrapRow: RDD[(Int, (KeyPartOrRow, Array[Byte]))] = newData.rdd.mapPartitions {
iter =>
{
val sparkPartitionId = TaskContext.getPartitionId()

val bootstrapType: RowType = IndexBootstrap.bootstrapType(table.schema())
val primaryKeys: java.util.List[String] = table.schema().primaryKeys()
val rowProject: Projection = CodeGenUtils.newProjection(rowType, primaryKeys)
val keyPartProject: Projection =
CodeGenUtils.newProjection(bootstrapType, primaryKeys)

val lst = scala.collection.mutable.ListBuffer[(Int, (KeyPartOrRow, Array[Byte]))]()
val bootstrap = new IndexBootstrap(table)
bootstrap.bootstrap(
numSparkPartitions,
sparkPartitionId,
row => {
val bootstrapSer = InternalSerializers.create(bootstrapType)
val bytes: Array[Byte] =
SerializationUtils.serializeBinaryRow(bootstrapSer.toBinaryRow(row))
lst.append((keyPartProject(row).hashCode(), (KeyPartOrRow.KEY_PART, bytes)))
}
)
lst.iterator ++ iter.map(
r => {
val sparkRow = new SparkRow(rowType, r)
val rowSer = InternalSerializers.create(rowType)
val bytes: Array[Byte] =
SerializationUtils.serializeBinaryRow(rowSer.toBinaryRow(sparkRow))
(rowProject(sparkRow).hashCode(), (KeyPartOrRow.ROW, bytes))
})
}
}

var assignerParallelism: Integer = table.coreOptions.dynamicBucketAssignerParallelism
if (assignerParallelism == null) {
assignerParallelism = numSparkPartitions
}

val value: RDD[Row] =
bootstrapRow.partitionBy(new HashPartitioner(assignerParallelism)).mapPartitions {
iter =>
{
val sparkPartitionId = TaskContext.getPartitionId()
val lst = scala.collection.mutable.ListBuffer[Row]()
val ioManager = createIOManager
val assigner = new GlobalIndexAssigner(table)

assigner.open(
ioManager,
assignerParallelism,
sparkPartitionId,
(row, bucket) => {
val extraRow: GenericRow = new GenericRow(2)
extraRow.setField(0, row.getRowKind.toByteValue)
extraRow.setField(1, bucket)
lst.append(
fromRow(SparkInternalRow.fromPaimon(new JoinedRow(row, extraRow), rowType)))
}
)
while (iter.hasNext) {
val tuple: (KeyPartOrRow, Array[Byte]) = iter.next()._2
val binaryRow = SerializationUtils.deserializeBinaryRow(tuple._2)
tuple._1 match {
case KeyPartOrRow.KEY_PART => assigner.bootstrapKey(binaryRow)
case KeyPartOrRow.ROW => assigner.processInput(binaryRow)
case _ => throw new UnsupportedOperationException(s"unknown kind ${tuple._1}")
}
}
assigner.endBoostrap(true)
lst.iterator
}
}
sparkSession
.createDataFrame(value, newData.schema)
.repartition(partitionCols ++ Seq(col(BUCKET_COL)): _*)
case BucketMode.UNAWARE =>
// Topology: input -- bucket-assigner
val unawareBucketProcessor = UnawareBucketProcessor(bucketColIdx, toRow, fromRow)
withBucketCol
.mapPartitions(unawareBucketProcessor.processPartition)(withBucketDataEncoder)
newData
.mapPartitions(unawareBucketProcessor.processPartition)(newDataEncoder)
.toDF()
case BucketMode.FIXED =>
// Topology: input -- bucket-assigner -- shuffle by partition & bucket
val commonBucketProcessor =
CommonBucketProcessor(writeBuilder, bucketColIdx, toRow, fromRow)
repartitionByBucket(
withBucketCol.mapPartitions(commonBucketProcessor.processPartition)(
withBucketDataEncoder))
newData.mapPartitions(commonBucketProcessor.processPartition)(newDataEncoder))
case _ =>
throw new UnsupportedOperationException(s"unsupported bucket mode $bucketMode")
}

val commitMessages = df
Expand All @@ -133,12 +229,7 @@ case class WriteIntoPaimonTable(
iter.foreach {
row =>
val bucket = row.getInt(bucketColIdx)
val bucketColDropped = originFromRow(toRow(row))
val sparkRow = new SparkRow(rowType, bucketColDropped)
if (row.schema.fieldNames.contains(ROW_KIND_COL)) {
val rowKind = RowKind.fromByteValue(row.getAs(ROW_KIND_COL))
sparkRow.setRowKind(rowKind)
}
val sparkRow = SparkRow.reSerializeRow(rowType, row, toRow, originFromRow)
write.write(new DynamicBucketRow(sparkRow, bucket))
}
val serializer = new CommitMessageSerializer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,10 @@ object EncoderUtils {
.reflectMethod(method)(schema)
.asInstanceOf[ExpressionEncoder[Row]]
}

def getEncoderAndSerDe(schema: StructType)
: (ExpressionEncoder[Row], ExpressionEncoder.Serializer[Row], ExpressionEncoder.Deserializer[Row]) = {
val encoder = encode(schema).resolveAndBind()
(encoder, encoder.createSerializer(), encoder.createDeserializer())
}
}
Loading

0 comments on commit 82a803c

Please sign in to comment.