Skip to content

Commit

Permalink
[core] Revisit merge schema to fix issues (#4094)
Browse files Browse the repository at this point in the history
  • Loading branch information
ulysses-you committed Sep 10, 2024
1 parent 685aa4b commit 28e02a9
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 106 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,21 @@
public class SchemaMergingUtils {

public static TableSchema mergeSchemas(
TableSchema currentTableSchema, RowType dataFields, boolean allowExplicitCast) {
if (currentTableSchema.logicalRowType().equals(dataFields)) {
TableSchema currentTableSchema, RowType targetType, boolean allowExplicitCast) {
RowType currentType = currentTableSchema.logicalRowType();
if (currentType.equals(targetType)) {
return currentTableSchema;
}

AtomicInteger highestFieldId = new AtomicInteger(currentTableSchema.highestFieldId());
RowType newRowType =
mergeSchemas(
currentTableSchema.logicalRowType(),
dataFields,
highestFieldId,
allowExplicitCast);
mergeSchemas(currentType, targetType, highestFieldId, allowExplicitCast);
if (newRowType == currentType) {
// It happens if the `targetType` only changes `nullability` but we always respect the
// current's.
return currentTableSchema;
}

return new TableSchema(
currentTableSchema.id() + 1,
newRowType.getFields(),
Expand Down Expand Up @@ -86,7 +89,7 @@ public static DataType merge(
DataType update0,
AtomicInteger highestFieldId,
boolean allowExplicitCast) {
// Here we try t0 merge the base0 and update0 without regard to the nullability,
// Here we try to merge the base0 and update0 without regard to the nullability,
// and set the base0's nullability to the return's.
DataType base = base0.copy(true);
DataType update = update0.copy(true);
Expand Down Expand Up @@ -134,10 +137,10 @@ public static DataType merge(
.collect(Collectors.toList());

updatedFields.addAll(newFields);
return new RowType(base.isNullable(), updatedFields);
return new RowType(base0.isNullable(), updatedFields);
} else if (base instanceof MapType && update instanceof MapType) {
return new MapType(
base.isNullable(),
base0.isNullable(),
merge(
((MapType) base).getKeyType(),
((MapType) update).getKeyType(),
Expand All @@ -150,27 +153,32 @@ public static DataType merge(
allowExplicitCast));
} else if (base instanceof ArrayType && update instanceof ArrayType) {
return new ArrayType(
base.isNullable(),
base0.isNullable(),
merge(
((ArrayType) base).getElementType(),
((ArrayType) update).getElementType(),
highestFieldId,
allowExplicitCast));
} else if (base instanceof MultisetType && update instanceof MultisetType) {
return new MultisetType(
base.isNullable(),
base0.isNullable(),
merge(
((MultisetType) base).getElementType(),
((MultisetType) update).getElementType(),
highestFieldId,
allowExplicitCast));
} else if (base instanceof DecimalType && update instanceof DecimalType) {
if (base.equals(update)) {
return base0;
if (((DecimalType) base).getScale() == ((DecimalType) update).getScale()) {
return new DecimalType(
base0.isNullable(),
Math.max(
((DecimalType) base).getPrecision(),
((DecimalType) update).getPrecision()),
((DecimalType) base).getScale());
} else {
throw new UnsupportedOperationException(
String.format(
"Failed to merge decimal types with different precision or scale: %s and %s",
"Failed to merge decimal types with different scale: %s and %s",
base, update));
}
} else if (supportsDataTypesCast(base, update, allowExplicitCast)) {
Expand Down Expand Up @@ -212,10 +220,11 @@ public static DataType merge(

private static boolean supportsDataTypesCast(
DataType sourceType, DataType targetType, boolean allowExplicitCast) {
boolean canImplicitCast = DataTypeCasts.supportsImplicitCast(sourceType, targetType);
boolean canExplicitCast =
allowExplicitCast && DataTypeCasts.supportsExplicitCast(sourceType, targetType);
return canImplicitCast || canExplicitCast;
if (allowExplicitCast) {
return DataTypeCasts.supportsExplicitCast(sourceType, targetType);
} else {
return DataTypeCasts.supportsImplicitCast(sourceType, targetType);
}
}

private static DataField assignIdForNewField(DataField field, AtomicInteger highestFieldId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ public void testMergeArrayTypes() {
// the element types aren't same, but can be evolved safety.
DataType t2 = new ArrayType(true, new BigIntType());
ArrayType r2 = (ArrayType) SchemaMergingUtils.merge(source, t2, highestFieldId, false);
assertThat(r2.isNullable()).isFalse();
assertThat(r2.getElementType() instanceof BigIntType).isTrue();

// the element types aren't same, and can't be evolved safety.
Expand All @@ -192,6 +193,7 @@ public void testMergeArrayTypes() {
// the value type of target's isn't same to the source's, but the source type can be cast to
// the target type explicitly.
ArrayType r3 = (ArrayType) SchemaMergingUtils.merge(source, t3, highestFieldId, true);
assertThat(r3.isNullable()).isFalse();
assertThat(r3.getElementType() instanceof SmallIntType).isTrue();
}

Expand All @@ -211,6 +213,7 @@ public void testMergeMapTypes() {
// the value type of target's isn't same to the source's, but can be evolved safety.
DataType t2 = new MapType(new VarCharType(VarCharType.MAX_LENGTH), new DoubleType());
MapType r2 = (MapType) SchemaMergingUtils.merge(source, t2, highestFieldId, false);
assertThat(r2.isNullable()).isTrue();
assertThat(r2.getKeyType() instanceof VarCharType).isTrue();
assertThat(r2.getValueType() instanceof DoubleType).isTrue();

Expand All @@ -221,6 +224,7 @@ public void testMergeMapTypes() {
// the value type of target's isn't same to the source's, but the source type can be cast to
// the target type explicitly.
MapType r3 = (MapType) SchemaMergingUtils.merge(source, t3, highestFieldId, true);
assertThat(r3.isNullable()).isTrue();
assertThat(r3.getKeyType() instanceof VarCharType).isTrue();
assertThat(r3.getValueType() instanceof SmallIntType).isTrue();
}
Expand All @@ -242,6 +246,7 @@ public void testMergeMultisetTypes() {
DataType t2 = new MultisetType(true, new BigIntType());
MultisetType r2 =
(MultisetType) SchemaMergingUtils.merge(source, t2, highestFieldId, false);
assertThat(r2.isNullable()).isFalse();
assertThat(r2.getElementType() instanceof BigIntType).isTrue();

// the element types aren't same, and can't be evolved safety.
Expand All @@ -251,6 +256,7 @@ public void testMergeMultisetTypes() {
// the value type of target's isn't same to the source's, but the source type can be cast to
// the target type explicitly.
MultisetType r3 = (MultisetType) SchemaMergingUtils.merge(source, t3, highestFieldId, true);
assertThat(r3.isNullable()).isFalse();
assertThat(r3.getElementType() instanceof SmallIntType).isTrue();
}

Expand All @@ -266,10 +272,24 @@ public void testMergeDecimalTypes() {
assertThat(r1.getScale()).isEqualTo(DecimalType.DEFAULT_SCALE);

DataType s2 = new DecimalType(5, 2);
DataType t2 = new DecimalType(7, 2);
DataType t2 = new DecimalType(7, 3);
assertThatThrownBy(() -> SchemaMergingUtils.merge(s2, t2, highestFieldId, false))
.isInstanceOf(UnsupportedOperationException.class);

DataType s3 = new DecimalType(false, 5, 2);
DataType t3 = new DecimalType(7, 2);
DecimalType r3 = (DecimalType) SchemaMergingUtils.merge(s3, t3, highestFieldId, false);
assertThat(r3.isNullable()).isFalse();
assertThat(r3.getPrecision()).isEqualTo(7);
assertThat(r3.getScale()).isEqualTo(2);

DataType s4 = new DecimalType(7, 2);
DataType t4 = new DecimalType(5, 2);
DecimalType r4 = (DecimalType) SchemaMergingUtils.merge(s4, t4, highestFieldId, false);
assertThat(r4.isNullable()).isTrue();
assertThat(r4.getPrecision()).isEqualTo(7);
assertThat(r4.getScale()).isEqualTo(2);

// DecimalType -> Other Numeric Type
DataType dcmSource = new DecimalType();
DataType iTarget = new IntType();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package org.apache.paimon.spark.sql
import org.apache.paimon.spark.PaimonSparkTestBase

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.DecimalType
import org.junit.jupiter.api.Assertions

import java.sql.{Date, Timestamp}
Expand Down Expand Up @@ -177,92 +178,120 @@ class DataFrameWriteTest extends PaimonSparkTestBase {
}
}

withPk.foreach {
hasPk =>
bucketModes.foreach {
bucket =>
test(s"Schema evolution: write data into Paimon: $hasPk, bucket: $bucket") {
val _spark = spark
import _spark.implicits._

val prop = if (hasPk) {
s"'primary-key'='a', 'bucket' = '$bucket' "
} else if (bucket != -1) {
s"'bucket-key'='a', 'bucket' = '$bucket' "
} else {
"'write-only'='true'"
}

spark.sql(s"""
|CREATE TABLE T (a INT, b STRING)
|TBLPROPERTIES ($prop)
|""".stripMargin)

val paimonTable = loadTable("T")
val location = paimonTable.location().toString

val df1 = Seq((1, "a"), (2, "b")).toDF("a", "b")
df1.write.format("paimon").mode("append").save(location)
checkAnswer(
spark.sql("SELECT * FROM T ORDER BY a, b"),
Row(1, "a") :: Row(2, "b") :: Nil)

// Case 1: two additional fields
val df2 = Seq((1, "a2", 123L, Map("k" -> 11.1)), (3, "c", 345L, Map("k" -> 33.3)))
.toDF("a", "b", "c", "d")
df2.write
.format("paimon")
.mode("append")
.option("write.merge-schema", "true")
.save(location)
val expected2 = if (hasPk) {
Row(1, "a2", 123L, Map("k" -> 11.1)) ::
Row(2, "b", null, null) :: Row(3, "c", 345L, Map("k" -> 33.3)) :: Nil
} else {
Row(1, "a", null, null) :: Row(1, "a2", 123L, Map("k" -> 11.1)) :: Row(
2,
"b",
null,
null) :: Row(3, "c", 345L, Map("k" -> 33.3)) :: Nil
}
checkAnswer(spark.sql("SELECT * FROM T ORDER BY a, b"), expected2)

// Case 2: two fields with the evolved types: Int -> Long, Long -> Decimal
val df3 = Seq(
(2L, "b2", BigDecimal.decimal(234), Map("k" -> 22.2)),
(4L, "d", BigDecimal.decimal(456), Map("k" -> 44.4))).toDF("a", "b", "c", "d")
df3.write
.format("paimon")
.mode("append")
.option("write.merge-schema", "true")
.save(location)
val expected3 = if (hasPk) {
Row(1L, "a2", BigDecimal.decimal(123), Map("k" -> 11.1)) :: Row(
2L,
"b2",
BigDecimal.decimal(234),
Map("k" -> 22.2)) :: Row(3L, "c", BigDecimal.decimal(345), Map("k" -> 33.3)) :: Row(
4L,
"d",
BigDecimal.decimal(456),
Map("k" -> 44.4)) :: Nil
} else {
Row(1L, "a", null, null) :: Row(
1L,
"a2",
BigDecimal.decimal(123),
Map("k" -> 11.1)) :: Row(2L, "b", null, null) :: Row(
2L,
"b2",
BigDecimal.decimal(234),
Map("k" -> 22.2)) :: Row(3L, "c", BigDecimal.decimal(345), Map("k" -> 33.3)) :: Row(
4L,
"d",
BigDecimal.decimal(456),
Map("k" -> 44.4)) :: Nil
}
checkAnswer(spark.sql("SELECT * FROM T ORDER BY a, b"), expected3)

fileFormats.foreach {
format =>
withPk.foreach {
hasPk =>
bucketModes.foreach {
bucket =>
test(
s"Schema evolution: write data into Paimon: $hasPk, bucket: $bucket, format: $format") {
val _spark = spark
import _spark.implicits._

val prop = if (hasPk) {
s"'primary-key'='a', 'bucket' = '$bucket', 'file.format' = '$format'"
} else if (bucket != -1) {
s"'bucket-key'='a', 'bucket' = '$bucket', 'file.format' = '$format'"
} else {
s"'write-only'='true', 'file.format' = '$format'"
}

spark.sql(s"""
|CREATE TABLE T (a INT, b STRING)
|TBLPROPERTIES ($prop)
|""".stripMargin)

val paimonTable = loadTable("T")
val location = paimonTable.location().toString

val df1 = Seq((1, "a"), (2, "b")).toDF("a", "b")
df1.write.format("paimon").mode("append").save(location)
checkAnswer(
spark.sql("SELECT * FROM T ORDER BY a, b"),
Row(1, "a") :: Row(2, "b") :: Nil)

// Case 1: two additional fields
val df2 = Seq((1, "a2", 123L, Map("k" -> 11.1)), (3, "c", 345L, Map("k" -> 33.3)))
.toDF("a", "b", "c", "d")
df2.write
.format("paimon")
.mode("append")
.option("write.merge-schema", "true")
.save(location)
val expected2 = if (hasPk) {
Row(1, "a2", 123L, Map("k" -> 11.1)) ::
Row(2, "b", null, null) :: Row(3, "c", 345L, Map("k" -> 33.3)) :: Nil
} else {
Row(1, "a", null, null) :: Row(1, "a2", 123L, Map("k" -> 11.1)) :: Row(
2,
"b",
null,
null) :: Row(3, "c", 345L, Map("k" -> 33.3)) :: Nil
}
checkAnswer(spark.sql("SELECT * FROM T ORDER BY a, b"), expected2)

// Case 2: two fields with the evolved types: Int -> Long, Long -> Decimal
val df3 = Seq(
(2L, "b2", BigDecimal.decimal(234), Map("k" -> 22.2)),
(4L, "d", BigDecimal.decimal(456), Map("k" -> 44.4))).toDF("a", "b", "c", "d")
df3.write
.format("paimon")
.mode("append")
.option("write.merge-schema", "true")
.save(location)
val expected3 = if (hasPk) {
Row(1L, "a2", BigDecimal.decimal(123), Map("k" -> 11.1)) :: Row(
2L,
"b2",
BigDecimal.decimal(234),
Map("k" -> 22.2)) :: Row(
3L,
"c",
BigDecimal.decimal(345),
Map("k" -> 33.3)) :: Row(
4L,
"d",
BigDecimal.decimal(456),
Map("k" -> 44.4)) :: Nil
} else {
Row(1L, "a", null, null) :: Row(
1L,
"a2",
BigDecimal.decimal(123),
Map("k" -> 11.1)) :: Row(2L, "b", null, null) :: Row(
2L,
"b2",
BigDecimal.decimal(234),
Map("k" -> 22.2)) :: Row(
3L,
"c",
BigDecimal.decimal(345),
Map("k" -> 33.3)) :: Row(
4L,
"d",
BigDecimal.decimal(456),
Map("k" -> 44.4)) :: Nil
}
checkAnswer(spark.sql("SELECT * FROM T ORDER BY a, b"), expected3)

// Case 3: insert Decimal(20,18) to Decimal(38,18)
val df4 = Seq((99L, "df4", BigDecimal.decimal(4.0), Map("4" -> 4.1)))
.toDF("a", "b", "c", "d")
.selectExpr("a", "b", "cast(c as decimal(20,18)) as c", "d")
df4.write
.format("paimon")
.mode("append")
.option("write.merge-schema", "true")
.save(location)
val expected4 =
expected3 ++ Seq(Row(99L, "df4", BigDecimal.decimal(4.0), Map("4" -> 4.1)))
checkAnswer(spark.sql("SELECT * FROM T ORDER BY a, b"), expected4)
val decimalType =
spark.table("T").schema.apply(2).dataType.asInstanceOf[DecimalType]
assert(decimalType.precision == 38)
assert(decimalType.scale == 18)
}
}
}
}
Expand Down

0 comments on commit 28e02a9

Please sign in to comment.