Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
AngersZhuuuu committed Jun 21, 2024
1 parent 83cd788 commit 110ff32
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 10 deletions.
8 changes: 3 additions & 5 deletions python/pyspark/sql/tests/test_arrow_python_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
pyarrow_requirement_message,
ReusedSQLTestCase,
ExamplePoint,
ExamplePointUDT
ExamplePointUDT,
)
from pyspark.util import PythonEvalType

Expand Down Expand Up @@ -221,10 +221,8 @@ def test_udf(a, b):
with self.assertRaises(PythonException):
self.spark.sql("SELECT test_udf(id, a => id * 10) FROM range(2)").show()

def test_udt_as_return_type(self):
data = [
ExamplePoint(1.0, 2.0),
]
def test_udt_as_udf_return_type(self):
data = [ExamplePoint(1.0, 2.0)]
schema = StructType().add("point", ExamplePointUDT())
df = self.spark.createDataFrame([data], schema=schema)
[row] = df.select(
Expand Down
8 changes: 3 additions & 5 deletions python/pyspark/sql/tests/test_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
pyarrow_requirement_message,
ReusedSQLTestCase,
ExamplePoint,
ExamplePointUDT
ExamplePointUDT,
)


Expand Down Expand Up @@ -2814,14 +2814,12 @@ def eval(self):
with self.assertRaisesRegex(PythonException, "UDTF_ARROW_TYPE_CAST_ERROR"):
udtf(TestUDTF, returnType=ret_type)().collect()

def test_udtf_as_return_type(self):
def test_udt_as_udtf_return_type(self):
class TestUDTF:
def eval(self):
yield ExamplePoint(0, 1),

data = [
ExamplePoint(1.0, 2.0),
]
data = [ExamplePoint(1.0, 2.0),]
schema = StructType().add("point", ExamplePointUDT())
df = self.spark.createDataFrame([data], schema=schema)
[row] = df.select(
Expand Down

0 comments on commit 110ff32

Please sign in to comment.