diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index 5a19f7fd7377c..f21e66c21dba8 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -2817,12 +2817,10 @@ def eval(self): def test_udt_as_udtf_return_type(self): class TestUDTF: def eval(self): - yield ExamplePoint(0, 1), + yield Row(point=ExamplePoint(0, 1)), schema = StructType().add("point", ExamplePointUDT()) - func = udtf(TestUDTF, returnType=schema, useArrow=True) - [row] = func().collect() - self.assertEqual(row[0], ExamplePoint(1.0, 2.0)) + self._check_result_or_exception(TestUDTF, schema, [Row(point=ExamplePoint(1.0, 2.0))]) class UDTFArrowTests(UDTFArrowTestsMixin, ReusedSQLTestCase):