From 8b1a37cabd09237888e61240256c7e4650e9d483 Mon Sep 17 00:00:00 2001 From: zouxxyy Date: Tue, 17 Sep 2024 11:26:31 +0800 Subject: [PATCH] 1 --- .../org/apache/paimon/types/DataField.java | 4 + .../org/apache/paimon/utils/Projection.java | 90 +++++++++++++------ .../paimon/table/source/ReadBuilder.java | 2 +- .../paimon/table/source/ReadBuilderImpl.java | 3 +- .../reader/ParquetSplitReaderUtil.java | 6 +- .../apache/paimon/spark/SparkTypeUtils.java | 31 ++++++- .../apache/paimon/spark/PaimonBaseScan.scala | 20 ++--- .../paimon/spark/PaimonBaseScanBuilder.scala | 7 +- .../paimon/spark/SparkInternalRowTest.java | 2 +- ...kTypeTest.java => SparkTypeUtilsTest.java} | 81 ++++++++++++++++- .../paimon/spark/sql/PaimonQueryTest.scala | 27 ++++++ 11 files changed, 221 insertions(+), 52 deletions(-) rename paimon-spark/paimon-spark-common/src/test/java/org/apache/paimon/spark/{SparkTypeTest.java => SparkTypeUtilsTest.java} (54%) diff --git a/paimon-common/src/main/java/org/apache/paimon/types/DataField.java b/paimon-common/src/main/java/org/apache/paimon/types/DataField.java index baffbcf3f997..25983f6d2d40 100644 --- a/paimon-common/src/main/java/org/apache/paimon/types/DataField.java +++ b/paimon-common/src/main/java/org/apache/paimon/types/DataField.java @@ -84,6 +84,10 @@ public DataField newName(String newName) { return new DataField(id, newName, type, description); } + public DataField newType(DataType newType) { + return new DataField(id, name, newType, description); + } + public DataField newDescription(String newDescription) { return new DataField(id, name, type, newDescription); } diff --git a/paimon-common/src/main/java/org/apache/paimon/utils/Projection.java b/paimon-common/src/main/java/org/apache/paimon/utils/Projection.java index bfdca8e71a1f..b2d04176002c 100644 --- a/paimon-common/src/main/java/org/apache/paimon/utils/Projection.java +++ b/paimon-common/src/main/java/org/apache/paimon/utils/Projection.java @@ -27,14 +27,13 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.HashSet; +import java.util.LinkedList; import java.util.List; import java.util.ListIterator; import java.util.Set; import java.util.stream.Collectors; import java.util.stream.IntStream; -import static org.apache.paimon.types.DataTypeRoot.ROW; - /** * {@link Projection} represents a list of (possibly nested) indexes that can be used to project * data types. A row projection includes both reducing the accessible fields and reordering them. @@ -238,28 +237,18 @@ private static class NestedProjection extends Projection { @Override public RowType project(RowType rowType) { - final List updatedFields = new ArrayList<>(); - Set nameDomain = new HashSet<>(); - int duplicateCount = 0; - for (int[] indexPath : this.projection) { - DataField field = rowType.getFields().get(indexPath[0]); - StringBuilder builder = - new StringBuilder(rowType.getFieldNames().get(indexPath[0])); - for (int index = 1; index < indexPath.length; index++) { - Preconditions.checkArgument( - field.type().getTypeRoot() == ROW, "Row data type expected."); - RowType rowtype = ((RowType) field.type()); - builder.append("_").append(rowtype.getFieldNames().get(indexPath[index])); - field = rowtype.getFields().get(indexPath[index]); - } - String path = builder.toString(); - while (nameDomain.contains(path)) { - path = builder.append("_$").append(duplicateCount++).toString(); + ProjectedDataTypeBuilder builder = new ProjectedDataTypeBuilder(rowType); + for (int[] path : projection) { + ProjectedDataTypeBuilder current = builder; + for (int i = 0; i < path.length; i++) { + current.projectField(path[i]); + if (i == path.length - 1) { + current.fieldBuilder(path[i]).project(); + } + current = current.fieldBuilder(path[i]); } - updatedFields.add(field.newName(path)); - nameDomain.add(path); } - return new RowType(rowType.isNullable(), updatedFields); + return (RowType) builder.build(); } @Override @@ -321,10 +310,7 @@ public Projection complement(int fieldsNumber) { @Override public int[] toTopLevelIndexes() { - if (isNested()) { - throw new IllegalStateException( - "Cannot convert a nested projection to a top level projection"); - } + // todo: fix it usage return Arrays.stream(projection).mapToInt(arr -> arr[0]).toArray(); } @@ -416,4 +402,56 @@ public int[][] toNestedIndexes() { return Arrays.stream(projection).mapToObj(i -> new int[] {i}).toArray(int[][]::new); } } + + private static class ProjectedDataTypeBuilder { + private final DataType dataType; + private boolean projected = false; + private final Set projectedFieldIds = new HashSet<>(); + private final LinkedList fieldBuilders = new LinkedList<>(); + + public ProjectedDataTypeBuilder(DataType dataType) { + this.dataType = dataType; + if (dataType instanceof RowType) { + for (DataField field : ((RowType) dataType).getFields()) { + fieldBuilders.add(new ProjectedDataTypeBuilder(field.type())); + } + } + } + + public ProjectedDataTypeBuilder project() { + this.projected = true; + return this; + } + + public ProjectedDataTypeBuilder projectField(int fieldId) { + if (!projected) { + this.projectedFieldIds.add(fieldId); + } + return this; + } + + public ProjectedDataTypeBuilder fieldBuilder(int fieldId) { + return fieldBuilders.get(fieldId); + } + + public DataType build() { + if (projected) { + return dataType.copy(); + } + + if (!fieldBuilders.isEmpty() && !projectedFieldIds.isEmpty()) { + List oldFields = ((RowType) dataType).getFields(); + List fields = new ArrayList<>(fieldBuilders.size()); + for (int i = 0; i < fieldBuilders.size(); i++) { + if (projectedFieldIds.contains(i)) { + DataType newType = fieldBuilders.get(i).build(); + fields.add(oldFields.get(i).newType(newType)); + } + } + return new RowType(dataType.isNullable(), fields); + } else { + throw new RuntimeException(); + } + } + } } diff --git a/paimon-core/src/main/java/org/apache/paimon/table/source/ReadBuilder.java b/paimon-core/src/main/java/org/apache/paimon/table/source/ReadBuilder.java index b7207927ca68..7e4b0cc685d5 100644 --- a/paimon-core/src/main/java/org/apache/paimon/table/source/ReadBuilder.java +++ b/paimon-core/src/main/java/org/apache/paimon/table/source/ReadBuilder.java @@ -113,7 +113,7 @@ default ReadBuilder withFilter(List predicates) { /** * Apply projection to the reader. * - *

NOTE: Nested row projection is currently not supported. + *

todo: update it. */ default ReadBuilder withProjection(int[] projection) { if (projection == null) { diff --git a/paimon-core/src/main/java/org/apache/paimon/table/source/ReadBuilderImpl.java b/paimon-core/src/main/java/org/apache/paimon/table/source/ReadBuilderImpl.java index e9b83340a881..4b76b9748538 100644 --- a/paimon-core/src/main/java/org/apache/paimon/table/source/ReadBuilderImpl.java +++ b/paimon-core/src/main/java/org/apache/paimon/table/source/ReadBuilderImpl.java @@ -24,7 +24,6 @@ import org.apache.paimon.types.RowType; import org.apache.paimon.utils.Filter; import org.apache.paimon.utils.Projection; -import org.apache.paimon.utils.TypeUtils; import java.util.Arrays; import java.util.Map; @@ -65,7 +64,7 @@ public RowType readType() { if (projection == null) { return table.rowType(); } - return TypeUtils.project(table.rowType(), Projection.of(projection).toTopLevelIndexes()); + return Projection.of(projection).project(table.rowType()); } @Override diff --git a/paimon-format/src/main/java/org/apache/paimon/format/parquet/reader/ParquetSplitReaderUtil.java b/paimon-format/src/main/java/org/apache/paimon/format/parquet/reader/ParquetSplitReaderUtil.java index 90abaa992c17..860ec54fa88b 100644 --- a/paimon-format/src/main/java/org/apache/paimon/format/parquet/reader/ParquetSplitReaderUtil.java +++ b/paimon-format/src/main/java/org/apache/paimon/format/parquet/reader/ParquetSplitReaderUtil.java @@ -370,12 +370,12 @@ private static List getAllColumnDescriptorByType( } public static List buildFieldsList( - List childrens, List fieldNames, MessageColumnIO columnIO) { + List children, List fieldNames, MessageColumnIO columnIO) { List list = new ArrayList<>(); - for (int i = 0; i < childrens.size(); i++) { + for (int i = 0; i < children.size(); i++) { list.add( constructField( - childrens.get(i), lookupColumnByName(columnIO, fieldNames.get(i)))); + children.get(i), lookupColumnByName(columnIO, fieldNames.get(i)))); } return list; } diff --git a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkTypeUtils.java b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkTypeUtils.java index 08fb2de32aa6..f8132b8d789f 100644 --- a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkTypeUtils.java +++ b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkTypeUtils.java @@ -50,6 +50,7 @@ import org.apache.spark.sql.types.UserDefinedType; import java.util.ArrayList; +import java.util.LinkedList; import java.util.List; import java.util.Optional; import java.util.concurrent.atomic.AtomicInteger; @@ -183,10 +184,6 @@ public DataType visit(MapType mapType) { mapType.getValueType().isNullable()); } - /** - * For simplicity, as a temporary solution, we directly convert the non-null attribute to - * nullable on the Spark side. - */ @Override public DataType visit(RowType rowType) { List fields = new ArrayList<>(rowType.getFieldCount()); @@ -333,4 +330,30 @@ public org.apache.paimon.types.DataType atomic(DataType atomic) { "Not a supported type: " + atomic.catalogString()); } } + + public static int[][] populateProjection(StructType structType, RowType type) { + LinkedList projectionList = new LinkedList<>(); + populateProjection(structType, type, projectionList, new LinkedList<>()); + return projectionList.toArray(new int[0][]); + } + + private static void populateProjection( + StructType structType, + RowType rowType, + LinkedList projectionList, + LinkedList currentPath) { + for (StructField field : structType.fields()) { + currentPath.add(rowType.getFieldIndex(field.name())); + if (field.dataType() instanceof StructType) { + populateProjection( + (StructType) field.dataType(), + (RowType) rowType.getField(field.name()).type(), + projectionList, + currentPath); + } else { + projectionList.add(currentPath.stream().mapToInt(Integer::intValue).toArray()); + } + currentPath.removeLast(); + } + } } diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScan.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScan.scala index 2188879c7ac3..103c232bd564 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScan.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScan.scala @@ -55,13 +55,12 @@ abstract class PaimonBaseScan( private lazy val tableSchema = SparkTypeUtils.fromPaimonRowType(tableRowType) private[paimon] val (requiredTableFields, metadataFields) = { - val nameToField = tableSchema.map(field => (field.name, field)).toMap - val _tableFields = requiredSchema.flatMap(field => nameToField.get(field.name)) - val _metadataFields = - requiredSchema - .filterNot(field => tableSchema.fieldNames.contains(field.name)) - .filter(field => PaimonMetadataColumn.SUPPORTED_METADATA_COLUMNS.contains(field.name)) - (_tableFields, _metadataFields) + assert( + requiredSchema.fields.forall( + field => + tableRowType.containsField(field.name) || + PaimonMetadataColumn.SUPPORTED_METADATA_COLUMNS.contains(field.name))) + requiredSchema.fields.partition(field => tableRowType.containsField(field.name)) } protected var runtimeFilters: Array[Filter] = Array.empty @@ -82,9 +81,8 @@ abstract class PaimonBaseScan( lazy val readBuilder: ReadBuilder = { val _readBuilder = table.newReadBuilder() - val projection = - requiredTableFields.map(field => tableSchema.fieldNames.indexOf(field.name)).toArray - _readBuilder.withProjection(projection) + _readBuilder.withProjection( + SparkTypeUtils.populateProjection(StructType(requiredTableFields), tableRowType)) if (filters.nonEmpty) { val pushedPredicate = PredicateBuilder.and(filters: _*) _readBuilder.withFilter(pushedPredicate) @@ -114,7 +112,7 @@ abstract class PaimonBaseScan( } override def readSchema(): StructType = { - StructType(requiredTableFields ++ metadataFields) + requiredSchema } override def toBatch: Batch = { diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScanBuilder.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScanBuilder.scala index 0efe14552afe..c91e1cc9f401 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScanBuilder.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScanBuilder.scala @@ -34,15 +34,16 @@ abstract class PaimonBaseScanBuilder(table: Table) with SupportsPushDownRequiredColumns with Logging { - protected var requiredSchema: StructType = SparkTypeUtils.fromPaimonRowType(table.rowType()) + private var prunedSchema: Option[StructType] = None - protected var pushed: Array[(Filter, Predicate)] = Array.empty + private var pushed: Array[(Filter, Predicate)] = Array.empty protected var reservedFilters: Array[Filter] = Array.empty protected var pushDownLimit: Option[Int] = None override def build(): Scan = { + val requiredSchema = prunedSchema.getOrElse(SparkTypeUtils.fromPaimonRowType(table.rowType)) PaimonScan(table, requiredSchema, pushed.map(_._2), reservedFilters, pushDownLimit) } @@ -87,6 +88,6 @@ abstract class PaimonBaseScanBuilder(table: Table) } override def pruneColumns(requiredSchema: StructType): Unit = { - this.requiredSchema = requiredSchema + this.prunedSchema = Some(requiredSchema) } } diff --git a/paimon-spark/paimon-spark-common/src/test/java/org/apache/paimon/spark/SparkInternalRowTest.java b/paimon-spark/paimon-spark-common/src/test/java/org/apache/paimon/spark/SparkInternalRowTest.java index 9af886d8369f..c5c729a466d9 100644 --- a/paimon-spark/paimon-spark-common/src/test/java/org/apache/paimon/spark/SparkInternalRowTest.java +++ b/paimon-spark/paimon-spark-common/src/test/java/org/apache/paimon/spark/SparkInternalRowTest.java @@ -45,7 +45,7 @@ import scala.collection.JavaConverters; import static org.apache.paimon.data.BinaryString.fromString; -import static org.apache.paimon.spark.SparkTypeTest.ALL_TYPES; +import static org.apache.paimon.spark.SparkTypeUtilsTest.ALL_TYPES; import static org.assertj.core.api.Assertions.assertThat; /** Test for {@link SparkInternalRow}. */ diff --git a/paimon-spark/paimon-spark-common/src/test/java/org/apache/paimon/spark/SparkTypeTest.java b/paimon-spark/paimon-spark-common/src/test/java/org/apache/paimon/spark/SparkTypeUtilsTest.java similarity index 54% rename from paimon-spark/paimon-spark-common/src/test/java/org/apache/paimon/spark/SparkTypeTest.java rename to paimon-spark/paimon-spark-common/src/test/java/org/apache/paimon/spark/SparkTypeUtilsTest.java index fdc7558fd5f4..fe5a749504cb 100644 --- a/paimon-spark/paimon-spark-common/src/test/java/org/apache/paimon/spark/SparkTypeTest.java +++ b/paimon-spark/paimon-spark-common/src/test/java/org/apache/paimon/spark/SparkTypeUtilsTest.java @@ -21,6 +21,8 @@ import org.apache.paimon.types.DataTypes; import org.apache.paimon.types.RowType; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.junit.jupiter.api.Test; @@ -31,7 +33,7 @@ import static org.assertj.core.api.Assertions.assertThat; /** Test for {@link SparkTypeUtils}. */ -public class SparkTypeTest { +public class SparkTypeUtilsTest { public static final RowType ALL_TYPES = RowType.builder( @@ -107,4 +109,81 @@ public void testAllTypes() { assertThat(toPaimonType(sparkType)).isEqualTo(ALL_TYPES); } + + @Test + public void testPopulateProjection() { + RowType rowType = + DataTypes.ROW( + DataTypes.FIELD(0, "f0", DataTypes.INT()), + DataTypes.FIELD( + 1, + "f1", + DataTypes.ROW( + DataTypes.FIELD(0, "f0", DataTypes.INT()), + DataTypes.FIELD(1, "f1", DataTypes.INT()))), + DataTypes.FIELD( + 2, + "f2", + DataTypes.ROW( + DataTypes.FIELD( + 0, + "f0", + DataTypes.ROW( + DataTypes.FIELD(0, "f0", DataTypes.INT()), + DataTypes.FIELD( + 1, "f1", DataTypes.INT())))))); + + StructType structType = + new StructType( + new StructField[] { + new StructField( + "f0", + org.apache.spark.sql.types.DataTypes.IntegerType, + false, + Metadata.empty()), + new StructField( + "f1", + new StructType( + new StructField[] { + new StructField( + "f0", + org.apache.spark.sql.types.DataTypes + .IntegerType, + false, + Metadata.empty()), + new StructField( + "f1", + org.apache.spark.sql.types.DataTypes + .IntegerType, + false, + Metadata.empty()) + }), + false, + Metadata.empty()), + new StructField( + "f2", + new StructType( + new StructField[] { + new StructField( + "f0", + new StructType( + new StructField[] { + new StructField( + "f1", + org.apache.spark.sql + .types.DataTypes + .IntegerType, + false, + Metadata.empty()) + }), + false, + Metadata.empty()) + }), + false, + Metadata.empty()) + }); + + assertThat(SparkTypeUtils.populateProjection(structType, rowType)) + .isEqualTo(new int[][] {{0}, {1, 0}, {1, 1}, {2, 0, 1}}); + } } diff --git a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/PaimonQueryTest.scala b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/PaimonQueryTest.scala index 70296570181d..d1169378bcbc 100644 --- a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/PaimonQueryTest.scala +++ b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/PaimonQueryTest.scala @@ -175,6 +175,33 @@ class PaimonQueryTest extends PaimonSparkTestBase { } } + test("Paimon Query: query nested cols") { + withTable("students") { + sql(""" + |CREATE TABLE students ( + | name STRING, + | age INT, + | course STRUCT + |) USING paimon; + |""".stripMargin) + + sql(""" + |INSERT INTO students VALUES + |('Alice', 20, STRUCT('Math', 85.0)), + |('Bob', 22, STRUCT('Biology', 92.0)), + |('Cathy', 21, STRUCT('History', 95.0)); + |""".stripMargin) + + sql(""" + |SELECT + | name, + | course.course_name + |FROM + | students; + |""".stripMargin).show() + } + } + private def getAllFiles( tableName: String, partitions: Seq[String],