diff --git a/python/pyspark/sql/tests/test_arrow_python_udf.py b/python/pyspark/sql/tests/test_arrow_python_udf.py index 5a66d61cb66a2..64055f73f6e73 100644 --- a/python/pyspark/sql/tests/test_arrow_python_udf.py +++ b/python/pyspark/sql/tests/test_arrow_python_udf.py @@ -21,17 +21,18 @@ from pyspark.sql import Row from pyspark.sql.functions import udf from pyspark.sql.tests.test_udf import BaseUDFTestsMixin -from pyspark.sql.types import VarcharType +from pyspark.sql.types import VarcharType, StructType from pyspark.testing.sqlutils import ( have_pandas, have_pyarrow, pandas_requirement_message, pyarrow_requirement_message, ReusedSQLTestCase, + ExamplePoint, + ExamplePointUDT ) from pyspark.util import PythonEvalType - @unittest.skipIf( not have_pandas or not have_pyarrow, pandas_requirement_message or pyarrow_requirement_message ) @@ -219,6 +220,17 @@ def test_udf(a, b): with self.assertRaises(PythonException): self.spark.sql("SELECT test_udf(id, a => id * 10) FROM range(2)").show() + def test_udt_as_return_type(self): + data = [ + ExamplePoint(1.0, 2.0), + ] + schema = StructType().add("point", ExamplePointUDT()) + df = self.spark.createDataFrame([data], schema=schema) + [row] = df.select( + udf(lambda x: x, returnType=ExamplePointUDT(), useArrow=True)("point"), + ).collect() + self.assertEqual(row[0], ExamplePoint(1.0, 2.0)) + class PythonUDFArrowTests(PythonUDFArrowTestsMixin, ReusedSQLTestCase): @classmethod diff --git a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala index 6852fe09ef96b..354cbecd1bd7a 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala @@ -219,4 +219,18 @@ private[sql] object ArrowUtils { valueContainsNull) case _ => dt } + + def toArrowOutputSchema(dt: DataType): DataType = { + dt match { + case udt: UserDefinedType[_] => toArrowOutputSchema(udt.sqlType) + case arr@ArrayType(elementType, _) => + arr.copy(elementType = toArrowOutputSchema(elementType)) + case struct@StructType(fields) => + struct.copy(fields.map(field => field.copy(dataType = toArrowOutputSchema(field.dataType)))) + case map@MapType(keyType, valueType, _) => + map.copy(keyType = toArrowOutputSchema(keyType), valueType = toArrowOutputSchema(valueType)) + case _ => + dt + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index 2fefd8f70cd5c..4eb7f21f1479f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -42,6 +42,7 @@ import org.apache.spark.sql.execution.streaming.GroupStateImpl import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.GroupStateTimeout import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.ArrowUtils /** * Physical version of `ObjectProducer`. @@ -252,7 +253,8 @@ case class MapPartitionsInRWithArrowExec( val outputProject = UnsafeProjection.create(output, output) columnarBatchIter.flatMap { batch => val actualDataTypes = (0 until batch.numCols()).map(i => batch.column(i).dataType()) - assert(outputTypes == actualDataTypes, "Invalid schema from dapply(): " + + val arrowOutputTypes = outputTypes.map(ArrowUtils.toArrowOutputSchema) + assert(arrowOutputTypes == actualDataTypes, "Invalid schema from dapply(): " + s"expected ${outputTypes.mkString(", ")}, got ${actualDataTypes.mkString(", ")}") batch.rowIterator.asScala }.map(outputProject) @@ -598,7 +600,8 @@ case class FlatMapGroupsInRWithArrowExec( columnarBatchIter.flatMap { batch => val actualDataTypes = (0 until batch.numCols()).map(i => batch.column(i).dataType()) - assert(outputTypes == actualDataTypes, "Invalid schema from gapply(): " + + val arrowOutputTypes = outputTypes.map(ArrowUtils.toArrowOutputSchema) + assert(arrowOutputTypes == actualDataTypes, "Invalid schema from gapply(): " + s"expected ${outputTypes.mkString(", ")}, got ${actualDataTypes.mkString(", ")}") batch.rowIterator().asScala }.map(outputProject) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index da4c5bff34459..1001ec7b36013 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.ArrowUtils /** * Grouped a iterator into batches. @@ -125,7 +126,8 @@ class ArrowEvalPythonEvaluatorFactory( columnarBatchIter.flatMap { batch => val actualDataTypes = (0 until batch.numCols()).map(i => batch.column(i).dataType()) - assert(outputTypes == actualDataTypes, "Invalid schema from pandas_udf: " + + val arrowOutputTypes = outputTypes.map(ArrowUtils.toArrowOutputSchema) + assert(arrowOutputTypes == actualDataTypes, "Invalid schema from pandas_udf: " + s"expected ${outputTypes.mkString(", ")}, got ${actualDataTypes.mkString(", ")}") batch.rowIterator.asScala } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala index 9e210bf5241bb..91f3e05bab9fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch} /** @@ -81,7 +82,8 @@ case class ArrowEvalPythonUDTFExec( val actualDataTypes = (0 until flattenedBatch.numCols()).map( i => flattenedBatch.column(i).dataType()) - assert(outputTypes == actualDataTypes, "Invalid schema from arrow-enabled Python UDTF: " + + val arrowOutputTypes = outputTypes.map(ArrowUtils.toArrowOutputSchema) + assert(arrowOutputTypes == actualDataTypes, "Invalid schema from arrow-enabled Python UDTF: " + s"expected ${outputTypes.mkString(", ")}, got ${actualDataTypes.mkString(", ")}") flattenedBatch.setNumRows(batch.numRows())