Skip to content

Commit

Permalink
[SPARK-48667][PYSPARK] Arrow python UDFS didn't support UDT as output…
Browse files Browse the repository at this point in the history
… type
  • Loading branch information
AngersZhuuuu committed Jun 20, 2024
1 parent 0d9f8a1 commit 4fecae0
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 6 deletions.
16 changes: 14 additions & 2 deletions python/pyspark/sql/tests/test_arrow_python_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,18 @@
from pyspark.sql import Row
from pyspark.sql.functions import udf
from pyspark.sql.tests.test_udf import BaseUDFTestsMixin
from pyspark.sql.types import VarcharType
from pyspark.sql.types import VarcharType, StructType
from pyspark.testing.sqlutils import (
have_pandas,
have_pyarrow,
pandas_requirement_message,
pyarrow_requirement_message,
ReusedSQLTestCase,
ExamplePoint,
ExamplePointUDT
)
from pyspark.util import PythonEvalType


@unittest.skipIf(
not have_pandas or not have_pyarrow, pandas_requirement_message or pyarrow_requirement_message
)
Expand Down Expand Up @@ -219,6 +220,17 @@ def test_udf(a, b):
with self.assertRaises(PythonException):
self.spark.sql("SELECT test_udf(id, a => id * 10) FROM range(2)").show()

def test_udt_as_return_type(self):
data = [
ExamplePoint(1.0, 2.0),
]
schema = StructType().add("point", ExamplePointUDT())
df = self.spark.createDataFrame([data], schema=schema)
[row] = df.select(
udf(lambda x: x, returnType=ExamplePointUDT(), useArrow=True)("point"),
).collect()
self.assertEqual(row[0], ExamplePoint(1.0, 2.0))


class PythonUDFArrowTests(PythonUDFArrowTestsMixin, ReusedSQLTestCase):
@classmethod
Expand Down
14 changes: 14 additions & 0 deletions sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -219,4 +219,18 @@ private[sql] object ArrowUtils {
valueContainsNull)
case _ => dt
}

def toArrowOutputSchema(dt: DataType): DataType = {
dt match {
case udt: UserDefinedType[_] => toArrowOutputSchema(udt.sqlType)
case arr@ArrayType(elementType, _) =>
arr.copy(elementType = toArrowOutputSchema(elementType))
case struct@StructType(fields) =>
struct.copy(fields.map(field => field.copy(dataType = toArrowOutputSchema(field.dataType))))
case map@MapType(keyType, valueType, _) =>
map.copy(keyType = toArrowOutputSchema(keyType), valueType = toArrowOutputSchema(valueType))
case _ =>
dt
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import org.apache.spark.sql.execution.streaming.GroupStateImpl
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.GroupStateTimeout
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.ArrowUtils

/**
* Physical version of `ObjectProducer`.
Expand Down Expand Up @@ -252,7 +253,8 @@ case class MapPartitionsInRWithArrowExec(
val outputProject = UnsafeProjection.create(output, output)
columnarBatchIter.flatMap { batch =>
val actualDataTypes = (0 until batch.numCols()).map(i => batch.column(i).dataType())
assert(outputTypes == actualDataTypes, "Invalid schema from dapply(): " +
val arrowOutputTypes = outputTypes.map(ArrowUtils.toArrowOutputSchema)
assert(arrowOutputTypes == actualDataTypes, "Invalid schema from dapply(): " +
s"expected ${outputTypes.mkString(", ")}, got ${actualDataTypes.mkString(", ")}")
batch.rowIterator.asScala
}.map(outputProject)
Expand Down Expand Up @@ -598,7 +600,8 @@ case class FlatMapGroupsInRWithArrowExec(

columnarBatchIter.flatMap { batch =>
val actualDataTypes = (0 until batch.numCols()).map(i => batch.column(i).dataType())
assert(outputTypes == actualDataTypes, "Invalid schema from gapply(): " +
val arrowOutputTypes = outputTypes.map(ArrowUtils.toArrowOutputSchema)
assert(arrowOutputTypes == actualDataTypes, "Invalid schema from gapply(): " +
s"expected ${outputTypes.mkString(", ")}, got ${actualDataTypes.mkString(", ")}")
batch.rowIterator().asScala
}.map(outputProject)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.ArrowUtils

/**
* Grouped a iterator into batches.
Expand Down Expand Up @@ -125,7 +126,8 @@ class ArrowEvalPythonEvaluatorFactory(

columnarBatchIter.flatMap { batch =>
val actualDataTypes = (0 until batch.numCols()).map(i => batch.column(i).dataType())
assert(outputTypes == actualDataTypes, "Invalid schema from pandas_udf: " +
val arrowOutputTypes = outputTypes.map(ArrowUtils.toArrowOutputSchema)
assert(arrowOutputTypes == actualDataTypes, "Invalid schema from pandas_udf: " +
s"expected ${outputTypes.mkString(", ")}, got ${actualDataTypes.mkString(", ")}")
batch.rowIterator.asScala
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}

/**
Expand Down Expand Up @@ -81,7 +82,8 @@ case class ArrowEvalPythonUDTFExec(

val actualDataTypes = (0 until flattenedBatch.numCols()).map(
i => flattenedBatch.column(i).dataType())
assert(outputTypes == actualDataTypes, "Invalid schema from arrow-enabled Python UDTF: " +
val arrowOutputTypes = outputTypes.map(ArrowUtils.toArrowOutputSchema)
assert(arrowOutputTypes == actualDataTypes, "Invalid schema from arrow-enabled Python UDTF: " +
s"expected ${outputTypes.mkString(", ")}, got ${actualDataTypes.mkString(", ")}")

flattenedBatch.setNumRows(batch.numRows())
Expand Down

0 comments on commit 4fecae0

Please sign in to comment.