Skip to content

Commit a5c88ec

Browse files
HyukjinKwoncloud-fan
authored andcommitted
[SPARK-28321][SQL] 0-args Java UDF should not be called only once
## What changes were proposed in this pull request? 0-args Java UDF alone calls the function even before making it as an expression. It causes that the function always returns the same value and the function is called at driver side. Seems like a mistake. ## How was this patch tested? Unit test was added Closes apache#25108 from HyukjinKwon/SPARK-28321. Authored-by: HyukjinKwon <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 507b745 commit a5c88ec

File tree

3 files changed

+19
-10
lines changed

3 files changed

+19
-10
lines changed

sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala

+5-5
Original file line numberDiff line numberDiff line change
@@ -142,16 +142,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
142142
val anyCast = s".asInstanceOf[UDF$i[$anyTypeArgs]]"
143143
val anyParams = (1 to i).map(_ => "_: Any").mkString(", ")
144144
val version = if (i == 0) "2.3.0" else "1.3.0"
145-
val funcCall = if (i == 0) "() => func" else "func"
145+
val funcCall = if (i == 0) s"() => f$anyCast.call($anyParams)" else s"f$anyCast.call($anyParams)"
146146
println(s"""
147147
|/**
148148
| * Register a deterministic Java UDF$i instance as user-defined function (UDF).
149149
| * @since $version
150150
| */
151151
|def register(name: String, f: UDF$i[$extTypeArgs], returnType: DataType): Unit = {
152-
| val func = f$anyCast.call($anyParams)
152+
| val func = $funcCall
153153
| def builder(e: Seq[Expression]) = if (e.length == $i) {
154-
| ScalaUDF($funcCall, returnType, e, e.map(_ => false), udfName = Some(name))
154+
| ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
155155
| } else {
156156
| throw new AnalysisException("Invalid number of arguments for function " + name +
157157
| ". Expected: $i; Found: " + e.length)
@@ -717,9 +717,9 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
717717
* @since 2.3.0
718718
*/
719719
def register(name: String, f: UDF0[_], returnType: DataType): Unit = {
720-
val func = f.asInstanceOf[UDF0[Any]].call()
720+
val func = () => f.asInstanceOf[UDF0[Any]].call()
721721
def builder(e: Seq[Expression]) = if (e.length == 0) {
722-
ScalaUDF(() => func, returnType, e, e.map(_ => false), udfName = Some(name))
722+
ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
723723
} else {
724724
throw new AnalysisException("Invalid number of arguments for function " + name +
725725
". Expected: 0; Found: " + e.length)

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

+5-5
Original file line numberDiff line numberDiff line change
@@ -3932,7 +3932,7 @@ object functions {
39323932
val anyTypeArgs = (0 to i).map(_ => "Any").mkString(", ")
39333933
val anyCast = s".asInstanceOf[UDF$i[$anyTypeArgs]]"
39343934
val anyParams = (1 to i).map(_ => "_: Any").mkString(", ")
3935-
val funcCall = if (i == 0) "() => func" else "func"
3935+
val funcCall = if (i == 0) s"() => f$anyCast.call($anyParams)" else s"f$anyCast.call($anyParams)"
39363936
println(s"""
39373937
|/**
39383938
| * Defines a Java UDF$i instance as user-defined function (UDF).
@@ -3944,8 +3944,8 @@ object functions {
39443944
| * @since 2.3.0
39453945
| */
39463946
|def udf(f: UDF$i[$extTypeArgs], returnType: DataType): UserDefinedFunction = {
3947-
| val func = f$anyCast.call($anyParams)
3948-
| SparkUserDefinedFunction($funcCall, returnType, inputSchemas = Seq.fill($i)(None))
3947+
| val func = $funcCall
3948+
| SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill($i)(None))
39493949
|}""".stripMargin)
39503950
}
39513951
@@ -4145,8 +4145,8 @@ object functions {
41454145
* @since 2.3.0
41464146
*/
41474147
def udf(f: UDF0[_], returnType: DataType): UserDefinedFunction = {
4148-
val func = f.asInstanceOf[UDF0[Any]].call()
4149-
SparkUserDefinedFunction(() => func, returnType, inputSchemas = Seq.fill(0)(None))
4148+
val func = () => f.asInstanceOf[UDF0[Any]].call()
4149+
SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(0)(None))
41504150
}
41514151

41524152
/**

sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala

+9
Original file line numberDiff line numberDiff line change
@@ -514,4 +514,13 @@ class UDFSuite extends QueryTest with SharedSQLContext {
514514
assert(df.collect().toSeq === Seq(Row(expected)))
515515
}
516516
}
517+
518+
test("SPARK-28321 0-args Java UDF should not be called only once") {
519+
val nonDeterministicJavaUDF = udf(
520+
new UDF0[Int] {
521+
override def call(): Int = scala.util.Random.nextInt()
522+
}, IntegerType).asNondeterministic()
523+
524+
assert(spark.range(2).select(nonDeterministicJavaUDF()).distinct().count() == 2)
525+
}
517526
}

0 commit comments

Comments
 (0)