From 19374e27dfc95ffd53f3ad80f692450aececf728 Mon Sep 17 00:00:00 2001 From: jackierwzhang <67607237+jackierwzhang@users.noreply.github.com> Date: Tue, 24 Sep 2024 23:06:05 +0800 Subject: [PATCH] Genericalize schema utils to support non-struct root level access (#3716) #### Which Delta project/connector is this regarding? - [x] Spark - [ ] Standalone - [ ] Flink - [ ] Kernel - [ ] Other (fill in here) ## Description Improving some schema utils to allow path index into non-struct root data structures. ## How was this patch tested? New UT. ## Does this PR introduce _any_ user-facing changes? No. --- .../apache/spark/sql/delta/DeltaErrors.scala | 29 +++-- .../sql/delta/schema/SchemaMergingUtils.scala | 6 +- .../spark/sql/delta/schema/SchemaUtils.scala | 119 ++++++++++-------- .../sql/delta/schema/SchemaUtilsSuite.scala | 91 ++++++++++++++ 4 files changed, 182 insertions(+), 63 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/DeltaErrors.scala b/spark/src/main/scala/org/apache/spark/sql/delta/DeltaErrors.scala index 9d8d242b20..58d7105bae 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/DeltaErrors.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/DeltaErrors.scala @@ -1634,14 +1634,14 @@ trait DeltaErrorsBase messageParameters = Array(option, operation)) } - def foundMapTypeColumnException(key: String, value: String, schema: StructType): Throwable = { + def foundMapTypeColumnException(key: String, value: String, schema: DataType): Throwable = { new DeltaAnalysisException( errorClass = "DELTA_FOUND_MAP_TYPE_COLUMN", - messageParameters = Array(key, value, schema.treeString) + messageParameters = Array(key, value, dataTypeToString(schema)) ) } - def columnNotInSchemaException(column: String, schema: StructType): Throwable = { - nonExistentColumnInSchema(column, schema.treeString) + def columnNotInSchemaException(column: String, schema: DataType): Throwable = { + nonExistentColumnInSchema(column, dataTypeToString(schema)) } def metadataAbsentException(): Throwable = { @@ -2690,10 +2690,14 @@ trait DeltaErrorsBase def incorrectArrayAccessByName( rightName: String, wrongName: String, - schema: StructType): Throwable = { + schema: DataType): Throwable = { new DeltaAnalysisException( errorClass = "DELTA_INCORRECT_ARRAY_ACCESS_BY_NAME", - messageParameters = Array(rightName, wrongName, schema.treeString) + messageParameters = Array( + rightName, + wrongName, + dataTypeToString(schema) + ) ) } @@ -2701,14 +2705,14 @@ trait DeltaErrorsBase columnPath: String, other: DataType, column: Seq[String], - schema: StructType): Throwable = { + schema: DataType): Throwable = { new DeltaAnalysisException( errorClass = "DELTA_COLUMN_PATH_NOT_NESTED", messageParameters = Array( s"$columnPath", s"$other", s"${SchemaUtils.prettyFieldName(column)}", - schema.treeString + dataTypeToString(schema) ) ) } @@ -3445,11 +3449,11 @@ trait DeltaErrorsBase } def errorFindingColumnPosition( - columnPath: Seq[String], schema: StructType, extraErrMsg: String): Throwable = { + columnPath: Seq[String], schema: DataType, extraErrMsg: String): Throwable = { new DeltaAnalysisException( errorClass = "_LEGACY_ERROR_TEMP_DELTA_0008", messageParameters = Array( - UnresolvedAttribute(columnPath).name, schema.treeString, extraErrMsg)) + UnresolvedAttribute(columnPath).name, dataTypeToString(schema), extraErrMsg)) } def alterTableClusterByOnPartitionedTableException(): Throwable = { @@ -3481,6 +3485,11 @@ trait DeltaErrorsBase errorClass = "DELTA_UNSUPPORTED_WRITES_WITHOUT_COORDINATOR", messageParameters = Array(coordinatorName)) } + + private def dataTypeToString(dt: DataType): String = dt match { + case s: StructType => s.treeString + case other => other.simpleString + } } object DeltaErrors extends DeltaErrorsBase diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/schema/SchemaMergingUtils.scala b/spark/src/main/scala/org/apache/spark/sql/delta/schema/SchemaMergingUtils.scala index cfd84e9b0f..fd7172603c 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/schema/SchemaMergingUtils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/schema/SchemaMergingUtils.scala @@ -296,9 +296,9 @@ object SchemaMergingUtils { * @param tf function to apply. * @return the transformed schema. */ - def transformColumns( - schema: StructType)( - tf: (Seq[String], StructField, Resolver) => StructField): StructType = { + def transformColumns[T <: DataType]( + schema: T)( + tf: (Seq[String], StructField, Resolver) => StructField): T = { def transform[E <: DataType](path: Seq[String], dt: E): E = { val newDt = dt match { case StructType(fields) => diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/schema/SchemaUtils.scala b/spark/src/main/scala/org/apache/spark/sql/delta/schema/SchemaUtils.scala index b346802caf..0a018e3ca5 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/schema/SchemaUtils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/schema/SchemaUtils.scala @@ -59,7 +59,7 @@ object SchemaUtils extends DeltaLogging { * defines whether we should recurse into ArrayType and MapType. */ def filterRecursively( - schema: StructType, + schema: DataType, checkComplexTypes: Boolean)(f: StructField => Boolean): Seq[(Seq[String], StructField)] = { def recurseIntoComplexTypes( complexType: DataType, @@ -699,7 +699,7 @@ def normalizeColumnNamesInDataType( */ def findColumnPosition( column: Seq[String], - schema: StructType, + schema: DataType, resolver: Resolver = DELTA_COL_RESOLVER): Seq[Int] = { def findRecursively( searchPath: Seq[String], @@ -803,7 +803,7 @@ def normalizeColumnNamesInDataType( * @param position A list of ordinals (0-based) representing the path to the nested field in * `parent`. */ - def getNestedTypeFromPosition(schema: StructType, position: Seq[Int]): DataType = + def getNestedTypeFromPosition(schema: DataType, position: Seq[Int]): DataType = getNestedFieldFromPosition(StructField("schema", schema), position).dataType /** @@ -814,7 +814,34 @@ def normalizeColumnNamesInDataType( } /** - * Add `column` to the specified `position` in `schema`. + * Add a column to its child. + * @param parent The parent data type. + * @param column The column to add. + * @param position The position to add the column. + */ + def addColumn[T <: DataType](parent: T, column: StructField, position: Seq[Int]): T = { + if (position.isEmpty) { + throw DeltaErrors.addColumnParentNotStructException(column, parent) + } + parent match { + case struct: StructType => + addColumnToStruct(struct, column, position).asInstanceOf[T] + case map: MapType if position.head == MAP_KEY_INDEX => + map.copy(keyType = addColumn(map.keyType, column, position.tail)).asInstanceOf[T] + case map: MapType if position.head == MAP_VALUE_INDEX => + map.copy(valueType = addColumn(map.valueType, column, position.tail)).asInstanceOf[T] + case array: ArrayType if position.head == ARRAY_ELEMENT_INDEX => + array.copy(elementType = + addColumn(array.elementType, column, position.tail)).asInstanceOf[T] + case _: ArrayType => + throw DeltaErrors.incorrectArrayAccess() + case other => + throw DeltaErrors.addColumnParentNotStructException(column, other) + } + } + + /** + * Add `column` to the specified `position` in a struct `schema`. * @param position A Seq of ordinals on where this column should go. It is a Seq to denote * positions in nested columns (0-based). For example: * @@ -824,26 +851,10 @@ def normalizeColumnNamesInDataType( * will return * result: , b,c:STRUCT> */ - def addColumn(schema: StructType, column: StructField, position: Seq[Int]): StructType = { - def addColumnInChild(parent: DataType, column: StructField, position: Seq[Int]): DataType = { - if (position.isEmpty) { - throw DeltaErrors.addColumnParentNotStructException(column, parent) - } - parent match { - case struct: StructType => - addColumn(struct, column, position) - case map: MapType if position.head == MAP_KEY_INDEX => - map.copy(keyType = addColumnInChild(map.keyType, column, position.tail)) - case map: MapType if position.head == MAP_VALUE_INDEX => - map.copy(valueType = addColumnInChild(map.valueType, column, position.tail)) - case array: ArrayType if position.head == ARRAY_ELEMENT_INDEX => - array.copy(elementType = addColumnInChild(array.elementType, column, position.tail)) - case _: ArrayType => - throw DeltaErrors.incorrectArrayAccess() - case other => - throw DeltaErrors.addColumnParentNotStructException(column, other) - } - } + private def addColumnToStruct( + schema: StructType, + column: StructField, + position: Seq[Int]): StructType = { // If the proposed new column includes a default value, return a specific "not supported" error. // The rationale is that such operations require the data source scan operator to implement // support for filling in the specified default value when the corresponding field is not @@ -877,13 +888,42 @@ def normalizeColumnNamesInDataType( if (!column.nullable && field.nullable) { throw DeltaErrors.nullableParentWithNotNullNestedField } - val mid = field.copy(dataType = addColumnInChild(field.dataType, column, position.tail)) + val mid = field.copy(dataType = addColumn(field.dataType, column, position.tail)) StructType(pre ++ Seq(mid) ++ post.tail) } else { StructType(pre ++ Seq(column) ++ post) } } + /** + * Drop a column from its child. + * @param parent The parent data type. + * @param position The position to drop the column. + */ + def dropColumn[T <: DataType](parent: T, position: Seq[Int]): (T, StructField) = { + if (position.isEmpty) { + throw DeltaErrors.dropNestedColumnsFromNonStructTypeException(parent) + } + parent match { + case struct: StructType => + val (t, s) = dropColumnInStruct(struct, position) + (t.asInstanceOf[T], s) + case map: MapType if position.head == MAP_KEY_INDEX => + val (newKeyType, droppedColumn) = dropColumn(map.keyType, position.tail) + map.copy(keyType = newKeyType).asInstanceOf[T] -> droppedColumn + case map: MapType if position.head == MAP_VALUE_INDEX => + val (newValueType, droppedColumn) = dropColumn(map.valueType, position.tail) + map.copy(valueType = newValueType).asInstanceOf[T] -> droppedColumn + case array: ArrayType if position.head == ARRAY_ELEMENT_INDEX => + val (newElementType, droppedColumn) = dropColumn(array.elementType, position.tail) + array.copy(elementType = newElementType).asInstanceOf[T] -> droppedColumn + case _: ArrayType => + throw DeltaErrors.incorrectArrayAccess() + case other => + throw DeltaErrors.dropNestedColumnsFromNonStructTypeException(other) + } + } + /** * Drop from the specified `position` in `schema` and return with the original column. * @param position A Seq of ordinals on where this column should go. It is a Seq to denote @@ -894,30 +934,9 @@ def normalizeColumnNamesInDataType( * will return * result: , b,c:STRUCT> */ - def dropColumn(schema: StructType, position: Seq[Int]): (StructType, StructField) = { - def dropColumnInChild(parent: DataType, position: Seq[Int]): (DataType, StructField) = { - if (position.isEmpty) { - throw DeltaErrors.dropNestedColumnsFromNonStructTypeException(parent) - } - parent match { - case struct: StructType => - dropColumn(struct, position) - case map: MapType if position.head == MAP_KEY_INDEX => - val (newKeyType, droppedColumn) = dropColumnInChild(map.keyType, position.tail) - map.copy(keyType = newKeyType) -> droppedColumn - case map: MapType if position.head == MAP_VALUE_INDEX => - val (newValueType, droppedColumn) = dropColumnInChild(map.valueType, position.tail) - map.copy(valueType = newValueType) -> droppedColumn - case array: ArrayType if position.head == ARRAY_ELEMENT_INDEX => - val (newElementType, droppedColumn) = dropColumnInChild(array.elementType, position.tail) - array.copy(elementType = newElementType) -> droppedColumn - case _: ArrayType => - throw DeltaErrors.incorrectArrayAccess() - case other => - throw DeltaErrors.dropNestedColumnsFromNonStructTypeException(other) - } - } - + private def dropColumnInStruct( + schema: StructType, + position: Seq[Int]): (StructType, StructField) = { require(position.nonEmpty, "Don't know where to drop the column") val slicePosition = position.head if (slicePosition < 0) { @@ -930,7 +949,7 @@ def normalizeColumnNamesInDataType( val (pre, post) = schema.splitAt(slicePosition) val field = post.head if (position.length > 1) { - val (newType, droppedColumn) = dropColumnInChild(field.dataType, position.tail) + val (newType, droppedColumn) = dropColumn(field.dataType, position.tail) val mid = field.copy(dataType = newType) StructType(pre ++ Seq(mid) ++ post.tail) -> droppedColumn diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/schema/SchemaUtilsSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/schema/SchemaUtilsSuite.scala index 89b9b0cc17..a8fbfd51ff 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/schema/SchemaUtilsSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/schema/SchemaUtilsSuite.scala @@ -1258,6 +1258,35 @@ class SchemaUtilsSuite extends QueryTest } } + test("addColumn - top level array") { + val a = StructField("a", IntegerType) + val b = StructField("b", StringType) + val schema = ArrayType(new StructType().add(a).add(b)) + + val x = StructField("x", LongType) + assert(SchemaUtils.addColumn(schema, x, Seq(0, 1)) === + ArrayType(new StructType().add(a).add(x).add(b))) + } + + test("addColumn - top level map") { + val k = StructField("k", IntegerType) + val v = StructField("v", StringType) + val schema = MapType( + keyType = new StructType().add(k), + valueType = new StructType().add(v)) + + val x = StructField("x", LongType) + assert(SchemaUtils.addColumn(schema, x, Seq(0, 1)) === + MapType( + keyType = new StructType().add(k).add(x), + valueType = new StructType().add(v))) + + assert(SchemaUtils.addColumn(schema, x, Seq(1, 1)) === + MapType( + keyType = new StructType().add(k), + valueType = new StructType().add(v).add(x))) + } + //////////////////////////// // dropColumn //////////////////////////// @@ -1511,6 +1540,29 @@ class SchemaUtilsSuite extends QueryTest } } + test("dropColumn - top level array") { + val schema = ArrayType(new StructType().add("a", IntegerType).add("b", StringType)) + + assert(SchemaUtils.dropColumn(schema, Seq(0, 0))._1 === + ArrayType(new StructType().add("b", StringType))) + } + + test("dropColumn - top level map") { + val schema = MapType( + keyType = new StructType().add("k", IntegerType).add("k2", StringType), + valueType = new StructType().add("v", StringType).add("v2", StringType)) + + assert(SchemaUtils.dropColumn(schema, Seq(0, 0))._1 === + MapType( + keyType = new StructType().add("k2", StringType), + valueType = new StructType().add("v", StringType).add("v2", StringType))) + + assert(SchemaUtils.dropColumn(schema, Seq(1, 0))._1 === + MapType( + keyType = new StructType().add("k", IntegerType).add("k2", StringType), + valueType = new StructType().add("v2", StringType))) + } + ///////////////////////////////// // normalizeColumnNamesInDataType ///////////////////////////////// @@ -2584,6 +2636,45 @@ class SchemaUtilsSuite extends QueryTest assert(update === res3) } + test("transform top level array type") { + val at = ArrayType( + new StructType() + .add("s1", IntegerType) + ) + + var visitedFields = 0 + val updated = SchemaMergingUtils.transformColumns(at) { + case (_, field, _) => + visitedFields += 1 + field.copy(name = "s1_1", dataType = StringType) + } + + assert(visitedFields === 1) + assert(updated === ArrayType(new StructType().add("s1_1", StringType))) + } + + test("transform top level map type") { + val mt = MapType( + new StructType() + .add("k1", IntegerType), + new StructType() + .add("v1", IntegerType) + ) + + var visitedFields = 0 + val updated = SchemaMergingUtils.transformColumns(mt) { + case (_, field, _) => + visitedFields += 1 + field.copy(name = field.name + "_1", dataType = StringType) + } + + assert(visitedFields === 2) + assert(updated === MapType( + new StructType().add("k1_1", StringType), + new StructType().add("v1_1", StringType) + )) + } + //////////////////////////// // pruneEmptyStructs ////////////////////////////