Skip to content

Commit b957d7f

Browse files
committed
[SPARK-44856][PYTHON] Improve Python UDTF arrow serializer performance
### What changes were proposed in this pull request? This PR removes pandas <> Arrow <> pandas conversion in Arrow-optimized Python UDTF by directly using PyArrow. ### Why are the changes needed? Currently, there is a lot of overhead in the arrow serializer for Python UDTFs. The overhead is largely from converting arrow batches into pandas series and converting UDTF's results back to a pandas dataframe. We should try directly converting Python object into arrow and vice versa to avoid the expensive pandas conversion. ### Does this PR introduce _any_ user-facing change? Yes. Previously the conversion was ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #50099 from HyukjinKwon/SPARK-44856. Lead-authored-by: Hyukjin Kwon <[email protected]> Co-authored-by: Hyukjin Kwon <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent d5f735b commit b957d7f

File tree

9 files changed

+528
-24
lines changed

9 files changed

+528
-24
lines changed

python/docs/source/migration_guide/pyspark_upgrade.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ Upgrading from PySpark 4.0 to 4.1
2424

2525
* In Spark 4.1, Arrow-optimized Python UDF supports UDT input / output instead of falling back to the regular UDF. To restore the legacy behavior, set ``spark.sql.execution.pythonUDF.arrow.legacy.fallbackOnUDT`` to ``true``.
2626

27+
* In Spark 4.1, unnecessary conversion to pandas instances is removed when ``spark.sql.execution.pythonUDTF.arrow.enabled`` is enabled. As a result, the type coercion changes when the produced output has a schema different from the specified schema. To restore the previous behavior, enable ``spark.sql.legacy.execution.pythonUDTF.pandas.conversion.enabled``.
28+
29+
2730
Upgrading from PySpark 3.5 to 4.0
2831
---------------------------------
2932

python/pyspark/errors/error-conditions.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -998,6 +998,11 @@
998998
"Cannot convert the output value of the column '<col_name>' with type '<col_type>' to the specified return type of the column: '<arrow_type>'. Please check if the data types match and try again."
999999
]
10001000
},
1001+
"UDTF_ARROW_TYPE_CONVERSION_ERROR": {
1002+
"message": [
1003+
"Cannot convert the output value of the input '<data>' with type '<schema>' to the specified return type of the column: '<arrow_schema>'. Please check if the data types match and try again."
1004+
]
1005+
},
10011006
"UDTF_CONSTRUCTOR_INVALID_IMPLEMENTS_ANALYZE_METHOD": {
10021007
"message": [
10031008
"Failed to evaluate the user-defined table function '<name>' because its constructor is invalid: the function implements the 'analyze' method, but its constructor has more than two arguments (including the 'self' reference). Please update the table function so that its constructor accepts exactly one 'self' argument, or one 'self' argument plus another argument for the result of the 'analyze' method, and try the query again."

python/pyspark/sql/conversion.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,9 @@ def convert(data: Sequence[Any], schema: StructType, use_large_var_types: bool)
345345
if isinstance(item, dict):
346346
for i, col in enumerate(column_names):
347347
pylist[i].append(column_convs[i](item.get(col)))
348+
elif item is None:
349+
for i, col in enumerate(column_names):
350+
pylist[i].append(None)
348351
else:
349352
if len(item) != len(column_names):
350353
raise PySparkValueError(

python/pyspark/sql/pandas/serializers.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,14 @@ def wrap_and_init_stream():
167167
assert isinstance(batch, pa.RecordBatch)
168168

169169
# Wrap the root struct
170-
struct = pa.StructArray.from_arrays(
171-
batch.columns, fields=pa.struct(list(batch.schema))
172-
)
170+
if len(batch.columns) == 0:
171+
# When batch has no column, it should still create
172+
# an empty batch with the number of rows set.
173+
struct = pa.array([{}] * batch.num_rows)
174+
else:
175+
struct = pa.StructArray.from_arrays(
176+
batch.columns, fields=pa.struct(list(batch.schema))
177+
)
173178
batch = pa.RecordBatch.from_arrays([struct], ["_0"])
174179

175180
# Write the first record batch with initialization.
@@ -181,6 +186,15 @@ def wrap_and_init_stream():
181186
return super(ArrowStreamUDFSerializer, self).dump_stream(wrap_and_init_stream(), stream)
182187

183188

189+
class ArrowStreamUDTFSerializer(ArrowStreamUDFSerializer):
190+
"""
191+
Same as :class:`ArrowStreamUDFSerializer` but it does not flatten when loading batches.
192+
"""
193+
194+
def load_stream(self, stream):
195+
return ArrowStreamSerializer.load_stream(self, stream)
196+
197+
184198
class ArrowStreamGroupUDFSerializer(ArrowStreamUDFSerializer):
185199
"""
186200
Serializes pyarrow.RecordBatch data with Arrow streaming format.

python/pyspark/sql/tests/connect/test_parity_udtf.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@
1717
import unittest
1818

1919
from pyspark.testing.connectutils import should_test_connect
20-
from pyspark.sql.tests.test_udtf import BaseUDTFTestsMixin, UDTFArrowTestsMixin
20+
from pyspark.sql.tests.test_udtf import (
21+
BaseUDTFTestsMixin,
22+
UDTFArrowTestsMixin,
23+
LegacyUDTFArrowTestsMixin,
24+
)
2125
from pyspark.testing.connectutils import ReusedConnectTestCase
2226

2327
if should_test_connect:
@@ -88,16 +92,50 @@ def _add_file(self, path):
8892
self.spark.addArtifacts(path, file=True)
8993

9094

95+
class LegacyArrowUDTFParityTests(LegacyUDTFArrowTestsMixin, UDTFParityTests):
96+
@classmethod
97+
def setUpClass(cls):
98+
super(LegacyArrowUDTFParityTests, cls).setUpClass()
99+
cls.spark.conf.set("spark.sql.execution.pythonUDTF.arrow.enabled", "true")
100+
cls.spark.conf.set(
101+
"spark.sql.legacy.execution.pythonUDTF.pandas.conversion.enabled", "true"
102+
)
103+
104+
@classmethod
105+
def tearDownClass(cls):
106+
try:
107+
cls.spark.conf.unset("spark.sql.execution.pythonUDTF.arrow.enabled")
108+
cls.spark.conf.unset("spark.sql.legacy.execution.pythonUDTF.pandas.conversion.enabled")
109+
finally:
110+
super(LegacyArrowUDTFParityTests, cls).tearDownClass()
111+
112+
def test_udtf_access_spark_session_connect(self):
113+
df = self.spark.range(10)
114+
115+
@udtf(returnType="x: int")
116+
class TestUDTF:
117+
def eval(self):
118+
df.collect()
119+
yield 1,
120+
121+
with self.assertRaisesRegex(PythonException, "NO_ACTIVE_SESSION"):
122+
TestUDTF().collect()
123+
124+
91125
class ArrowUDTFParityTests(UDTFArrowTestsMixin, UDTFParityTests):
92126
@classmethod
93127
def setUpClass(cls):
94128
super(ArrowUDTFParityTests, cls).setUpClass()
95129
cls.spark.conf.set("spark.sql.execution.pythonUDTF.arrow.enabled", "true")
130+
cls.spark.conf.set(
131+
"spark.sql.legacy.execution.pythonUDTF.pandas.conversion.enabled", "false"
132+
)
96133

97134
@classmethod
98135
def tearDownClass(cls):
99136
try:
100137
cls.spark.conf.unset("spark.sql.execution.pythonUDTF.arrow.enabled")
138+
cls.spark.conf.unset("spark.sql.legacy.execution.pythonUDTF.pandas.conversion.enabled")
101139
finally:
102140
super(ArrowUDTFParityTests, cls).tearDownClass()
103141

0 commit comments

Comments
 (0)