Skip to content

Commit

Permalink
[FLINK-36706] Refactor TypeInferenceExtractor for PTFs
Browse files Browse the repository at this point in the history
  • Loading branch information
twalthr committed Dec 16, 2024
1 parent 9677350 commit 7bf93f2
Show file tree
Hide file tree
Showing 24 changed files with 2,056 additions and 1,090 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@ import org.apache.flink.table.annotation.{DataTypeHint, FunctionHint}
import org.apache.flink.table.api.DataTypes
import org.apache.flink.table.functions.ScalarFunction
import org.apache.flink.table.types.extraction.TypeInferenceExtractorTest.TestSpec
import org.apache.flink.table.types.inference.{ArgumentTypeStrategy, InputTypeStrategies, TypeStrategies}
import org.apache.flink.table.types.inference.{ArgumentTypeStrategy, InputTypeStrategies, StaticArgument, TypeStrategies}

import org.assertj.core.api.AssertionsForClassTypes.assertThat
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.MethodSource

import java.util
import java.util.{stream, Optional}

import scala.annotation.varargs
Expand All @@ -36,28 +37,24 @@ class TypeInferenceExtractorScalaTest {

@ParameterizedTest
@MethodSource(Array("testData"))
def testArgumentNames(testSpec: TestSpec): Unit = {
if (testSpec.expectedArgumentNames != null) {
assertThat(testSpec.typeInferenceExtraction.get.getNamedArguments)
.isEqualTo(Optional.of(testSpec.expectedArgumentNames))
}
}

@ParameterizedTest
@MethodSource(Array("testData"))
def testArgumentTypes(testSpec: TestSpec): Unit = {
if (testSpec.expectedArgumentTypes != null) {
assertThat(testSpec.typeInferenceExtraction.get.getTypedArguments)
.isEqualTo(Optional.of(testSpec.expectedArgumentTypes))
def testStaticArguments(testSpec: TestSpec): Unit = {
if (testSpec.expectedStaticArguments != null) {
val staticArguments = testSpec.typeInferenceExtraction.get.getStaticArguments
assertThat(staticArguments).isEqualTo(Optional.of(testSpec.expectedStaticArguments))
}
}

@ParameterizedTest
@MethodSource(Array("testData"))
def testOutputTypeStrategy(testSpec: TestSpec): Unit = {
if (!testSpec.expectedOutputStrategies.isEmpty) {
assertThat(testSpec.typeInferenceExtraction.get.getOutputTypeStrategy)
.isEqualTo(TypeStrategies.mapping(testSpec.expectedOutputStrategies))
if (testSpec.expectedOutputStrategies.size == 1) {
assertThat(testSpec.typeInferenceExtraction.get.getOutputTypeStrategy)
.isEqualTo(testSpec.expectedOutputStrategies.values.iterator.next)
} else {
assertThat(testSpec.typeInferenceExtraction.get.getOutputTypeStrategy)
.isEqualTo(TypeStrategies.mapping(testSpec.expectedOutputStrategies))
}
}
}
}
Expand All @@ -68,22 +65,12 @@ object TypeInferenceExtractorScalaTest {
// Scala function with data type hint
TestSpec
.forScalarFunction(classOf[ScalaScalarFunction])
.expectNamedArguments("i", "s", "d")
.expectTypedArguments(
DataTypes.INT.notNull().bridgedTo(classOf[Int]),
DataTypes.STRING,
DataTypes.DECIMAL(10, 4))
.expectOutputMapping(
InputTypeStrategies.sequence(
Array[String]("i", "s", "d"),
Array[ArgumentTypeStrategy](
InputTypeStrategies.explicit(DataTypes.INT.notNull().bridgedTo(classOf[Int])),
InputTypeStrategies.explicit(DataTypes.STRING),
InputTypeStrategies.explicit(DataTypes.DECIMAL(10, 4))
)
),
TypeStrategies.explicit(DataTypes.BOOLEAN.notNull().bridgedTo(classOf[Boolean]))
),
.expectStaticArgument(
StaticArgument.scalar("i", DataTypes.INT.notNull().bridgedTo(classOf[Int]), false))
.expectStaticArgument(StaticArgument.scalar("s", DataTypes.STRING, false))
.expectStaticArgument(StaticArgument.scalar("d", DataTypes.DECIMAL(10, 4), false))
.expectOutput(TypeStrategies.explicit(
DataTypes.BOOLEAN.notNull().bridgedTo(classOf[Boolean]))),
TestSpec
.forScalarFunction(classOf[ScalaPrimitiveVarArgScalarFunction])
.expectOutputMapping(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public enum ArgumentTrait {
*
* <p>It's the default if no {@link ArgumentHint} is provided.
*/
SCALAR(StaticArgumentTrait.SCALAR),
SCALAR(true, StaticArgumentTrait.SCALAR),

/**
* An argument that accepts a table "as row" (i.e. with row semantics). This trait only applies
Expand All @@ -56,7 +56,7 @@ public enum ArgumentTrait {
* can be processed independently. The framework is free in how to distribute rows across
* virtual processors and each virtual processor has access only to the currently processed row.
*/
TABLE_AS_ROW(StaticArgumentTrait.TABLE_AS_ROW),
TABLE_AS_ROW(true, StaticArgumentTrait.TABLE_AS_ROW),

/**
* An argument that accepts a table "as set" (i.e. with set semantics). This trait only applies
Expand All @@ -77,22 +77,28 @@ public enum ArgumentTrait {
* <p>It is also possible not to provide a key ({@link #OPTIONAL_PARTITION_BY}), in which case
* only one virtual processor handles the entire table, thereby losing scalability benefits.
*/
TABLE_AS_SET(StaticArgumentTrait.TABLE_AS_SET),
TABLE_AS_SET(true, StaticArgumentTrait.TABLE_AS_SET),

/**
* Defines that a PARTITION BY clause is optional for {@link #TABLE_AS_SET}. By default, it is
* mandatory for improving the parallel execution by distributing the table by key.
*/
OPTIONAL_PARTITION_BY(StaticArgumentTrait.OPTIONAL_PARTITION_BY, TABLE_AS_SET);
OPTIONAL_PARTITION_BY(false, StaticArgumentTrait.OPTIONAL_PARTITION_BY, TABLE_AS_SET);

private final boolean isRoot;
private final StaticArgumentTrait staticTrait;
private final Set<ArgumentTrait> requirements;

ArgumentTrait(StaticArgumentTrait staticTrait, ArgumentTrait... requirements) {
ArgumentTrait(boolean isRoot, StaticArgumentTrait staticTrait, ArgumentTrait... requirements) {
this.isRoot = isRoot;
this.staticTrait = staticTrait;
this.requirements = Arrays.stream(requirements).collect(Collectors.toSet());
}

public boolean isRoot() {
return isRoot;
}

public Set<ArgumentTrait> getRequirements() {
return requirements;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.flink.table.annotation.DataTypeHint;
import org.apache.flink.table.annotation.FunctionHint;
import org.apache.flink.table.catalog.DataTypeFactory;
import org.apache.flink.table.types.extraction.TypeInferenceExtractor;
import org.apache.flink.table.types.inference.TypeInference;
import org.apache.flink.util.Collector;

Expand Down Expand Up @@ -225,8 +226,9 @@ public final FunctionKind getKind() {
}

@Override
@SuppressWarnings({"unchecked", "rawtypes"})
public TypeInference getTypeInference(DataTypeFactory typeFactory) {
throw new UnsupportedOperationException("Type inference is not implemented yet.");
return TypeInferenceExtractor.forProcessTableFunction(typeFactory, (Class) getClass());
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ public final class UserDefinedFunctionHelper {

public static final String ASYNC_TABLE_EVAL = "eval";

public static final String PROCESS_TABLE_EVAL = "eval";

/**
* Tries to infer the TypeInformation of an AggregateFunction's accumulator type.
*
Expand Down Expand Up @@ -320,9 +322,12 @@ public static void validateClassForRuntime(
methods.stream()
.anyMatch(
method ->
ExtractionUtils.isInvokable(method, argumentClasses)
ExtractionUtils.isInvokable(false, method, argumentClasses)
&& ExtractionUtils.isAssignable(
outputClass, method.getReturnType(), true));
outputClass,
method.getReturnType(),
true,
false));
if (!isMatching) {
throw new ValidationException(
String.format(
Expand Down
Loading

0 comments on commit 7bf93f2

Please sign in to comment.