From 57e0c8a3ef5b4d35729d26b59af0651a674e92e9 Mon Sep 17 00:00:00 2001 From: Angerszhuuuu Date: Mon, 24 Jun 2024 14:38:04 +0800 Subject: [PATCH] Update test_udtf.py --- python/pyspark/sql/tests/test_udtf.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index 5a19f7fd7377c..f21e66c21dba8 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -2817,12 +2817,10 @@ def eval(self): def test_udt_as_udtf_return_type(self): class TestUDTF: def eval(self): - yield ExamplePoint(0, 1), + yield Row(point=ExamplePoint(0, 1)), schema = StructType().add("point", ExamplePointUDT()) - func = udtf(TestUDTF, returnType=schema, useArrow=True) - [row] = func().collect() - self.assertEqual(row[0], ExamplePoint(1.0, 2.0)) + self._check_result_or_exception(TestUDTF, schema, [Row(point=ExamplePoint(1.0, 2.0))]) class UDTFArrowTests(UDTFArrowTestsMixin, ReusedSQLTestCase):