Skip to content

Commit

Permalink
[SPARK-37957][SQL] Correctly pass deterministic flag for V2 scalar fu…
Browse files Browse the repository at this point in the history
…nctions

### What changes were proposed in this pull request?

Pass `isDeterministic` flag to `ApplyFunctionExpression`, `Invoke` and `StaticInvoke` when processing V2 scalar functions.

### Why are the changes needed?

A V2 scalar function can be declared as non-deterministic. However, currently Spark doesn't pass the flag when converting the V2 function to a catalyst expression, which could lead to incorrect results if being applied with certain optimizations.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Added a unit test.

Closes apache#35243 from sunchao/SPARK-37957.

Authored-by: Chao Sun <[email protected]>
Signed-off-by: Chao Sun <[email protected]>
  • Loading branch information
sunchao committed Jan 19, 2022
1 parent 5cf8108 commit 3860ac5
Show file tree
Hide file tree
Showing 7 changed files with 158 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2299,12 +2299,14 @@ class Analyzer(override val catalogManager: CatalogManager)
case Some(m) if Modifier.isStatic(m.getModifiers) =>
StaticInvoke(scalarFunc.getClass, scalarFunc.resultType(),
MAGIC_METHOD_NAME, arguments, inputTypes = declaredInputTypes,
propagateNull = false, returnNullable = scalarFunc.isResultNullable)
propagateNull = false, returnNullable = scalarFunc.isResultNullable,
isDeterministic = scalarFunc.isDeterministic)
case Some(_) =>
val caller = Literal.create(scalarFunc, ObjectType(scalarFunc.getClass))
Invoke(caller, MAGIC_METHOD_NAME, scalarFunc.resultType(),
arguments, methodInputTypes = declaredInputTypes, propagateNull = false,
returnNullable = scalarFunc.isResultNullable)
returnNullable = scalarFunc.isResultNullable,
isDeterministic = scalarFunc.isDeterministic)
case _ =>
// TODO: handle functions defined in Scala too - in Scala, even if a
// subclass do not override the default method in parent interface
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ case class ApplyFunctionExpression(
override def name: String = function.name()
override def dataType: DataType = function.resultType()
override def inputTypes: Seq[AbstractDataType] = function.inputTypes().toSeq
override lazy val deterministic: Boolean = function.isDeterministic &&
children.forall(_.deterministic)

private lazy val reusedRow = new SpecificInternalRow(function.inputTypes())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,8 @@ object SerializerSupport {
* without invoking the function.
* @param returnNullable When false, indicating the invoked method will always return
* non-null value.
* @param isDeterministic Whether the method invocation is deterministic or not. If false, Spark
* will not apply certain optimizations such as constant folding.
*/
case class StaticInvoke(
staticObject: Class[_],
Expand All @@ -248,7 +250,8 @@ case class StaticInvoke(
arguments: Seq[Expression] = Nil,
inputTypes: Seq[AbstractDataType] = Nil,
propagateNull: Boolean = true,
returnNullable: Boolean = true) extends InvokeLike {
returnNullable: Boolean = true,
isDeterministic: Boolean = true) extends InvokeLike {

val objectName = staticObject.getName.stripSuffix("$")
val cls = if (staticObject.getName == objectName) {
Expand All @@ -259,6 +262,7 @@ case class StaticInvoke(

override def nullable: Boolean = needNullCheck || returnNullable
override def children: Seq[Expression] = arguments
override lazy val deterministic: Boolean = isDeterministic && arguments.forall(_.deterministic)

lazy val argClasses = ScalaReflection.expressionJavaClasses(arguments)
@transient lazy val method = findMethod(cls, functionName, argClasses)
Expand Down Expand Up @@ -340,6 +344,8 @@ case class StaticInvoke(
* without invoking the function.
* @param returnNullable When false, indicating the invoked method will always return
* non-null value.
* @param isDeterministic Whether the method invocation is deterministic or not. If false, Spark
* will not apply certain optimizations such as constant folding.
*/
case class Invoke(
targetObject: Expression,
Expand All @@ -348,12 +354,14 @@ case class Invoke(
arguments: Seq[Expression] = Nil,
methodInputTypes: Seq[AbstractDataType] = Nil,
propagateNull: Boolean = true,
returnNullable : Boolean = true) extends InvokeLike {
returnNullable : Boolean = true,
isDeterministic: Boolean = true) extends InvokeLike {

lazy val argClasses = ScalaReflection.expressionJavaClasses(arguments)

override def nullable: Boolean = targetObject.nullable || needNullCheck || returnNullable
override def children: Seq[Expression] = targetObject +: arguments
override lazy val deterministic: Boolean = isDeterministic && arguments.forall(_.deterministic)
override def inputTypes: Seq[AbstractDataType] =
if (methodInputTypes.nonEmpty) {
Seq(targetObject.dataType) ++ methodInputTypes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public String description() {
return "long_add";
}

private abstract static class JavaLongAddBase implements ScalarFunction<Long> {
public abstract static class JavaLongAddBase implements ScalarFunction<Long> {
private final boolean isResultNullable;

JavaLongAddBase(boolean isResultNullable) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package test.org.apache.spark.sql.connector.catalog.functions;

import java.util.Random;

import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.connector.catalog.functions.BoundFunction;
import org.apache.spark.sql.connector.catalog.functions.ScalarFunction;
import org.apache.spark.sql.connector.catalog.functions.UnboundFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.IntegerType;
import org.apache.spark.sql.types.StructType;

/**
* Test V2 function which add a random number to the input integer.
*/
public class JavaRandomAdd implements UnboundFunction {
private final BoundFunction fn;

public JavaRandomAdd(BoundFunction fn) {
this.fn = fn;
}

@Override
public String name() {
return "rand";
}

@Override
public BoundFunction bind(StructType inputType) {
if (inputType.fields().length != 1) {
throw new UnsupportedOperationException("Expect exactly one argument");
}
if (inputType.fields()[0].dataType() instanceof IntegerType) {
return fn;
}
throw new UnsupportedOperationException("Expect IntegerType");
}

@Override
public String description() {
return "rand_add: add a random integer to the input\n" +
"rand_add(int) -> int";
}

public abstract static class JavaRandomAddBase implements ScalarFunction<Integer> {
@Override
public DataType[] inputTypes() {
return new DataType[] { DataTypes.IntegerType };
}

@Override
public DataType resultType() {
return DataTypes.IntegerType;
}

@Override
public String name() {
return "rand_add";
}

@Override
public boolean isDeterministic() {
return false;
}
}

public static class JavaRandomAddDefault extends JavaRandomAddBase {
private final Random rand = new Random();

@Override
public Integer produceResult(InternalRow input) {
return input.getInt(0) + rand.nextInt();
}
}

public static class JavaRandomAddMagic extends JavaRandomAddBase {
private final Random rand = new Random();

public int invoke(int input) {
return input + rand.nextInt();
}
}

public static class JavaRandomAddStaticMagic extends JavaRandomAddBase {
private static final Random rand = new Random();

public static int invoke(int input) {
return input + rand.nextInt();
}
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public BoundFunction bind(StructType inputType) {
return fn;
}

throw new UnsupportedOperationException("Except StringType");
throw new UnsupportedOperationException("Expect StringType");
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,18 @@ package org.apache.spark.sql.connector
import java.util
import java.util.Collections

import test.org.apache.spark.sql.connector.catalog.functions.{JavaAverage, JavaLongAdd, JavaStrLen}
import test.org.apache.spark.sql.connector.catalog.functions.JavaLongAdd.{JavaLongAddDefault, JavaLongAddMagic, JavaLongAddMismatchMagic, JavaLongAddStaticMagic}
import test.org.apache.spark.sql.connector.catalog.functions._
import test.org.apache.spark.sql.connector.catalog.functions.JavaLongAdd._
import test.org.apache.spark.sql.connector.catalog.functions.JavaRandomAdd._
import test.org.apache.spark.sql.connector.catalog.functions.JavaStrLen._

import org.apache.spark.SparkException
import org.apache.spark.sql.{AnalysisException, Row}
import org.apache.spark.sql.{AnalysisException, DataFrame, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode.{FALLBACK, NO_CODEGEN}
import org.apache.spark.sql.connector.catalog.{BasicInMemoryTableCatalog, Identifier, InMemoryCatalog, SupportsNamespaces}
import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction, _}
import org.apache.spark.sql.execution.ProjectExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -365,6 +367,31 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase {
}
}

test("SPARK-37957: pass deterministic flag when creating V2 function expression") {
def checkDeterministic(df: DataFrame): Unit = {
val result = df.queryExecution.executedPlan.find(_.isInstanceOf[ProjectExec])
assert(result.isDefined, s"Expect to find ProjectExec")
assert(!result.get.asInstanceOf[ProjectExec].projectList.exists(_.deterministic),
"Expect expressions in projectList to be non-deterministic")
}

catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
Seq(new JavaRandomAddDefault, new JavaRandomAddMagic,
new JavaRandomAddStaticMagic).foreach { fn =>
addFunction(Identifier.of(Array("ns"), "rand_add"), new JavaRandomAdd(fn))
checkDeterministic(sql("SELECT testcat.ns.rand_add(42)"))
}

// A function call is non-deterministic if one of its arguments is non-deterministic
Seq(new JavaLongAddDefault(true), new JavaLongAddMagic(true),
new JavaLongAddStaticMagic(true)).foreach { fn =>
addFunction(Identifier.of(Array("ns"), "add"), new JavaLongAdd(fn))
addFunction(Identifier.of(Array("ns"), "rand_add"),
new JavaRandomAdd(new JavaRandomAddDefault))
checkDeterministic(sql("SELECT testcat.ns.add(10, testcat.ns.rand_add(42))"))
}
}

private case class StrLen(impl: BoundFunction) extends UnboundFunction {
override def description(): String =
"""strlen: returns the length of the input string
Expand Down

0 comments on commit 3860ac5

Please sign in to comment.