From 28e02a99efd64c273ee5762e900439856e367c86 Mon Sep 17 00:00:00 2001 From: Xiduo You Date: Tue, 10 Sep 2024 15:28:02 +0800 Subject: [PATCH] [core] Revisit merge schema to fix issues (#4094) --- .../paimon/schema/SchemaMergingUtils.java | 47 ++-- .../paimon/schema/SchemaMergingUtilsTest.java | 22 +- .../paimon/spark/sql/DataFrameWriteTest.scala | 201 ++++++++++-------- 3 files changed, 164 insertions(+), 106 deletions(-) diff --git a/paimon-core/src/main/java/org/apache/paimon/schema/SchemaMergingUtils.java b/paimon-core/src/main/java/org/apache/paimon/schema/SchemaMergingUtils.java index 9591175e0a44..30004b53fcfb 100644 --- a/paimon-core/src/main/java/org/apache/paimon/schema/SchemaMergingUtils.java +++ b/paimon-core/src/main/java/org/apache/paimon/schema/SchemaMergingUtils.java @@ -39,18 +39,21 @@ public class SchemaMergingUtils { public static TableSchema mergeSchemas( - TableSchema currentTableSchema, RowType dataFields, boolean allowExplicitCast) { - if (currentTableSchema.logicalRowType().equals(dataFields)) { + TableSchema currentTableSchema, RowType targetType, boolean allowExplicitCast) { + RowType currentType = currentTableSchema.logicalRowType(); + if (currentType.equals(targetType)) { return currentTableSchema; } AtomicInteger highestFieldId = new AtomicInteger(currentTableSchema.highestFieldId()); RowType newRowType = - mergeSchemas( - currentTableSchema.logicalRowType(), - dataFields, - highestFieldId, - allowExplicitCast); + mergeSchemas(currentType, targetType, highestFieldId, allowExplicitCast); + if (newRowType == currentType) { + // It happens if the `targetType` only changes `nullability` but we always respect the + // current's. + return currentTableSchema; + } + return new TableSchema( currentTableSchema.id() + 1, newRowType.getFields(), @@ -86,7 +89,7 @@ public static DataType merge( DataType update0, AtomicInteger highestFieldId, boolean allowExplicitCast) { - // Here we try t0 merge the base0 and update0 without regard to the nullability, + // Here we try to merge the base0 and update0 without regard to the nullability, // and set the base0's nullability to the return's. DataType base = base0.copy(true); DataType update = update0.copy(true); @@ -134,10 +137,10 @@ public static DataType merge( .collect(Collectors.toList()); updatedFields.addAll(newFields); - return new RowType(base.isNullable(), updatedFields); + return new RowType(base0.isNullable(), updatedFields); } else if (base instanceof MapType && update instanceof MapType) { return new MapType( - base.isNullable(), + base0.isNullable(), merge( ((MapType) base).getKeyType(), ((MapType) update).getKeyType(), @@ -150,7 +153,7 @@ public static DataType merge( allowExplicitCast)); } else if (base instanceof ArrayType && update instanceof ArrayType) { return new ArrayType( - base.isNullable(), + base0.isNullable(), merge( ((ArrayType) base).getElementType(), ((ArrayType) update).getElementType(), @@ -158,19 +161,24 @@ public static DataType merge( allowExplicitCast)); } else if (base instanceof MultisetType && update instanceof MultisetType) { return new MultisetType( - base.isNullable(), + base0.isNullable(), merge( ((MultisetType) base).getElementType(), ((MultisetType) update).getElementType(), highestFieldId, allowExplicitCast)); } else if (base instanceof DecimalType && update instanceof DecimalType) { - if (base.equals(update)) { - return base0; + if (((DecimalType) base).getScale() == ((DecimalType) update).getScale()) { + return new DecimalType( + base0.isNullable(), + Math.max( + ((DecimalType) base).getPrecision(), + ((DecimalType) update).getPrecision()), + ((DecimalType) base).getScale()); } else { throw new UnsupportedOperationException( String.format( - "Failed to merge decimal types with different precision or scale: %s and %s", + "Failed to merge decimal types with different scale: %s and %s", base, update)); } } else if (supportsDataTypesCast(base, update, allowExplicitCast)) { @@ -212,10 +220,11 @@ public static DataType merge( private static boolean supportsDataTypesCast( DataType sourceType, DataType targetType, boolean allowExplicitCast) { - boolean canImplicitCast = DataTypeCasts.supportsImplicitCast(sourceType, targetType); - boolean canExplicitCast = - allowExplicitCast && DataTypeCasts.supportsExplicitCast(sourceType, targetType); - return canImplicitCast || canExplicitCast; + if (allowExplicitCast) { + return DataTypeCasts.supportsExplicitCast(sourceType, targetType); + } else { + return DataTypeCasts.supportsImplicitCast(sourceType, targetType); + } } private static DataField assignIdForNewField(DataField field, AtomicInteger highestFieldId) { diff --git a/paimon-core/src/test/java/org/apache/paimon/schema/SchemaMergingUtilsTest.java b/paimon-core/src/test/java/org/apache/paimon/schema/SchemaMergingUtilsTest.java index 53856a4c208e..8ad40852721a 100644 --- a/paimon-core/src/test/java/org/apache/paimon/schema/SchemaMergingUtilsTest.java +++ b/paimon-core/src/test/java/org/apache/paimon/schema/SchemaMergingUtilsTest.java @@ -183,6 +183,7 @@ public void testMergeArrayTypes() { // the element types aren't same, but can be evolved safety. DataType t2 = new ArrayType(true, new BigIntType()); ArrayType r2 = (ArrayType) SchemaMergingUtils.merge(source, t2, highestFieldId, false); + assertThat(r2.isNullable()).isFalse(); assertThat(r2.getElementType() instanceof BigIntType).isTrue(); // the element types aren't same, and can't be evolved safety. @@ -192,6 +193,7 @@ public void testMergeArrayTypes() { // the value type of target's isn't same to the source's, but the source type can be cast to // the target type explicitly. ArrayType r3 = (ArrayType) SchemaMergingUtils.merge(source, t3, highestFieldId, true); + assertThat(r3.isNullable()).isFalse(); assertThat(r3.getElementType() instanceof SmallIntType).isTrue(); } @@ -211,6 +213,7 @@ public void testMergeMapTypes() { // the value type of target's isn't same to the source's, but can be evolved safety. DataType t2 = new MapType(new VarCharType(VarCharType.MAX_LENGTH), new DoubleType()); MapType r2 = (MapType) SchemaMergingUtils.merge(source, t2, highestFieldId, false); + assertThat(r2.isNullable()).isTrue(); assertThat(r2.getKeyType() instanceof VarCharType).isTrue(); assertThat(r2.getValueType() instanceof DoubleType).isTrue(); @@ -221,6 +224,7 @@ public void testMergeMapTypes() { // the value type of target's isn't same to the source's, but the source type can be cast to // the target type explicitly. MapType r3 = (MapType) SchemaMergingUtils.merge(source, t3, highestFieldId, true); + assertThat(r3.isNullable()).isTrue(); assertThat(r3.getKeyType() instanceof VarCharType).isTrue(); assertThat(r3.getValueType() instanceof SmallIntType).isTrue(); } @@ -242,6 +246,7 @@ public void testMergeMultisetTypes() { DataType t2 = new MultisetType(true, new BigIntType()); MultisetType r2 = (MultisetType) SchemaMergingUtils.merge(source, t2, highestFieldId, false); + assertThat(r2.isNullable()).isFalse(); assertThat(r2.getElementType() instanceof BigIntType).isTrue(); // the element types aren't same, and can't be evolved safety. @@ -251,6 +256,7 @@ public void testMergeMultisetTypes() { // the value type of target's isn't same to the source's, but the source type can be cast to // the target type explicitly. MultisetType r3 = (MultisetType) SchemaMergingUtils.merge(source, t3, highestFieldId, true); + assertThat(r3.isNullable()).isFalse(); assertThat(r3.getElementType() instanceof SmallIntType).isTrue(); } @@ -266,10 +272,24 @@ public void testMergeDecimalTypes() { assertThat(r1.getScale()).isEqualTo(DecimalType.DEFAULT_SCALE); DataType s2 = new DecimalType(5, 2); - DataType t2 = new DecimalType(7, 2); + DataType t2 = new DecimalType(7, 3); assertThatThrownBy(() -> SchemaMergingUtils.merge(s2, t2, highestFieldId, false)) .isInstanceOf(UnsupportedOperationException.class); + DataType s3 = new DecimalType(false, 5, 2); + DataType t3 = new DecimalType(7, 2); + DecimalType r3 = (DecimalType) SchemaMergingUtils.merge(s3, t3, highestFieldId, false); + assertThat(r3.isNullable()).isFalse(); + assertThat(r3.getPrecision()).isEqualTo(7); + assertThat(r3.getScale()).isEqualTo(2); + + DataType s4 = new DecimalType(7, 2); + DataType t4 = new DecimalType(5, 2); + DecimalType r4 = (DecimalType) SchemaMergingUtils.merge(s4, t4, highestFieldId, false); + assertThat(r4.isNullable()).isTrue(); + assertThat(r4.getPrecision()).isEqualTo(7); + assertThat(r4.getScale()).isEqualTo(2); + // DecimalType -> Other Numeric Type DataType dcmSource = new DecimalType(); DataType iTarget = new IntType(); diff --git a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DataFrameWriteTest.scala b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DataFrameWriteTest.scala index ca3ba8797be6..3f6e81da018c 100644 --- a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DataFrameWriteTest.scala +++ b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/DataFrameWriteTest.scala @@ -21,6 +21,7 @@ package org.apache.paimon.spark.sql import org.apache.paimon.spark.PaimonSparkTestBase import org.apache.spark.sql.Row +import org.apache.spark.sql.types.DecimalType import org.junit.jupiter.api.Assertions import java.sql.{Date, Timestamp} @@ -177,92 +178,120 @@ class DataFrameWriteTest extends PaimonSparkTestBase { } } - withPk.foreach { - hasPk => - bucketModes.foreach { - bucket => - test(s"Schema evolution: write data into Paimon: $hasPk, bucket: $bucket") { - val _spark = spark - import _spark.implicits._ - - val prop = if (hasPk) { - s"'primary-key'='a', 'bucket' = '$bucket' " - } else if (bucket != -1) { - s"'bucket-key'='a', 'bucket' = '$bucket' " - } else { - "'write-only'='true'" - } - - spark.sql(s""" - |CREATE TABLE T (a INT, b STRING) - |TBLPROPERTIES ($prop) - |""".stripMargin) - - val paimonTable = loadTable("T") - val location = paimonTable.location().toString - - val df1 = Seq((1, "a"), (2, "b")).toDF("a", "b") - df1.write.format("paimon").mode("append").save(location) - checkAnswer( - spark.sql("SELECT * FROM T ORDER BY a, b"), - Row(1, "a") :: Row(2, "b") :: Nil) - - // Case 1: two additional fields - val df2 = Seq((1, "a2", 123L, Map("k" -> 11.1)), (3, "c", 345L, Map("k" -> 33.3))) - .toDF("a", "b", "c", "d") - df2.write - .format("paimon") - .mode("append") - .option("write.merge-schema", "true") - .save(location) - val expected2 = if (hasPk) { - Row(1, "a2", 123L, Map("k" -> 11.1)) :: - Row(2, "b", null, null) :: Row(3, "c", 345L, Map("k" -> 33.3)) :: Nil - } else { - Row(1, "a", null, null) :: Row(1, "a2", 123L, Map("k" -> 11.1)) :: Row( - 2, - "b", - null, - null) :: Row(3, "c", 345L, Map("k" -> 33.3)) :: Nil - } - checkAnswer(spark.sql("SELECT * FROM T ORDER BY a, b"), expected2) - - // Case 2: two fields with the evolved types: Int -> Long, Long -> Decimal - val df3 = Seq( - (2L, "b2", BigDecimal.decimal(234), Map("k" -> 22.2)), - (4L, "d", BigDecimal.decimal(456), Map("k" -> 44.4))).toDF("a", "b", "c", "d") - df3.write - .format("paimon") - .mode("append") - .option("write.merge-schema", "true") - .save(location) - val expected3 = if (hasPk) { - Row(1L, "a2", BigDecimal.decimal(123), Map("k" -> 11.1)) :: Row( - 2L, - "b2", - BigDecimal.decimal(234), - Map("k" -> 22.2)) :: Row(3L, "c", BigDecimal.decimal(345), Map("k" -> 33.3)) :: Row( - 4L, - "d", - BigDecimal.decimal(456), - Map("k" -> 44.4)) :: Nil - } else { - Row(1L, "a", null, null) :: Row( - 1L, - "a2", - BigDecimal.decimal(123), - Map("k" -> 11.1)) :: Row(2L, "b", null, null) :: Row( - 2L, - "b2", - BigDecimal.decimal(234), - Map("k" -> 22.2)) :: Row(3L, "c", BigDecimal.decimal(345), Map("k" -> 33.3)) :: Row( - 4L, - "d", - BigDecimal.decimal(456), - Map("k" -> 44.4)) :: Nil - } - checkAnswer(spark.sql("SELECT * FROM T ORDER BY a, b"), expected3) - + fileFormats.foreach { + format => + withPk.foreach { + hasPk => + bucketModes.foreach { + bucket => + test( + s"Schema evolution: write data into Paimon: $hasPk, bucket: $bucket, format: $format") { + val _spark = spark + import _spark.implicits._ + + val prop = if (hasPk) { + s"'primary-key'='a', 'bucket' = '$bucket', 'file.format' = '$format'" + } else if (bucket != -1) { + s"'bucket-key'='a', 'bucket' = '$bucket', 'file.format' = '$format'" + } else { + s"'write-only'='true', 'file.format' = '$format'" + } + + spark.sql(s""" + |CREATE TABLE T (a INT, b STRING) + |TBLPROPERTIES ($prop) + |""".stripMargin) + + val paimonTable = loadTable("T") + val location = paimonTable.location().toString + + val df1 = Seq((1, "a"), (2, "b")).toDF("a", "b") + df1.write.format("paimon").mode("append").save(location) + checkAnswer( + spark.sql("SELECT * FROM T ORDER BY a, b"), + Row(1, "a") :: Row(2, "b") :: Nil) + + // Case 1: two additional fields + val df2 = Seq((1, "a2", 123L, Map("k" -> 11.1)), (3, "c", 345L, Map("k" -> 33.3))) + .toDF("a", "b", "c", "d") + df2.write + .format("paimon") + .mode("append") + .option("write.merge-schema", "true") + .save(location) + val expected2 = if (hasPk) { + Row(1, "a2", 123L, Map("k" -> 11.1)) :: + Row(2, "b", null, null) :: Row(3, "c", 345L, Map("k" -> 33.3)) :: Nil + } else { + Row(1, "a", null, null) :: Row(1, "a2", 123L, Map("k" -> 11.1)) :: Row( + 2, + "b", + null, + null) :: Row(3, "c", 345L, Map("k" -> 33.3)) :: Nil + } + checkAnswer(spark.sql("SELECT * FROM T ORDER BY a, b"), expected2) + + // Case 2: two fields with the evolved types: Int -> Long, Long -> Decimal + val df3 = Seq( + (2L, "b2", BigDecimal.decimal(234), Map("k" -> 22.2)), + (4L, "d", BigDecimal.decimal(456), Map("k" -> 44.4))).toDF("a", "b", "c", "d") + df3.write + .format("paimon") + .mode("append") + .option("write.merge-schema", "true") + .save(location) + val expected3 = if (hasPk) { + Row(1L, "a2", BigDecimal.decimal(123), Map("k" -> 11.1)) :: Row( + 2L, + "b2", + BigDecimal.decimal(234), + Map("k" -> 22.2)) :: Row( + 3L, + "c", + BigDecimal.decimal(345), + Map("k" -> 33.3)) :: Row( + 4L, + "d", + BigDecimal.decimal(456), + Map("k" -> 44.4)) :: Nil + } else { + Row(1L, "a", null, null) :: Row( + 1L, + "a2", + BigDecimal.decimal(123), + Map("k" -> 11.1)) :: Row(2L, "b", null, null) :: Row( + 2L, + "b2", + BigDecimal.decimal(234), + Map("k" -> 22.2)) :: Row( + 3L, + "c", + BigDecimal.decimal(345), + Map("k" -> 33.3)) :: Row( + 4L, + "d", + BigDecimal.decimal(456), + Map("k" -> 44.4)) :: Nil + } + checkAnswer(spark.sql("SELECT * FROM T ORDER BY a, b"), expected3) + + // Case 3: insert Decimal(20,18) to Decimal(38,18) + val df4 = Seq((99L, "df4", BigDecimal.decimal(4.0), Map("4" -> 4.1))) + .toDF("a", "b", "c", "d") + .selectExpr("a", "b", "cast(c as decimal(20,18)) as c", "d") + df4.write + .format("paimon") + .mode("append") + .option("write.merge-schema", "true") + .save(location) + val expected4 = + expected3 ++ Seq(Row(99L, "df4", BigDecimal.decimal(4.0), Map("4" -> 4.1))) + checkAnswer(spark.sql("SELECT * FROM T ORDER BY a, b"), expected4) + val decimalType = + spark.table("T").schema.apply(2).dataType.asInstanceOf[DecimalType] + assert(decimalType.precision == 38) + assert(decimalType.scale == 18) + } } } }