From 7bf93f2992fa985f1880b4b1b2ab03e99a221f63 Mon Sep 17 00:00:00 2001 From: Timo Walther Date: Mon, 16 Dec 2024 21:31:06 +0100 Subject: [PATCH] [FLINK-36706] Refactor TypeInferenceExtractor for PTFs --- .../TypeInferenceExtractorScalaTest.scala | 51 +- .../flink/table/annotation/ArgumentTrait.java | 16 +- .../table/functions/ProcessTableFunction.java | 4 +- .../functions/UserDefinedFunctionHelper.java | 9 +- .../extraction/BaseMappingExtractor.java | 583 +++++++--- .../types/extraction/DataTypeExtractor.java | 12 +- .../types/extraction/ExtractionUtils.java | 56 +- .../extraction/FunctionArgumentTemplate.java | 31 +- .../extraction/FunctionMappingExtractor.java | 280 +++-- .../extraction/FunctionResultTemplate.java | 131 ++- .../extraction/FunctionSignatureTemplate.java | 83 +- .../types/extraction/FunctionTemplate.java | 239 +++- .../extraction/ProcedureMappingExtractor.java | 93 +- .../extraction/TypeInferenceExtractor.java | 278 +++-- .../table/types/inference/StaticArgument.java | 133 ++- .../types/inference/StaticArgumentTrait.java | 20 +- .../table/types/inference/TypeInference.java | 21 + .../extraction/DataTypeExtractorTest.java | 3 +- .../TypeInferenceExtractorTest.java | 1025 ++++++++++------- .../TypeInferenceOperandChecker.java | 50 +- .../PlannerCallProcedureOperation.java | 5 +- .../SqlNodeToCallOperationTest.java | 2 +- .../utils/userDefinedScalarFunctions.scala | 2 +- .../utils/UserDefinedFunctionTestUtils.scala | 19 +- 24 files changed, 2056 insertions(+), 1090 deletions(-) diff --git a/flink-table/flink-table-api-scala/src/test/scala/org/apache/flink/table/types/extraction/TypeInferenceExtractorScalaTest.scala b/flink-table/flink-table-api-scala/src/test/scala/org/apache/flink/table/types/extraction/TypeInferenceExtractorScalaTest.scala index bf79a1eb7d92f8..f2286754f6b4e4 100644 --- a/flink-table/flink-table-api-scala/src/test/scala/org/apache/flink/table/types/extraction/TypeInferenceExtractorScalaTest.scala +++ b/flink-table/flink-table-api-scala/src/test/scala/org/apache/flink/table/types/extraction/TypeInferenceExtractorScalaTest.scala @@ -21,12 +21,13 @@ import org.apache.flink.table.annotation.{DataTypeHint, FunctionHint} import org.apache.flink.table.api.DataTypes import org.apache.flink.table.functions.ScalarFunction import org.apache.flink.table.types.extraction.TypeInferenceExtractorTest.TestSpec -import org.apache.flink.table.types.inference.{ArgumentTypeStrategy, InputTypeStrategies, TypeStrategies} +import org.apache.flink.table.types.inference.{ArgumentTypeStrategy, InputTypeStrategies, StaticArgument, TypeStrategies} import org.assertj.core.api.AssertionsForClassTypes.assertThat import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.MethodSource +import java.util import java.util.{stream, Optional} import scala.annotation.varargs @@ -36,19 +37,10 @@ class TypeInferenceExtractorScalaTest { @ParameterizedTest @MethodSource(Array("testData")) - def testArgumentNames(testSpec: TestSpec): Unit = { - if (testSpec.expectedArgumentNames != null) { - assertThat(testSpec.typeInferenceExtraction.get.getNamedArguments) - .isEqualTo(Optional.of(testSpec.expectedArgumentNames)) - } - } - - @ParameterizedTest - @MethodSource(Array("testData")) - def testArgumentTypes(testSpec: TestSpec): Unit = { - if (testSpec.expectedArgumentTypes != null) { - assertThat(testSpec.typeInferenceExtraction.get.getTypedArguments) - .isEqualTo(Optional.of(testSpec.expectedArgumentTypes)) + def testStaticArguments(testSpec: TestSpec): Unit = { + if (testSpec.expectedStaticArguments != null) { + val staticArguments = testSpec.typeInferenceExtraction.get.getStaticArguments + assertThat(staticArguments).isEqualTo(Optional.of(testSpec.expectedStaticArguments)) } } @@ -56,8 +48,13 @@ class TypeInferenceExtractorScalaTest { @MethodSource(Array("testData")) def testOutputTypeStrategy(testSpec: TestSpec): Unit = { if (!testSpec.expectedOutputStrategies.isEmpty) { - assertThat(testSpec.typeInferenceExtraction.get.getOutputTypeStrategy) - .isEqualTo(TypeStrategies.mapping(testSpec.expectedOutputStrategies)) + if (testSpec.expectedOutputStrategies.size == 1) { + assertThat(testSpec.typeInferenceExtraction.get.getOutputTypeStrategy) + .isEqualTo(testSpec.expectedOutputStrategies.values.iterator.next) + } else { + assertThat(testSpec.typeInferenceExtraction.get.getOutputTypeStrategy) + .isEqualTo(TypeStrategies.mapping(testSpec.expectedOutputStrategies)) + } } } } @@ -68,22 +65,12 @@ object TypeInferenceExtractorScalaTest { // Scala function with data type hint TestSpec .forScalarFunction(classOf[ScalaScalarFunction]) - .expectNamedArguments("i", "s", "d") - .expectTypedArguments( - DataTypes.INT.notNull().bridgedTo(classOf[Int]), - DataTypes.STRING, - DataTypes.DECIMAL(10, 4)) - .expectOutputMapping( - InputTypeStrategies.sequence( - Array[String]("i", "s", "d"), - Array[ArgumentTypeStrategy]( - InputTypeStrategies.explicit(DataTypes.INT.notNull().bridgedTo(classOf[Int])), - InputTypeStrategies.explicit(DataTypes.STRING), - InputTypeStrategies.explicit(DataTypes.DECIMAL(10, 4)) - ) - ), - TypeStrategies.explicit(DataTypes.BOOLEAN.notNull().bridgedTo(classOf[Boolean])) - ), + .expectStaticArgument( + StaticArgument.scalar("i", DataTypes.INT.notNull().bridgedTo(classOf[Int]), false)) + .expectStaticArgument(StaticArgument.scalar("s", DataTypes.STRING, false)) + .expectStaticArgument(StaticArgument.scalar("d", DataTypes.DECIMAL(10, 4), false)) + .expectOutput(TypeStrategies.explicit( + DataTypes.BOOLEAN.notNull().bridgedTo(classOf[Boolean]))), TestSpec .forScalarFunction(classOf[ScalaPrimitiveVarArgScalarFunction]) .expectOutputMapping( diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/annotation/ArgumentTrait.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/annotation/ArgumentTrait.java index df44b5a64f7be7..fe161faac7cb7d 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/annotation/ArgumentTrait.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/annotation/ArgumentTrait.java @@ -40,7 +40,7 @@ public enum ArgumentTrait { * *

It's the default if no {@link ArgumentHint} is provided. */ - SCALAR(StaticArgumentTrait.SCALAR), + SCALAR(true, StaticArgumentTrait.SCALAR), /** * An argument that accepts a table "as row" (i.e. with row semantics). This trait only applies @@ -56,7 +56,7 @@ public enum ArgumentTrait { * can be processed independently. The framework is free in how to distribute rows across * virtual processors and each virtual processor has access only to the currently processed row. */ - TABLE_AS_ROW(StaticArgumentTrait.TABLE_AS_ROW), + TABLE_AS_ROW(true, StaticArgumentTrait.TABLE_AS_ROW), /** * An argument that accepts a table "as set" (i.e. with set semantics). This trait only applies @@ -77,22 +77,28 @@ public enum ArgumentTrait { *

It is also possible not to provide a key ({@link #OPTIONAL_PARTITION_BY}), in which case * only one virtual processor handles the entire table, thereby losing scalability benefits. */ - TABLE_AS_SET(StaticArgumentTrait.TABLE_AS_SET), + TABLE_AS_SET(true, StaticArgumentTrait.TABLE_AS_SET), /** * Defines that a PARTITION BY clause is optional for {@link #TABLE_AS_SET}. By default, it is * mandatory for improving the parallel execution by distributing the table by key. */ - OPTIONAL_PARTITION_BY(StaticArgumentTrait.OPTIONAL_PARTITION_BY, TABLE_AS_SET); + OPTIONAL_PARTITION_BY(false, StaticArgumentTrait.OPTIONAL_PARTITION_BY, TABLE_AS_SET); + private final boolean isRoot; private final StaticArgumentTrait staticTrait; private final Set requirements; - ArgumentTrait(StaticArgumentTrait staticTrait, ArgumentTrait... requirements) { + ArgumentTrait(boolean isRoot, StaticArgumentTrait staticTrait, ArgumentTrait... requirements) { + this.isRoot = isRoot; this.staticTrait = staticTrait; this.requirements = Arrays.stream(requirements).collect(Collectors.toSet()); } + public boolean isRoot() { + return isRoot; + } + public Set getRequirements() { return requirements; } diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/ProcessTableFunction.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/ProcessTableFunction.java index 1e2e215dca8fcd..4dd05ebb1cbb14 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/ProcessTableFunction.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/ProcessTableFunction.java @@ -24,6 +24,7 @@ import org.apache.flink.table.annotation.DataTypeHint; import org.apache.flink.table.annotation.FunctionHint; import org.apache.flink.table.catalog.DataTypeFactory; +import org.apache.flink.table.types.extraction.TypeInferenceExtractor; import org.apache.flink.table.types.inference.TypeInference; import org.apache.flink.util.Collector; @@ -225,8 +226,9 @@ public final FunctionKind getKind() { } @Override + @SuppressWarnings({"unchecked", "rawtypes"}) public TypeInference getTypeInference(DataTypeFactory typeFactory) { - throw new UnsupportedOperationException("Type inference is not implemented yet."); + return TypeInferenceExtractor.forProcessTableFunction(typeFactory, (Class) getClass()); } /** diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/UserDefinedFunctionHelper.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/UserDefinedFunctionHelper.java index 354146d70c3717..23307852cbdc6c 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/UserDefinedFunctionHelper.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/UserDefinedFunctionHelper.java @@ -92,6 +92,8 @@ public final class UserDefinedFunctionHelper { public static final String ASYNC_TABLE_EVAL = "eval"; + public static final String PROCESS_TABLE_EVAL = "eval"; + /** * Tries to infer the TypeInformation of an AggregateFunction's accumulator type. * @@ -320,9 +322,12 @@ public static void validateClassForRuntime( methods.stream() .anyMatch( method -> - ExtractionUtils.isInvokable(method, argumentClasses) + ExtractionUtils.isInvokable(false, method, argumentClasses) && ExtractionUtils.isAssignable( - outputClass, method.getReturnType(), true)); + outputClass, + method.getReturnType(), + true, + false)); if (!isMatching) { throw new ValidationException( String.format( diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/BaseMappingExtractor.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/BaseMappingExtractor.java index 8e3aab464e9351..2e849eb1d68e23 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/BaseMappingExtractor.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/BaseMappingExtractor.java @@ -21,13 +21,21 @@ import org.apache.flink.table.annotation.ArgumentHint; import org.apache.flink.table.annotation.ArgumentTrait; import org.apache.flink.table.annotation.DataTypeHint; +import org.apache.flink.table.annotation.StateHint; import org.apache.flink.table.api.DataTypes; import org.apache.flink.table.api.ValidationException; import org.apache.flink.table.catalog.DataTypeFactory; +import org.apache.flink.table.data.RowData; import org.apache.flink.table.functions.UserDefinedFunction; import org.apache.flink.table.procedures.Procedure; import org.apache.flink.table.types.CollectionDataType; import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.extraction.FunctionResultTemplate.FunctionOutputTemplate; +import org.apache.flink.table.types.extraction.FunctionResultTemplate.FunctionStateTemplate; +import org.apache.flink.table.types.inference.StaticArgumentTrait; +import org.apache.flink.types.Row; + +import org.apache.commons.lang3.ArrayUtils; import javax.annotation.Nullable; @@ -36,6 +44,7 @@ import java.lang.reflect.ParameterizedType; import java.lang.reflect.Type; import java.util.Arrays; +import java.util.EnumSet; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -65,27 +74,110 @@ abstract class BaseMappingExtractor { protected final DataTypeFactory typeFactory; - private final String methodName; + protected final String methodName; private final SignatureExtraction signatureExtraction; protected final ResultExtraction outputExtraction; - protected final MethodVerification verification; + protected final MethodVerification outputVerification; public BaseMappingExtractor( DataTypeFactory typeFactory, String methodName, SignatureExtraction signatureExtraction, ResultExtraction outputExtraction, - MethodVerification verification) { + MethodVerification outputVerification) { this.typeFactory = typeFactory; this.methodName = methodName; this.signatureExtraction = signatureExtraction; this.outputExtraction = outputExtraction; - this.verification = verification; + this.outputVerification = outputVerification; + } + + Map extractOutputMapping() { + try { + return extractResultMappings( + outputExtraction, FunctionTemplate::getOutputTemplate, outputVerification); + } catch (Throwable t) { + throw extractionError(t, "Error in extracting a signature to output mapping."); + } + } + + // -------------------------------------------------------------------------------------------- + // Extraction strategies + // -------------------------------------------------------------------------------------------- + + /** + * Extraction that uses the method parameters for producing a {@link FunctionSignatureTemplate}. + */ + static SignatureExtraction createArgumentsFromParametersExtraction( + int offset, @Nullable Class contextClass) { + return (extractor, method) -> { + final List args = + extractArgumentParameters(method, offset, contextClass); + + final EnumSet[] argumentTraits = extractArgumentTraits(args); + + final List argumentTemplates = + extractArgumentTemplates( + extractor.typeFactory, extractor.getFunctionClass(), args); + + final String[] argumentNames = extractArgumentNames(method, args); + + final boolean[] argumentOptionals = extractArgumentOptionals(args); + + return FunctionSignatureTemplate.of( + argumentTemplates, + method.isVarArgs(), + argumentTraits, + argumentNames, + argumentOptionals); + }; + } + + /** + * Extraction that uses the method parameters for producing a {@link FunctionSignatureTemplate}. + */ + static SignatureExtraction createArgumentsFromParametersExtraction(int offset) { + return createArgumentsFromParametersExtraction(offset, null); + } + + /** Extraction that uses the method parameters with {@link StateHint} for state entries. */ + static ResultExtraction createStateFromParametersExtraction() { + return (extractor, method) -> { + final List stateParameters = extractStateParameters(method); + return createStateTemplateFromParameters(extractor, method, stateParameters); + }; + } + + /** + * Extraction that uses a generic type variable for producing a {@link FunctionStateTemplate}. + * Or method parameters with {@link StateHint} for state entries as a fallback. + */ + static ResultExtraction createStateFromGenericInClassOrParameters( + Class baseClass, int genericPos) { + return (extractor, method) -> { + final List stateParameters = extractStateParameters(method); + if (stateParameters.isEmpty()) { + final DataType dataType = + DataTypeExtractor.extractFromGeneric( + extractor.typeFactory, + baseClass, + genericPos, + extractor.getFunctionClass()); + final LinkedHashMap state = new LinkedHashMap<>(); + state.put("acc", dataType); + return FunctionResultTemplate.ofState(state); + } + return createStateTemplateFromParameters(extractor, method, stateParameters); + }; } + // -------------------------------------------------------------------------------------------- + // Methods for subclasses + // -------------------------------------------------------------------------------------------- + protected abstract Set extractGlobalFunctionTemplates(); protected abstract Set extractLocalFunctionTemplates(Method method); @@ -96,27 +188,35 @@ public BaseMappingExtractor( protected abstract String getHintType(); - Map extractOutputMapping() { - try { - return extractResultMappings( - outputExtraction, FunctionTemplate::getOutputTemplate, verification); - } catch (Throwable t) { - throw extractionError(t, "Error in extracting a signature to output mapping."); - } + protected static Class[] assembleParameters(List> state, List> arguments) { + return Stream.concat(state.stream(), arguments.stream()).toArray(Class[]::new); + } + + protected static ValidationException createMethodNotFoundError( + String methodName, + Class[] parameters, + @Nullable Class returnType, + String pattern) { + return extractionError( + "Considering all hints, the method should comply with the signature:\n%s%s", + createMethodSignatureString(methodName, parameters, returnType), + pattern.isEmpty() ? "" : "\nPattern: " + pattern); } /** - * Extracts mappings from signature to result (either accumulator or output) for the entire - * function. Verifies if the extracted inference matches with the implementation. + * Extracts mappings from signature to result (either state or output) for the entire function. + * Verifies if the extracted inference matches with the implementation. * *

For example, from {@code (INT, BOOLEAN, ANY) -> INT}. It does this by going through all * implementation methods and collecting all "per-method" mappings. The function mapping is the * union of all "per-method" mappings. */ - protected Map extractResultMappings( - ResultExtraction resultExtraction, - Function accessor, - MethodVerification verification) { + @SuppressWarnings("unchecked") + protected + Map extractResultMappings( + ResultExtraction resultExtraction, + Function accessor, + @Nullable MethodVerification verification) { final Set global = extractGlobalFunctionTemplates(); final Set globalResultOnly = findResultOnlyTemplates(global, accessor); @@ -125,7 +225,7 @@ protected Map extractResultMa final Map collectedMappings = new LinkedHashMap<>(); final List methods = collectMethods(methodName); - if (methods.size() == 0) { + if (methods.isEmpty()) { throw extractionError( "Could not find a publicly accessible method named '%s'.", methodName); } @@ -145,9 +245,6 @@ protected Map extractResultMa // check if the method can be called verifyMappingForMethod(correctMethod, collectedMappingsPerMethod, verification); - // check if we declare optional on a primitive type parameter - verifyOptionalOnPrimitiveParameter(correctMethod, collectedMappingsPerMethod); - // check if method strategies conflict with function strategies collectedMappingsPerMethod.forEach( (signature, result) -> putMapping(collectedMappings, signature, result)); @@ -158,48 +255,55 @@ protected Map extractResultMa method.toString()); } } - return collectedMappings; + return (Map) collectedMappings; } - /** - * Special case for Scala which generates two methods when using var-args (a {@code Seq < String - * >} and {@code String...}). This method searches for the Java-like variant. - */ - static Method correctVarArgMethod(Method method) { - final int paramCount = method.getParameterCount(); - final Class[] paramClasses = method.getParameterTypes(); - if (paramCount > 0 - && paramClasses[paramCount - 1].getName().equals("scala.collection.Seq")) { - final Type[] paramTypes = method.getGenericParameterTypes(); - final ParameterizedType seqType = (ParameterizedType) paramTypes[paramCount - 1]; - final Type varArgType = seqType.getActualTypeArguments()[0]; - return ExtractionUtils.collectMethods(method.getDeclaringClass(), method.getName()) - .stream() - .filter(Method::isVarArgs) - .filter(candidate -> candidate.getParameterCount() == paramCount) - .filter( - candidate -> { - final Type[] candidateParamTypes = - candidate.getGenericParameterTypes(); - for (int i = 0; i < paramCount - 1; i++) { - if (candidateParamTypes[i] != paramTypes[i]) { - return false; - } - } - final Class candidateVarArgType = - candidate.getParameterTypes()[paramCount - 1]; - return candidateVarArgType.isArray() - && - // check for Object is needed in case of Scala primitives - // (e.g. Int) - (varArgType == Object.class - || candidateVarArgType.getComponentType() - == varArgType); - }) - .findAny() - .orElse(method); + protected static void checkNoState(@Nullable List> state) { + if (state != null && !state.isEmpty()) { + throw extractionError("State is not supported for this kind of function."); + } + } + + protected static void checkSingleState(@Nullable List> state) { + if (state == null || state.size() != 1) { + throw extractionError( + "Aggregating functions need exactly one state entry for the accumulator."); } - return method; + } + + // -------------------------------------------------------------------------------------------- + // Helper methods + // -------------------------------------------------------------------------------------------- + + private static FunctionStateTemplate createStateTemplateFromParameters( + BaseMappingExtractor extractor, Method method, List stateParameters) { + final String[] argumentNames = extractStateNames(method, stateParameters); + if (argumentNames == null) { + throw extractionError("Unable to extract names for all state entries."); + } + + final List dataTypes = + stateParameters.stream() + .map( + s -> + DataTypeExtractor.extractFromMethodParameter( + extractor.typeFactory, + extractor.getFunctionClass(), + s.method, + s.pos)) + .collect(Collectors.toList()); + + final LinkedHashMap state = + IntStream.range(0, dataTypes.size()) + .mapToObj(i -> Map.entry(argumentNames[i], dataTypes.get(i))) + .collect( + Collectors.toMap( + Map.Entry::getKey, + Map.Entry::getValue, + (o, n) -> o, + LinkedHashMap::new)); + + return FunctionResultTemplate.ofState(state); } /** @@ -247,9 +351,47 @@ private Map collectMethodMapp return collectedMappingsPerMethod; } - // -------------------------------------------------------------------------------------------- - // Helper methods (ordered by invocation order) - // -------------------------------------------------------------------------------------------- + /** + * Special case for Scala which generates two methods when using var-args (a {@code Seq < String + * >} and {@code String...}). This method searches for the Java-like variant. + */ + private static Method correctVarArgMethod(Method method) { + final int paramCount = method.getParameterCount(); + final Class[] paramClasses = method.getParameterTypes(); + if (paramCount > 0 + && paramClasses[paramCount - 1].getName().equals("scala.collection.Seq")) { + final Type[] paramTypes = method.getGenericParameterTypes(); + final ParameterizedType seqType = (ParameterizedType) paramTypes[paramCount - 1]; + final Type varArgType = seqType.getActualTypeArguments()[0]; + return ExtractionUtils.collectMethods(method.getDeclaringClass(), method.getName()) + .stream() + .filter(Method::isVarArgs) + .filter(candidate -> candidate.getParameterCount() == paramCount) + .filter( + candidate -> { + final Type[] candidateParamTypes = + candidate.getGenericParameterTypes(); + for (int i = 0; i < paramCount - 1; i++) { + if (candidateParamTypes[i] != paramTypes[i]) { + return false; + } + } + final Class candidateVarArgType = + candidate.getParameterTypes()[paramCount - 1]; + return candidateVarArgType.isArray() + && + // check for Object is needed in case of Scala primitives + // (e.g. Int) + (varArgType == Object.class + || candidateVarArgType.getComponentType() + == varArgType); + }) + .findAny() + .orElse(method); + } + return method; + } + /** Explicit mappings with complete signature to result declaration. */ private void putExplicitMappings( Map collectedMappings, @@ -322,166 +464,259 @@ else if (!existingResult.equals(result)) { private void verifyMappingForMethod( Method method, Map collectedMappingsPerMethod, - MethodVerification verification) { + @Nullable MethodVerification verification) { + if (verification == null) { + return; + } collectedMappingsPerMethod.forEach( - (signature, result) -> - verification.verify(method, signature.toClass(), result.toClass())); - } - - private void verifyOptionalOnPrimitiveParameter( - Method method, - Map collectedMappingsPerMethod) { - collectedMappingsPerMethod - .keySet() - .forEach( - signature -> { - Boolean[] argumentOptional = signature.argumentOptionals; - // Here we restrict that the argument must contain optional parameters - // in order to obtain the FunctionSignatureTemplate of the method for - // verification. Therefore, the extract method will only be called once. - // If no function hint is set, this verify will not be executed. - if (argumentOptional != null - && Arrays.stream(argumentOptional) - .anyMatch(Boolean::booleanValue)) { - FunctionSignatureTemplate functionResultTemplate = - signatureExtraction.extract(this, method); - for (int i = 0; i < argumentOptional.length; i++) { - DataType dataType = - functionResultTemplate.argumentTemplates.get(i) - .dataType; - if (dataType != null - && argumentOptional[i] - && dataType.getConversionClass() != null - && dataType.getConversionClass().isPrimitive()) { - throw extractionError( - "Argument at position %d is optional but a primitive type doesn't accept null value.", - i); - } - } - } - }); + (signature, result) -> { + if (result instanceof FunctionStateTemplate) { + final FunctionStateTemplate stateTemplate = (FunctionStateTemplate) result; + verification.verify( + method, stateTemplate.toClassList(), signature.toClassList(), null); + } else if (result instanceof FunctionOutputTemplate) { + final FunctionOutputTemplate outputTemplate = + (FunctionOutputTemplate) result; + verification.verify( + method, + List.of(), + signature.toClassList(), + outputTemplate.toClass()); + } + }); } // -------------------------------------------------------------------------------------------- - // Context sensitive extraction and verification logic + // Parameters extraction (i.e. state and arguments) // -------------------------------------------------------------------------------------------- - /** - * Extraction that uses the method parameters for producing a {@link FunctionSignatureTemplate}. - */ - static SignatureExtraction createParameterSignatureExtraction(int offset) { - return (extractor, method) -> { - final List parameterTypes = - extractArgumentTemplates( - extractor.typeFactory, extractor.getFunctionClass(), method, offset); + /** Method parameter that qualifies as a function argument (i.e. not a context or state). */ + private static class ArgumentParameter { + final Parameter parameter; + final Method method; + // Pos in the method, not necessarily in the extracted function + final int pos; + + private ArgumentParameter(Parameter parameter, Method method, int pos) { + this.parameter = parameter; + this.method = method; + this.pos = pos; + } + } - final String[] argumentNames = extractArgumentNames(method, offset); + /** Method parameter that qualifies as a function state (i.e. not a context or argument). */ + private static class StateParameter { + final Parameter parameter; + final Method method; + // Pos in the method, not necessarily in the extracted function + final int pos; + + private StateParameter(Parameter parameter, Method method, int pos) { + this.parameter = parameter; + this.method = method; + this.pos = pos; + } + } - final Boolean[] argumentOptionals = extractArgumentOptionals(method, offset); + private static List extractArgumentParameters( + Method method, int offset, @Nullable Class contextClass) { + final Parameter[] parameters = method.getParameters(); + return IntStream.range(0, parameters.length) + .mapToObj( + pos -> { + final Parameter parameter = parameters[pos]; + return new ArgumentParameter(parameter, method, pos); + }) + .skip(offset) + .filter(arg -> contextClass == null || arg.parameter.getType() != contextClass) + .filter(arg -> arg.parameter.getAnnotation(StateHint.class) == null) + .collect(Collectors.toList()); + } - return FunctionSignatureTemplate.of( - parameterTypes, method.isVarArgs(), argumentNames, argumentOptionals); - }; + private static List extractStateParameters(Method method) { + final Parameter[] parameters = method.getParameters(); + return IntStream.range(0, parameters.length) + .mapToObj( + pos -> { + final Parameter parameter = parameters[pos]; + return new StateParameter(parameter, method, pos); + }) + .filter(arg -> arg.parameter.getAnnotation(StateHint.class) != null) + .collect(Collectors.toList()); } private static List extractArgumentTemplates( - DataTypeFactory typeFactory, Class extractedClass, Method method, int offset) { - return IntStream.range(offset, method.getParameterCount()) - .mapToObj( - i -> + DataTypeFactory typeFactory, Class extractedClass, List args) { + return args.stream() + .map( + arg -> // check for input group before start extracting a data type - tryExtractInputGroupArgument(method, i) + tryExtractInputGroupArgument(arg) .orElseGet( () -> - extractDataTypeArgument( - typeFactory, - extractedClass, - method, - i))) + extractArgumentByKind( + typeFactory, extractedClass, arg))) .collect(Collectors.toList()); } - static Optional tryExtractInputGroupArgument( - Method method, int paramPos) { - final Parameter parameter = method.getParameters()[paramPos]; - final DataTypeHint hint = parameter.getAnnotation(DataTypeHint.class); - final ArgumentHint argumentHint = parameter.getAnnotation(ArgumentHint.class); - if (hint != null && argumentHint != null) { + private static Optional tryExtractInputGroupArgument( + ArgumentParameter arg) { + final DataTypeHint dataTypehint = arg.parameter.getAnnotation(DataTypeHint.class); + final ArgumentHint argumentHint = arg.parameter.getAnnotation(ArgumentHint.class); + if (dataTypehint != null && argumentHint != null) { throw extractionError( - "Argument and dataType hints cannot be declared in the same parameter at position %d.", - paramPos); + "Argument and data type hints cannot be declared at the same time at position %d.", + arg.pos); } if (argumentHint != null) { final DataTypeTemplate template = DataTypeTemplate.fromAnnotation(argumentHint, null); if (template.inputGroup != null) { - return Optional.of(FunctionArgumentTemplate.of(template.inputGroup)); + return Optional.of(FunctionArgumentTemplate.ofInputGroup(template.inputGroup)); } - } else if (hint != null) { - final DataTypeTemplate template = DataTypeTemplate.fromAnnotation(hint, null); + } else if (dataTypehint != null) { + final DataTypeTemplate template = DataTypeTemplate.fromAnnotation(dataTypehint, null); if (template.inputGroup != null) { - return Optional.of(FunctionArgumentTemplate.of(template.inputGroup)); + return Optional.of(FunctionArgumentTemplate.ofInputGroup(template.inputGroup)); } } return Optional.empty(); } - private static FunctionArgumentTemplate extractDataTypeArgument( - DataTypeFactory typeFactory, Class extractedClass, Method method, int paramPos) { + private static FunctionArgumentTemplate extractArgumentByKind( + DataTypeFactory typeFactory, Class extractedClass, ArgumentParameter arg) { + final Parameter parameter = arg.parameter; + final ArgumentHint argumentHint = parameter.getAnnotation(ArgumentHint.class); + final int pos = arg.pos; + final Set rootTrait = + Optional.ofNullable(argumentHint) + .map( + hint -> + Arrays.stream(hint.value()) + .filter(ArgumentTrait::isRoot) + .collect(Collectors.toSet())) + .orElse(Set.of(ArgumentTrait.SCALAR)); + if (rootTrait.size() != 1) { + throw extractionError( + "Incorrect argument kind at position %d. Argument kind must be one of: %s", + pos, + Arrays.stream(ArgumentTrait.values()) + .filter(ArgumentTrait::isRoot) + .collect(Collectors.toList())); + } + + if (rootTrait.contains(ArgumentTrait.SCALAR)) { + return extractScalarArgument(typeFactory, extractedClass, arg); + } else if (rootTrait.contains(ArgumentTrait.TABLE_AS_ROW) + || rootTrait.contains(ArgumentTrait.TABLE_AS_SET)) { + return extractTableArgument(typeFactory, argumentHint, extractedClass, arg); + } else { + throw extractionError("Unknown argument kind."); + } + } + + private static FunctionArgumentTemplate extractTableArgument( + DataTypeFactory typeFactory, + ArgumentHint argumentHint, + Class extractedClass, + ArgumentParameter arg) { + try { + final DataType type = + DataTypeExtractor.extractFromMethodParameter( + typeFactory, extractedClass, arg.method, arg.pos); + return FunctionArgumentTemplate.ofDataType(type); + } catch (Throwable t) { + final Class paramClass = arg.parameter.getType(); + final Class argClass = argumentHint.type().bridgedTo(); + if (argClass == Row.class || argClass == RowData.class) { + return FunctionArgumentTemplate.ofTable(argClass); + } + if (paramClass == Row.class || paramClass == RowData.class) { + return FunctionArgumentTemplate.ofTable(paramClass); + } + // Just a regular error for a typed argument + throw t; + } + } + + private static FunctionArgumentTemplate extractScalarArgument( + DataTypeFactory typeFactory, Class extractedClass, ArgumentParameter arg) { final DataType type = DataTypeExtractor.extractFromMethodParameter( - typeFactory, extractedClass, method, paramPos); + typeFactory, extractedClass, arg.method, arg.pos); // unwrap data type in case of varargs - if (method.isVarArgs() && paramPos == method.getParameterCount() - 1) { + if (arg.parameter.isVarArgs()) { // for ARRAY if (type instanceof CollectionDataType) { - return FunctionArgumentTemplate.of( + return FunctionArgumentTemplate.ofDataType( ((CollectionDataType) type).getElementDataType()); } // special case for varargs that have been misinterpreted as BYTES else if (type.equals(DataTypes.BYTES())) { - return FunctionArgumentTemplate.of( + return FunctionArgumentTemplate.ofDataType( DataTypes.TINYINT().notNull().bridgedTo(byte.class)); } } - return FunctionArgumentTemplate.of(type); + return FunctionArgumentTemplate.ofDataType(type); } - static @Nullable String[] extractArgumentNames(Method method, int offset) { + @SuppressWarnings("unchecked") + private static EnumSet[] extractArgumentTraits( + List args) { + return args.stream() + .map( + arg -> { + final ArgumentHint argumentHint = + arg.parameter.getAnnotation(ArgumentHint.class); + if (argumentHint == null) { + return EnumSet.of(StaticArgumentTrait.SCALAR); + } + final List traits = + Arrays.stream(argumentHint.value()) + .map(ArgumentTrait::toStaticTrait) + .collect(Collectors.toList()); + return EnumSet.copyOf(traits); + }) + .toArray(EnumSet[]::new); + } + + private static @Nullable String[] extractArgumentNames( + Method method, List args) { final List methodParameterNames = ExtractionUtils.extractMethodParameterNames(method); if (methodParameterNames != null) { - return methodParameterNames - .subList(offset, methodParameterNames.size()) - .toArray(new String[0]); + return args.stream() + .map(arg -> methodParameterNames.get(arg.pos)) + .toArray(String[]::new); } else { return null; } } - static Boolean[] extractArgumentOptionals(Method method, int offset) { - return Arrays.stream(method.getParameters()) - .skip(offset) - .map(parameter -> parameter.getAnnotation(ArgumentHint.class)) - .map( - h -> { - if (h == null) { - return false; - } - final ArgumentTrait[] traits = h.value(); - if (traits.length != 1 || traits[0] != ArgumentTrait.SCALAR) { - throw extractionError( - "Only scalar arguments are supported so far."); - } - return h.isOptional(); - }) - .toArray(Boolean[]::new); + private static @Nullable String[] extractStateNames(Method method, List state) { + final List methodParameterNames = + ExtractionUtils.extractMethodParameterNames(method); + if (methodParameterNames != null) { + return state.stream() + .map(arg -> methodParameterNames.get(arg.pos)) + .toArray(String[]::new); + } else { + return null; + } } - protected static ValidationException createMethodNotFoundError( - String methodName, Class[] parameters, @Nullable Class returnType) { - return extractionError( - "Considering all hints, the method should comply with the signature:\n%s", - createMethodSignatureString(methodName, parameters, returnType)); + private static boolean[] extractArgumentOptionals(List args) { + final Boolean[] argumentOptionals = + args.stream() + .map(arg -> arg.parameter.getAnnotation(ArgumentHint.class)) + .map( + hint -> { + if (hint == null) { + return false; + } + return hint.isOptional(); + }) + .toArray(Boolean[]::new); + return ArrayUtils.toPrimitive(argumentOptionals); } // -------------------------------------------------------------------------------------------- @@ -489,18 +724,22 @@ protected static ValidationException createMethodNotFoundError( // -------------------------------------------------------------------------------------------- /** Extracts a {@link FunctionSignatureTemplate} from a method. */ - protected interface SignatureExtraction { + interface SignatureExtraction { FunctionSignatureTemplate extract(BaseMappingExtractor extractor, Method method); } /** Extracts a {@link FunctionResultTemplate} from a class or method. */ - protected interface ResultExtraction { + interface ResultExtraction { @Nullable FunctionResultTemplate extract(BaseMappingExtractor extractor, Method method); } /** Verifies the signature of a method. */ protected interface MethodVerification { - void verify(Method method, List> arguments, Class result); + void verify( + Method method, + List> state, + List> arguments, + @Nullable Class result); } } diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/DataTypeExtractor.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/DataTypeExtractor.java index 203b717ec117ca..e3db03d8b2c0c4 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/DataTypeExtractor.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/DataTypeExtractor.java @@ -22,6 +22,7 @@ import org.apache.flink.api.java.typeutils.AvroUtils; import org.apache.flink.table.annotation.ArgumentHint; import org.apache.flink.table.annotation.DataTypeHint; +import org.apache.flink.table.annotation.StateHint; import org.apache.flink.table.api.DataTypes; import org.apache.flink.table.api.dataview.DataView; import org.apache.flink.table.api.dataview.ListView; @@ -144,8 +145,11 @@ public static DataType extractFromMethodParameter( final Parameter parameter = method.getParameters()[paramPos]; final DataTypeHint hint = parameter.getAnnotation(DataTypeHint.class); final ArgumentHint argumentHint = parameter.getAnnotation(ArgumentHint.class); + final StateHint stateHint = parameter.getAnnotation(StateHint.class); final DataTypeTemplate template; - if (argumentHint != null) { + if (stateHint != null) { + template = DataTypeTemplate.fromAnnotation(typeFactory, stateHint.type()); + } else if (argumentHint != null) { template = DataTypeTemplate.fromAnnotation(typeFactory, argumentHint.type()); } else if (hint != null) { template = DataTypeTemplate.fromAnnotation(typeFactory, hint); @@ -206,9 +210,9 @@ public static DataType extractFromGenericMethodParameter( * Extracts a data type from a method return type by considering surrounding classes and method * annotation. */ - public static DataType extractFromMethodOutput( + public static DataType extractFromMethodReturnType( DataTypeFactory typeFactory, Class baseClass, Method method) { - return extractFromMethodOutput( + return extractFromMethodReturnType( typeFactory, baseClass, method, method.getGenericReturnType()); } @@ -216,7 +220,7 @@ public static DataType extractFromMethodOutput( * Extracts a data type from a method return type with specifying the method's type explicitly * by considering surrounding classes and method annotation. */ - public static DataType extractFromMethodOutput( + public static DataType extractFromMethodReturnType( DataTypeFactory typeFactory, Class baseClass, Method method, Type methodReturnType) { final DataTypeHint hint = method.getAnnotation(DataTypeHint.class); final DataTypeTemplate template; diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/ExtractionUtils.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/ExtractionUtils.java index 0d01e04a08ffc8..34473f9c8cae91 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/ExtractionUtils.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/ExtractionUtils.java @@ -22,6 +22,7 @@ import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.table.annotation.ArgumentHint; +import org.apache.flink.table.annotation.StateHint; import org.apache.flink.table.api.DataTypes; import org.apache.flink.table.api.ValidationException; import org.apache.flink.table.catalog.DataTypeFactory; @@ -91,7 +92,8 @@ public static List collectMethods(Class function, String methodName) *

E.g., {@code (int.class, int.class)} matches {@code f(Object...), f(int, int), f(Integer, * Object)} and so forth. */ - public static boolean isInvokable(Executable executable, Class... classes) { + public static boolean isInvokable( + boolean strictAutoboxing, Executable executable, Class... classes) { final int m = executable.getModifiers(); if (!Modifier.isPublic(m)) { return false; @@ -110,21 +112,25 @@ public static boolean isInvokable(Executable executable, Class... classes) { if (currentParam == paramCount - 1 && executable.isVarArgs()) { final Class paramComponent = executable.getParameterTypes()[currentParam].getComponentType(); - // we have more than 1 classes left so the vararg needs to consume them all + // we have more than one class left so the vararg needs to consume them all if (classCount - currentClass > 1) { while (currentClass < classCount && ExtractionUtils.isAssignable( - classes[currentClass], paramComponent, true)) { + classes[currentClass], + paramComponent, + true, + strictAutoboxing)) { currentClass++; } } else if (currentClass < classCount - && (parameterMatches(classes[currentClass], param) - || parameterMatches(classes[currentClass], paramComponent))) { + && (parameterMatches(strictAutoboxing, classes[currentClass], param) + || parameterMatches( + strictAutoboxing, classes[currentClass], paramComponent))) { currentClass++; } } // entire parameter matches - else if (parameterMatches(classes[currentClass], param)) { + else if (parameterMatches(strictAutoboxing, classes[currentClass], param)) { currentClass++; } } @@ -132,8 +138,9 @@ else if (parameterMatches(classes[currentClass], param)) { return currentClass == classCount; } - private static boolean parameterMatches(Class clz, Class param) { - return clz == null || ExtractionUtils.isAssignable(clz, param, true); + private static boolean parameterMatches( + boolean strictAutoboxing, Class clz, Class param) { + return clz == null || ExtractionUtils.isAssignable(clz, param, true, strictAutoboxing); } /** Creates a method signature string like {@code int eval(Integer, String)}. */ @@ -298,11 +305,11 @@ private static String normalizeAccessorName(String name) { /** * Checks for an invokable constructor matching the given arguments. * - * @see #isInvokable(Executable, Class[]) + * @see #isInvokable(boolean, Executable, Class[]) */ public static boolean hasInvokableConstructor(Class clazz, Class... classes) { for (Constructor constructor : clazz.getDeclaredConstructors()) { - if (isInvokable(constructor, classes)) { + if (isInvokable(false, constructor, classes)) { return true; } } @@ -758,16 +765,20 @@ private AssigningConstructor(Constructor constructor, List parameterN } else { offset = 0; } - // by default parameter names are "arg0, arg1, arg2, ..." if compiler flag is not set - // so we need to extract them manually if possible + // by default parameter names are "arg0, arg1, arg2, ..." if compiler flag is not set, + // we need to extract them manually if possible List parameterNames = Stream.of(executable.getParameters()) .map( parameter -> { - ArgumentHint argumentHint = + final StateHint stateHint = + parameter.getAnnotation(StateHint.class); + final ArgumentHint argHint = parameter.getAnnotation(ArgumentHint.class); - if (argumentHint != null && !argumentHint.name().isEmpty()) { - return argumentHint.name(); + if (stateHint != null && !stateHint.name().isEmpty()) { + return stateHint.name(); + } else if (argHint != null && !argHint.name().isEmpty()) { + return argHint.name(); } else { return parameter.getName(); } @@ -787,7 +798,7 @@ private AssigningConstructor(Constructor constructor, List parameterN return null; } // remove "this" and additional local variables - // select less names if class file has not the required information + // select fewer names if class file has not the required information parameterNames = extractedNames.subList( offset, @@ -936,10 +947,11 @@ public void visitLocalVariable( * @param cls the Class to check, may be null * @param toClass the Class to try to assign into, returns false if null * @param autoboxing whether to use implicit autoboxing/unboxing between primitives and wrappers + * @param strictAutoboxing checks whether null would end up in a primitive type and forbids it * @return {@code true} if assignment possible */ public static boolean isAssignable( - Class cls, final Class toClass, final boolean autoboxing) { + Class cls, final Class toClass, boolean autoboxing, boolean strictAutoboxing) { if (toClass == null) { return false; } @@ -955,10 +967,12 @@ public static boolean isAssignable( return false; } } - if (toClass.isPrimitive() && !cls.isPrimitive()) { - cls = wrapperToPrimitive(cls); - if (cls == null) { - return false; + if (!strictAutoboxing) { + if (toClass.isPrimitive() && !cls.isPrimitive()) { + cls = wrapperToPrimitive(cls); + if (cls == null) { + return false; + } } } } diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionArgumentTemplate.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionArgumentTemplate.java index dd289d98a567b1..778efce2301ef1 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionArgumentTemplate.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionArgumentTemplate.java @@ -37,21 +37,29 @@ @Internal final class FunctionArgumentTemplate { - final @Nullable DataType dataType; + private final @Nullable DataType dataType; + private final @Nullable InputGroup inputGroup; + private final @Nullable Class conversionClass; - final @Nullable InputGroup inputGroup; - - private FunctionArgumentTemplate(@Nullable DataType dataType, @Nullable InputGroup inputGroup) { + private FunctionArgumentTemplate( + @Nullable DataType dataType, + @Nullable InputGroup inputGroup, + @Nullable Class conversionClass) { this.dataType = dataType; this.inputGroup = inputGroup; + this.conversionClass = conversionClass; + } + + static FunctionArgumentTemplate ofDataType(DataType dataType) { + return new FunctionArgumentTemplate(dataType, null, null); } - static FunctionArgumentTemplate of(DataType dataType) { - return new FunctionArgumentTemplate(dataType, null); + static FunctionArgumentTemplate ofInputGroup(InputGroup inputGroup) { + return new FunctionArgumentTemplate(null, inputGroup, null); } - static FunctionArgumentTemplate of(InputGroup inputGroup) { - return new FunctionArgumentTemplate(null, inputGroup); + static FunctionArgumentTemplate ofTable(Class conversionClass) { + return new FunctionArgumentTemplate(null, null, conversionClass); } ArgumentTypeStrategy toArgumentTypeStrategy() { @@ -68,10 +76,17 @@ ArgumentTypeStrategy toArgumentTypeStrategy() { } } + public @Nullable DataType toDataType() { + return dataType; + } + Class toConversionClass() { if (dataType != null) { return dataType.getConversionClass(); } + if (conversionClass != null) { + return conversionClass; + } assert inputGroup != null; switch (inputGroup) { case ANY: diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionMappingExtractor.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionMappingExtractor.java index 26909257645a18..44b0491f6a0c8e 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionMappingExtractor.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionMappingExtractor.java @@ -24,6 +24,8 @@ import org.apache.flink.table.catalog.DataTypeFactory; import org.apache.flink.table.functions.UserDefinedFunction; import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.extraction.FunctionResultTemplate.FunctionOutputTemplate; +import org.apache.flink.table.types.extraction.FunctionResultTemplate.FunctionStateTemplate; import org.apache.flink.util.Preconditions; import javax.annotation.Nullable; @@ -31,12 +33,13 @@ import java.lang.reflect.Method; import java.lang.reflect.ParameterizedType; import java.lang.reflect.Type; +import java.util.Arrays; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; -import java.util.stream.Collectors; +import java.util.concurrent.CompletableFuture; import java.util.stream.Stream; import static org.apache.flink.table.types.extraction.ExtractionUtils.collectAnnotationsOfClass; @@ -59,81 +62,47 @@ final class FunctionMappingExtractor extends BaseMappingExtractor { private final Class function; - private final @Nullable ResultExtraction accumulatorExtraction; + private final @Nullable ResultExtraction stateExtraction; + private final @Nullable MethodVerification stateVerification; FunctionMappingExtractor( DataTypeFactory typeFactory, Class function, String methodName, SignatureExtraction signatureExtraction, - @Nullable ResultExtraction accumulatorExtraction, + @Nullable ResultExtraction stateExtraction, + @Nullable MethodVerification stateVerification, ResultExtraction outputExtraction, - MethodVerification verification) { - super(typeFactory, methodName, signatureExtraction, outputExtraction, verification); + @Nullable MethodVerification outputVerification) { + super(typeFactory, methodName, signatureExtraction, outputExtraction, outputVerification); this.function = function; - this.accumulatorExtraction = accumulatorExtraction; + this.stateExtraction = stateExtraction; + this.stateVerification = stateVerification; } - Class getFunction() { - return function; - } - - boolean hasAccumulator() { - return accumulatorExtraction != null; - } - - @Override - protected Set extractGlobalFunctionTemplates() { - return TemplateUtils.extractGlobalFunctionTemplates(typeFactory, function); - } - - @Override - protected Set extractLocalFunctionTemplates(Method method) { - return TemplateUtils.extractLocalFunctionTemplates(typeFactory, method); - } - - @Override - protected List collectMethods(String methodName) { - return ExtractionUtils.collectMethods(function, methodName); - } - - @Override - protected Class getFunctionClass() { - return function; - } - - @Override - protected String getHintType() { - return "Function"; - } - - Map extractAccumulatorMapping() { - Preconditions.checkState(hasAccumulator()); + Map extractStateMapping() { + Preconditions.checkState(supportsState()); try { return extractResultMappings( - accumulatorExtraction, - FunctionTemplate::getAccumulatorTemplate, - (method, signature, result) -> { - // put the result into the signature for accumulators - final List> arguments = - Stream.concat(Stream.of(result), signature.stream()) - .collect(Collectors.toList()); - verification.verify(method, arguments, null); - }); + stateExtraction, FunctionTemplate::getStateTemplate, stateVerification); } catch (Throwable t) { - throw extractionError(t, "Error in extracting a signature to accumulator mapping."); + throw extractionError(t, "Error in extracting a signature to state mapping."); } } + // -------------------------------------------------------------------------------------------- + // Extraction strategies + // -------------------------------------------------------------------------------------------- + /** - * Extraction that uses the method return type for producing a {@link FunctionResultTemplate}. + * Extraction that uses the method return type for producing a {@link FunctionOutputTemplate}. */ - static ResultExtraction createReturnTypeResultExtraction() { + static ResultExtraction createOutputFromReturnTypeInMethod() { return (extractor, method) -> { final DataType dataType = - DataTypeExtractor.extractFromMethodOutput( + DataTypeExtractor.extractFromMethodReturnType( extractor.typeFactory, extractor.getFunctionClass(), method); - return FunctionResultTemplate.of(dataType); + return FunctionResultTemplate.ofOutput(dataType); }; } @@ -142,7 +111,7 @@ static ResultExtraction createReturnTypeResultExtraction() { * *

If enabled, a {@link DataTypeHint} from method or class has higher priority. */ - static ResultExtraction createGenericResultExtraction( + static ResultExtraction createOutputFromGenericInClass( Class baseClass, int genericPos, boolean allowDataTypeHint) { @@ -159,7 +128,7 @@ static ResultExtraction createGenericResultExtraction( baseClass, genericPos, extractor.getFunctionClass()); - return FunctionResultTemplate.of(dataType); + return FunctionResultTemplate.ofOutput(dataType); }; } @@ -169,7 +138,7 @@ static ResultExtraction createGenericResultExtraction( * *

If enabled, a {@link DataTypeHint} from method or class has higher priority. */ - static ResultExtraction createGenericResultExtractionFromMethod( + static ResultExtraction createOutputFromGenericInMethod( int paramPos, int genericPos, boolean allowDataTypeHint) { return (extractor, method) -> { if (allowDataTypeHint) { @@ -185,101 +154,166 @@ static ResultExtraction createGenericResultExtractionFromMethod( method, paramPos, genericPos); - return FunctionResultTemplate.of(dataType); + return FunctionResultTemplate.ofOutput(dataType); }; } - /** Uses hints to extract functional template. */ - private static Optional extractHints( - BaseMappingExtractor extractor, Method method) { - final Set dataTypeHints = new HashSet<>(); - dataTypeHints.addAll(collectAnnotationsOfMethod(DataTypeHint.class, method)); - dataTypeHints.addAll( - collectAnnotationsOfClass(DataTypeHint.class, extractor.getFunctionClass())); - if (dataTypeHints.size() > 1) { - throw extractionError( - "More than one data type hint found for output of function. " - + "Please use a function hint instead."); - } - if (dataTypeHints.size() == 1) { - return Optional.ofNullable( - FunctionTemplate.createResultTemplate( - extractor.typeFactory, dataTypeHints.iterator().next())); - } - // otherwise continue with regular extraction - return Optional.empty(); - } + // -------------------------------------------------------------------------------------------- + // Verification strategies + // -------------------------------------------------------------------------------------------- - /** Verification that checks a method by parameters and return type. */ + /** Verification that checks a method by parameters (arguments only) and return type. */ static MethodVerification createParameterAndReturnTypeVerification() { - return (method, signature, result) -> { - final Class[] parameters = signature.toArray(new Class[0]); + return (method, state, arguments, result) -> { + checkNoState(state); + final Class[] parameters = assembleParameters(state, arguments); final Class returnType = method.getReturnType(); + // TODO enable strict autoboxing final boolean isValid = - isInvokable(method, parameters) && isAssignable(result, returnType, true); + isInvokable(false, method, parameters) + && isAssignable(result, returnType, true, false); if (!isValid) { - throw createMethodNotFoundError(method.getName(), parameters, result); + throw createMethodNotFoundError(method.getName(), parameters, result, ""); } }; } - /** Verification that checks a method by parameters including an accumulator. */ - static MethodVerification createParameterWithAccumulatorVerification() { - return (method, signature, result) -> { - if (result != null) { - // ignore the accumulator in the first argument - createParameterWithArgumentVerification(null).verify(method, signature, result); + /** Verification that checks a method by parameters (arguments only or with accumulator). */ + static MethodVerification createParameterVerification(boolean requireAccumulator) { + return (method, state, arguments, result) -> { + if (requireAccumulator) { + checkSingleState(state); } else { - // check the signature only - createParameterVerification().verify(method, signature, null); + checkNoState(state); } - }; - } - - /** Verification that checks a method by parameters including an additional first parameter. */ - static MethodVerification createParameterWithArgumentVerification( - @Nullable Class argumentClass) { - return (method, signature, result) -> { - final Class[] parameters = - Stream.concat(Stream.of(argumentClass), signature.stream()) - .toArray(Class[]::new); - if (!isInvokable(method, parameters)) { - throw createMethodNotFoundError(method.getName(), parameters, null); + final Class[] parameters = assembleParameters(state, arguments); + // TODO enable strict autoboxing + if (!isInvokable(false, method, parameters)) { + throw createMethodNotFoundError( + method.getName(), + parameters, + null, + requireAccumulator ? "( [, ]*)" : ""); } }; } - /** Verification that checks a method by parameters including an additional first parameter. */ - static MethodVerification createGenericParameterWithArgumentAndReturnTypeVerification( - Class baseClass, Class argumentClass, int paramPos, int genericPos) { - return (method, signature, result) -> { - final Class[] parameters = - Stream.concat(Stream.of(argumentClass), signature.stream()) + /** + * Verification that checks a method by parameters (arguments only) with mandatory {@link + * CompletableFuture}. + */ + static MethodVerification createParameterAndCompletableFutureVerification(Class baseClass) { + return (method, state, arguments, result) -> { + checkNoState(state); + final Class[] parameters = assembleParameters(state, arguments); + final Class[] parametersWithFuture = + Stream.concat(Stream.of(CompletableFuture.class), Arrays.stream(parameters)) .toArray(Class[]::new); - Type genericType = method.getGenericParameterTypes()[paramPos]; + Type genericType = method.getGenericParameterTypes()[0]; genericType = resolveVariableWithClassContext(baseClass, genericType); if (!(genericType instanceof ParameterizedType)) { throw extractionError( - "The method '%s' needs generic parameters for the %d arg.", - method.getName(), paramPos); + "The method '%s' needs generic parameters for the CompletableFuture at position %d.", + method.getName(), 0); } - Type returnType = - ((ParameterizedType) genericType).getActualTypeArguments()[genericPos]; + final Type returnType = ((ParameterizedType) genericType).getActualTypeArguments()[0]; Class returnClazz = getClassFromType(returnType); - if (!(isInvokable(method, parameters) && isAssignable(result, returnClazz, true))) { - throw createMethodNotFoundError(method.getName(), parameters, null); + // TODO enable strict autoboxing + if (!(isInvokable(false, method, parametersWithFuture) + && isAssignable(result, returnClazz, true, false))) { + throw createMethodNotFoundError( + method.getName(), + parametersWithFuture, + null, + "( [, ]*)"); } }; } - /** Verification that checks a method by parameters. */ - static MethodVerification createParameterVerification() { - return (method, signature, result) -> { - final Class[] parameters = signature.toArray(new Class[0]); - if (!isInvokable(method, parameters)) { - throw createMethodNotFoundError(method.getName(), parameters, null); + /** + * Verification that checks a method by parameters (state and arguments) with optional context. + */ + static MethodVerification createParameterAndOptionalContextVerification( + Class context, boolean allowState) { + return (method, state, arguments, result) -> { + if (!allowState) { + checkNoState(state); + } + final Class[] parameters = assembleParameters(state, arguments); + final Class[] parametersWithContext = + Stream.concat(Stream.of(context), Arrays.stream(parameters)) + .toArray(Class[]::new); + if (!isInvokable(true, method, parameters) + && !isInvokable(true, method, parametersWithContext)) { + throw createMethodNotFoundError( + method.getName(), + parameters, + null, + allowState ? "(? [, ]* [, ]*)" : ""); } }; } + + // -------------------------------------------------------------------------------------------- + // Methods from super class + // -------------------------------------------------------------------------------------------- + + Class getFunction() { + return function; + } + + boolean supportsState() { + return stateExtraction != null; + } + + @Override + protected Set extractGlobalFunctionTemplates() { + return TemplateUtils.extractGlobalFunctionTemplates(typeFactory, function); + } + + @Override + protected Set extractLocalFunctionTemplates(Method method) { + return TemplateUtils.extractLocalFunctionTemplates(typeFactory, method); + } + + @Override + protected List collectMethods(String methodName) { + return ExtractionUtils.collectMethods(function, methodName); + } + + @Override + protected Class getFunctionClass() { + return function; + } + + @Override + protected String getHintType() { + return "Function"; + } + + // -------------------------------------------------------------------------------------------- + // Helper methods + // -------------------------------------------------------------------------------------------- + + /** Uses hints to extract functional template. */ + private static Optional extractHints( + BaseMappingExtractor extractor, Method method) { + final Set dataTypeHints = new HashSet<>(); + dataTypeHints.addAll(collectAnnotationsOfMethod(DataTypeHint.class, method)); + dataTypeHints.addAll( + collectAnnotationsOfClass(DataTypeHint.class, extractor.getFunctionClass())); + if (dataTypeHints.size() > 1) { + throw extractionError( + "More than one data type hint found for output of function. " + + "Please use a function hint instead."); + } + if (dataTypeHints.size() == 1) { + return Optional.ofNullable( + FunctionTemplate.createOutputTemplate( + extractor.typeFactory, dataTypeHints.iterator().next())); + } + // otherwise continue with regular extraction + return Optional.empty(); + } } diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionResultTemplate.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionResultTemplate.java index e8305f9abe0a31..ce6db9596c47c3 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionResultTemplate.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionResultTemplate.java @@ -20,47 +20,128 @@ import org.apache.flink.annotation.Internal; import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.inference.StateTypeStrategy; +import org.apache.flink.table.types.inference.StateTypeStrategyWrapper; import org.apache.flink.table.types.inference.TypeStrategies; import org.apache.flink.table.types.inference.TypeStrategy; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; import java.util.Objects; +import java.util.stream.Collectors; -/** Template of a function intermediate result (i.e. accumulator) or final result (i.e. output). */ -@Internal -final class FunctionResultTemplate { +import static org.apache.flink.table.types.extraction.ExtractionUtils.extractionError; - final DataType dataType; +/** Template of a function intermediate result (i.e. state) or final result (i.e. output). */ +@Internal +interface FunctionResultTemplate { - private FunctionResultTemplate(DataType dataType) { - this.dataType = dataType; + static FunctionOutputTemplate ofOutput(DataType dataType) { + return new FunctionOutputTemplate(dataType); } - static FunctionResultTemplate of(DataType dataType) { - return new FunctionResultTemplate(dataType); + static FunctionStateTemplate ofState(LinkedHashMap state) { + return new FunctionStateTemplate(state); } - TypeStrategy toTypeStrategy() { - return TypeStrategies.explicit(dataType); - } + class FunctionOutputTemplate implements FunctionResultTemplate { - Class toClass() { - return dataType.getConversionClass(); - } + private final DataType dataType; + + private FunctionOutputTemplate(DataType dataType) { + this.dataType = dataType; + } + + TypeStrategy toTypeStrategy() { + return TypeStrategies.explicit(dataType); + } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; + Class toClass() { + return dataType.getConversionClass(); } - if (o == null || getClass() != o.getClass()) { - return false; + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final FunctionOutputTemplate template = (FunctionOutputTemplate) o; + return Objects.equals(dataType, template.dataType); + } + + @Override + public int hashCode() { + return Objects.hash(dataType); } - FunctionResultTemplate that = (FunctionResultTemplate) o; - return dataType.equals(that.dataType); } - @Override - public int hashCode() { - return Objects.hash(dataType); + class FunctionStateTemplate implements FunctionResultTemplate { + + private final LinkedHashMap state; + + private FunctionStateTemplate(LinkedHashMap state) { + this.state = state; + } + + List> toClassList() { + return state.values().stream() + .map(DataType::getConversionClass) + .collect(Collectors.toList()); + } + + LinkedHashMap toStateTypeStrategies() { + return state.entrySet().stream() + .collect( + Collectors.toMap( + Map.Entry::getKey, + e -> createStateTypeStrategy(e.getValue()), + (o, n) -> o, + LinkedHashMap::new)); + } + + String toAccumulatorStateName() { + checkSingleStateEntry(); + return state.keySet().iterator().next(); + } + + TypeStrategy toAccumulatorTypeStrategy() { + checkSingleStateEntry(); + return createTypeStrategy(state.values().iterator().next()); + } + + private void checkSingleStateEntry() { + if (state.size() != 1) { + throw extractionError("Aggregating functions support only one state entry."); + } + } + + private static StateTypeStrategy createStateTypeStrategy(DataType dataType) { + return StateTypeStrategyWrapper.of(TypeStrategies.explicit(dataType)); + } + + private static TypeStrategy createTypeStrategy(DataType dataType) { + return TypeStrategies.explicit(dataType); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final FunctionStateTemplate that = (FunctionStateTemplate) o; + return Objects.equals(state, that.state); + } + + @Override + public int hashCode() { + return Objects.hash(state); + } } } diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionSignatureTemplate.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionSignatureTemplate.java index 855f11e6fa830d..db641fb39ed270 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionSignatureTemplate.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionSignatureTemplate.java @@ -23,11 +23,14 @@ import org.apache.flink.table.types.inference.ArgumentTypeStrategy; import org.apache.flink.table.types.inference.InputTypeStrategies; import org.apache.flink.table.types.inference.InputTypeStrategy; +import org.apache.flink.table.types.inference.StaticArgument; +import org.apache.flink.table.types.inference.StaticArgumentTrait; import javax.annotation.Nullable; import java.lang.reflect.Array; import java.util.Arrays; +import java.util.EnumSet; import java.util.List; import java.util.Objects; import java.util.stream.Collectors; @@ -43,17 +46,21 @@ final class FunctionSignatureTemplate { final boolean isVarArgs; + final EnumSet[] argumentTraits; + final @Nullable String[] argumentNames; - final Boolean[] argumentOptionals; + final boolean[] argumentOptionals; private FunctionSignatureTemplate( List argumentTemplates, boolean isVarArgs, + EnumSet[] argumentTraits, @Nullable String[] argumentNames, - Boolean[] argumentOptionals) { + boolean[] argumentOptionals) { this.argumentTemplates = argumentTemplates; this.isVarArgs = isVarArgs; + this.argumentTraits = argumentTraits; this.argumentNames = argumentNames; this.argumentOptionals = argumentOptionals; } @@ -61,8 +68,9 @@ private FunctionSignatureTemplate( static FunctionSignatureTemplate of( List argumentTemplates, boolean isVarArgs, + EnumSet[] argumentTraits, @Nullable String[] argumentNames, - Boolean[] argumentOptionals) { + boolean[] argumentOptionals) { if (argumentNames != null && argumentNames.length != argumentTemplates.size()) { throw extractionError( "Mismatch between number of argument names '%s' and argument types '%s'.", @@ -80,7 +88,7 @@ static FunctionSignatureTemplate of( } if (argumentOptionals != null) { for (int i = 0; i < argumentTemplates.size(); i++) { - DataType dataType = argumentTemplates.get(i).dataType; + DataType dataType = argumentTemplates.get(i).toDataType(); if (dataType != null && !dataType.getLogicalType().isNullable() && argumentOptionals[i]) { @@ -91,7 +99,70 @@ static FunctionSignatureTemplate of( } } return new FunctionSignatureTemplate( - argumentTemplates, isVarArgs, argumentNames, argumentOptionals); + argumentTemplates, isVarArgs, argumentTraits, argumentNames, argumentOptionals); + } + + /** + * Converts the given signature into a list of static arguments if the signature allows it. E.g. + * no var-args and all arguments are named. + */ + @Nullable + List toStaticArguments() { + if (isVarArgs || argumentNames == null) { + return null; + } + final List arguments = + IntStream.range(0, argumentTemplates.size()) + .mapToObj( + pos -> { + final String name = argumentNames[pos]; + final boolean isOptional = argumentOptionals[pos]; + final FunctionArgumentTemplate template = + argumentTemplates.get(pos); + final EnumSet traits = argumentTraits[pos]; + if (traits.contains(StaticArgumentTrait.TABLE_AS_ROW) + || traits.contains(StaticArgumentTrait.TABLE_AS_SET)) { + return createTableArgument( + name, + isOptional, + traits, + template.toDataType(), + template.toConversionClass()); + } else if (traits.contains(StaticArgumentTrait.SCALAR)) { + return createScalarArgument( + name, isOptional, template.toDataType()); + } else { + return null; + } + }) + .collect(Collectors.toList()); + if (arguments.contains(null)) { + return null; + } + return arguments; + } + + private static @Nullable StaticArgument createTableArgument( + String name, + boolean isOptional, + EnumSet traits, + @Nullable DataType dataType, + @Nullable Class conversionClass) { + if (dataType != null) { + return StaticArgument.table(name, dataType, isOptional, traits); + } + if (conversionClass != null) { + return StaticArgument.table(name, conversionClass, isOptional, traits); + } + return null; + } + + private static @Nullable StaticArgument createScalarArgument( + String name, boolean isOptional, @Nullable DataType dataType) { + if (dataType != null) { + return StaticArgument.scalar(name, dataType, isOptional); + } + return null; } InputTypeStrategy toInputTypeStrategy() { @@ -117,7 +188,7 @@ InputTypeStrategy toInputTypeStrategy() { return strategy; } - List> toClass() { + List> toClassList() { return IntStream.range(0, argumentTemplates.size()) .mapToObj( i -> { diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionTemplate.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionTemplate.java index c92830393cc148..0b62b0e340cda8 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionTemplate.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionTemplate.java @@ -24,15 +24,27 @@ import org.apache.flink.table.annotation.DataTypeHint; import org.apache.flink.table.annotation.FunctionHint; import org.apache.flink.table.annotation.ProcedureHint; +import org.apache.flink.table.annotation.StateHint; import org.apache.flink.table.catalog.DataTypeFactory; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.extraction.FunctionResultTemplate.FunctionOutputTemplate; +import org.apache.flink.table.types.extraction.FunctionResultTemplate.FunctionStateTemplate; +import org.apache.flink.table.types.inference.StaticArgumentTrait; +import org.apache.flink.types.Row; import javax.annotation.Nullable; import java.lang.annotation.Annotation; import java.util.Arrays; +import java.util.EnumSet; +import java.util.LinkedHashMap; +import java.util.List; import java.util.Objects; +import java.util.Set; import java.util.function.Function; import java.util.stream.Collectors; +import java.util.stream.IntStream; import static org.apache.flink.table.types.extraction.ExtractionUtils.extractionError; @@ -47,16 +59,16 @@ final class FunctionTemplate { private final @Nullable FunctionSignatureTemplate signatureTemplate; - private final @Nullable FunctionResultTemplate accumulatorTemplate; + private final @Nullable FunctionStateTemplate stateTemplate; - private final @Nullable FunctionResultTemplate outputTemplate; + private final @Nullable FunctionOutputTemplate outputTemplate; private FunctionTemplate( @Nullable FunctionSignatureTemplate signatureTemplate, - @Nullable FunctionResultTemplate accumulatorTemplate, - @Nullable FunctionResultTemplate outputTemplate) { + @Nullable FunctionStateTemplate stateTemplate, + @Nullable FunctionOutputTemplate outputTemplate) { this.signatureTemplate = signatureTemplate; - this.accumulatorTemplate = accumulatorTemplate; + this.stateTemplate = stateTemplate; this.outputTemplate = outputTemplate; } @@ -64,10 +76,8 @@ private FunctionTemplate( * Creates an instance using the given {@link FunctionHint}. It resolves explicitly defined data * types. */ + @SuppressWarnings("deprecation") static FunctionTemplate fromAnnotation(DataTypeFactory typeFactory, FunctionHint hint) { - if (hint.state().length > 0) { - throw extractionError("State hints are not supported yet."); - } return new FunctionTemplate( createSignatureTemplate( typeFactory, @@ -76,14 +86,18 @@ static FunctionTemplate fromAnnotation(DataTypeFactory typeFactory, FunctionHint defaultAsNull(hint, FunctionHint::argument), defaultAsNull(hint, FunctionHint::arguments), hint.isVarArgs()), - createResultTemplate(typeFactory, defaultAsNull(hint, FunctionHint::accumulator)), - createResultTemplate(typeFactory, defaultAsNull(hint, FunctionHint::output))); + createStateTemplate( + typeFactory, + defaultAsNull(hint, FunctionHint::accumulator), + defaultAsNull(hint, FunctionHint::state)), + createOutputTemplate(typeFactory, defaultAsNull(hint, FunctionHint::output))); } /** * Creates an instance using the given {@link ProcedureHint}. It resolves explicitly defined * data types. */ + @SuppressWarnings("deprecation") static FunctionTemplate fromAnnotation(DataTypeFactory typeFactory, ProcedureHint hint) { return new FunctionTemplate( createSignatureTemplate( @@ -93,12 +107,12 @@ static FunctionTemplate fromAnnotation(DataTypeFactory typeFactory, ProcedureHin defaultAsNull(hint, ProcedureHint::argument), defaultAsNull(hint, ProcedureHint::arguments), hint.isVarArgs()), - null, - createResultTemplate(typeFactory, defaultAsNull(hint, ProcedureHint::output))); + createStateTemplate(typeFactory, null, null), + createOutputTemplate(typeFactory, defaultAsNull(hint, ProcedureHint::output))); } /** Creates an instance of {@link FunctionResultTemplate} from a {@link DataTypeHint}. */ - static @Nullable FunctionResultTemplate createResultTemplate( + static @Nullable FunctionOutputTemplate createOutputTemplate( DataTypeFactory typeFactory, @Nullable DataTypeHint hint) { if (hint == null) { return null; @@ -110,20 +124,49 @@ static FunctionTemplate fromAnnotation(DataTypeFactory typeFactory, ProcedureHin throw extractionError(t, "Error in data type hint annotation."); } if (template.dataType != null) { - return FunctionResultTemplate.of(template.dataType); + return FunctionResultTemplate.ofOutput(template.dataType); } throw extractionError( "Data type hint does not specify a data type for use as function result."); } + /** Creates a {@link FunctionStateTemplate}s from {@link StateHint}s or accumulator. */ + static @Nullable FunctionStateTemplate createStateTemplate( + DataTypeFactory typeFactory, + @Nullable DataTypeHint accumulatorHint, + @Nullable StateHint[] stateHints) { + if (accumulatorHint == null && stateHints == null) { + return null; + } + if (accumulatorHint != null && stateHints != null) { + throw extractionError( + "State hints and accumulator cannot be declared in the same function hint. " + + "Use either one or the other."); + } + final LinkedHashMap state = new LinkedHashMap<>(); + if (accumulatorHint != null) { + state.put("acc", createStateDataType(typeFactory, accumulatorHint, "accumulator")); + return FunctionResultTemplate.ofState(state); + } + IntStream.range(0, stateHints.length) + .forEach( + pos -> { + final StateHint hint = stateHints[pos]; + state.put( + hint.name(), + createStateDataType(typeFactory, hint.type(), "state entry")); + }); + return FunctionResultTemplate.ofState(state); + } + @Nullable FunctionSignatureTemplate getSignatureTemplate() { return signatureTemplate; } @Nullable - FunctionResultTemplate getAccumulatorTemplate() { - return accumulatorTemplate; + FunctionResultTemplate getStateTemplate() { + return stateTemplate; } @Nullable @@ -141,13 +184,13 @@ public boolean equals(Object o) { } FunctionTemplate template = (FunctionTemplate) o; return Objects.equals(signatureTemplate, template.signatureTemplate) - && Objects.equals(accumulatorTemplate, template.accumulatorTemplate) + && Objects.equals(stateTemplate, template.stateTemplate) && Objects.equals(outputTemplate, template.outputTemplate); } @Override public int hashCode() { - return Objects.hash(signatureTemplate, accumulatorTemplate, outputTemplate); + return Objects.hash(signatureTemplate, stateTemplate, outputTemplate); } // -------------------------------------------------------------------------------------------- @@ -185,6 +228,7 @@ private static T defaultAsNull( return actualValue; } + @SuppressWarnings("unchecked") private static @Nullable FunctionSignatureTemplate createSignatureTemplate( DataTypeFactory typeFactory, @Nullable DataTypeHint[] inputs, @@ -204,79 +248,150 @@ private static T defaultAsNull( argumentHints = pluralArgumentHints; } - String[] argumentHintNames; - DataTypeHint[] argumentHintTypes; - // Deal with #arguments() and #input() if (argumentHints != null && inputs != null) { throw extractionError( - "Argument and input hints cannot be declared in the same function hint."); + "Argument and input hints cannot be declared in the same function hint. " + + "Use either one or the other."); } - - Boolean[] argumentOptionals; + final DataTypeHint[] argumentHintTypes; + final boolean[] argumentOptionals; + final ArgumentTrait[][] argumentTraits; + String[] argumentHintNames; if (argumentHints != null) { - final boolean allScalar = - Arrays.stream(argumentHints) - .allMatch( - h -> { - final ArgumentTrait[] traits = h.value(); - return traits.length == 1 - && traits[0] == ArgumentTrait.SCALAR; - }); - if (!allScalar) { - throw extractionError("Only scalar arguments are supported so far."); - } - - argumentHintNames = new String[argumentHints.length]; argumentHintTypes = new DataTypeHint[argumentHints.length]; - argumentOptionals = new Boolean[argumentHints.length]; - boolean allArgumentNameNotSet = true; + argumentOptionals = new boolean[argumentHints.length]; + argumentTraits = new ArgumentTrait[argumentHints.length][]; + argumentHintNames = new String[argumentHints.length]; + boolean allArgumentNamesNotSet = true; for (int i = 0; i < argumentHints.length; i++) { - ArgumentHint argumentHint = argumentHints[i]; + final ArgumentHint argumentHint = argumentHints[i]; argumentHintNames[i] = defaultAsNull(argumentHint, ArgumentHint::name); argumentHintTypes[i] = defaultAsNull(argumentHint, ArgumentHint::type); argumentOptionals[i] = argumentHint.isOptional(); - if (argumentHintTypes[i] == null) { - throw extractionError("The type of the argument at position %d is not set.", i); - } + argumentTraits[i] = argumentHint.value(); if (argumentHintNames[i] != null) { - allArgumentNameNotSet = false; - } else if (!allArgumentNameNotSet) { + allArgumentNamesNotSet = false; + } else if (!allArgumentNamesNotSet) { throw extractionError( - "The argument name in function hint must be either fully set or not set at all."); + "Argument names in function hint must be either fully set or not set at all."); } } - if (allArgumentNameNotSet) { + if (allArgumentNamesNotSet) { argumentHintNames = null; } - } else { - if (inputs == null) { - return null; - } + } else if (inputs != null) { argumentHintTypes = inputs; argumentHintNames = argumentNames; - argumentOptionals = new Boolean[inputs.length]; - Arrays.fill(argumentOptionals, false); + argumentOptionals = new boolean[inputs.length]; + argumentTraits = new ArgumentTrait[inputs.length][]; + Arrays.fill(argumentTraits, new ArgumentTrait[] {ArgumentTrait.SCALAR}); + } else { + return null; } + final List argumentTemplates = + IntStream.range(0, argumentHintTypes.length) + .mapToObj( + i -> + createArgumentTemplate( + typeFactory, + i, + argumentHintTypes[i], + argumentTraits[i])) + .collect(Collectors.toList()); + return FunctionSignatureTemplate.of( - Arrays.stream(argumentHintTypes) - .map(dataTypeHint -> createArgumentTemplate(typeFactory, dataTypeHint)) - .collect(Collectors.toList()), + argumentTemplates, isVarArg, + Arrays.stream(argumentTraits) + .map( + t -> { + final List traits = + Arrays.stream(t) + .map(ArgumentTrait::toStaticTrait) + .collect(Collectors.toList()); + return EnumSet.copyOf(traits); + }) + .toArray(EnumSet[]::new), argumentHintNames, argumentOptionals); } private static FunctionArgumentTemplate createArgumentTemplate( - DataTypeFactory typeFactory, DataTypeHint hint) { - final DataTypeTemplate template = DataTypeTemplate.fromAnnotation(typeFactory, hint); + DataTypeFactory typeFactory, + int pos, + @Nullable DataTypeHint hint, + ArgumentTrait[] argumentTraits) { + final Set rootTrait = + Arrays.stream(argumentTraits) + .filter(ArgumentTrait::isRoot) + .collect(Collectors.toSet()); + if (rootTrait.size() != 1) { + throw extractionError( + "Incorrect argument kind at position %d. Argument kind must be one of: %s", + pos, + Arrays.stream(ArgumentTrait.values()) + .filter(ArgumentTrait::isRoot) + .collect(Collectors.toList())); + } + + if (rootTrait.contains(ArgumentTrait.SCALAR)) { + if (hint != null) { + final DataTypeTemplate template; + try { + template = DataTypeTemplate.fromAnnotation(typeFactory, hint); + } catch (Throwable t) { + throw extractionError( + t, + "Error in data type hint annotation for argument at position %s.", + pos); + } + if (template.dataType != null) { + return FunctionArgumentTemplate.ofDataType(template.dataType); + } else if (template.inputGroup != null) { + return FunctionArgumentTemplate.ofInputGroup(template.inputGroup); + } + } + throw extractionError("Data type missing for scalar argument at position %s.", pos); + } else if (rootTrait.contains(ArgumentTrait.TABLE_AS_ROW) + || rootTrait.contains(ArgumentTrait.TABLE_AS_SET)) { + try { + final DataTypeTemplate template = + DataTypeTemplate.fromAnnotation(typeFactory, hint); + if (template.dataType != null) { + return FunctionArgumentTemplate.ofDataType(template.dataType); + } else if (template.inputGroup != null) { + throw extractionError( + "Input groups are not supported for table argument at position %s.", + pos); + } + return FunctionArgumentTemplate.ofTable(Row.class); + } catch (Throwable t) { + final Class argClass = hint == null ? Row.class : hint.bridgedTo(); + if (argClass == Row.class || argClass == RowData.class) { + return FunctionArgumentTemplate.ofTable(argClass); + } + // Just a regular error for a typed argument + throw t; + } + } else { + throw extractionError("Unknown argument kind."); + } + } + + private static DataType createStateDataType( + DataTypeFactory typeFactory, DataTypeHint dataTypeHint, String description) { + final DataTypeTemplate template; + try { + template = DataTypeTemplate.fromAnnotation(typeFactory, dataTypeHint); + } catch (Throwable t) { + throw extractionError(t, "Error in data type hint annotation."); + } if (template.dataType != null) { - return FunctionArgumentTemplate.of(template.dataType); - } else if (template.inputGroup != null) { - return FunctionArgumentTemplate.of(template.inputGroup); + return template.dataType; } throw extractionError( - "Data type hint does neither specify a data type nor input group for use as function argument."); + "Data type hint does not specify a data type for use as %s.", description); } } diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/ProcedureMappingExtractor.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/ProcedureMappingExtractor.java index da37485011b437..43765a9b5cbcd9 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/ProcedureMappingExtractor.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/ProcedureMappingExtractor.java @@ -22,9 +22,11 @@ import org.apache.flink.table.catalog.DataTypeFactory; import org.apache.flink.table.procedures.Procedure; import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.extraction.FunctionResultTemplate.FunctionOutputTemplate; import java.lang.reflect.Array; import java.lang.reflect.Method; +import java.util.Arrays; import java.util.List; import java.util.Set; import java.util.stream.Stream; @@ -49,12 +51,66 @@ final class ProcedureMappingExtractor extends BaseMappingExtractor { Class procedure, String methodName, SignatureExtraction signatureExtraction, - ResultExtraction outputExtraction, + ResultExtraction resultExtraction, MethodVerification verification) { - super(typeFactory, methodName, signatureExtraction, outputExtraction, verification); + super(typeFactory, methodName, signatureExtraction, resultExtraction, verification); this.procedure = procedure; } + // -------------------------------------------------------------------------------------------- + // Extraction strategy + // -------------------------------------------------------------------------------------------- + + /** + * Extraction that uses the method return type for producing a {@link FunctionOutputTemplate}. + */ + static ResultExtraction createOutputFromArrayReturnTypeInMethod() { + return (extractor, method) -> { + final DataType dataType = + DataTypeExtractor.extractFromMethodReturnType( + extractor.typeFactory, + extractor.getFunctionClass(), + method, + method.getReturnType().getComponentType()); + return FunctionResultTemplate.ofOutput(dataType); + }; + } + + // -------------------------------------------------------------------------------------------- + // Verification strategy + // -------------------------------------------------------------------------------------------- + + /** + * Verification that checks a method by parameters (arguments only) with mandatory context and + * array return type. + */ + static MethodVerification createParameterWithOptionalContextAndArrayReturnTypeVerification() { + return (method, state, arguments, result) -> { + checkNoState(state); + final Class[] parameters = assembleParameters(state, arguments); + // ignore the ProcedureContext in the first argument + final Class[] parametersWithContext = + Stream.concat(Stream.of((Class) null), Arrays.stream(parameters)) + .toArray(Class[]::new); + final Class returnType = method.getReturnType(); + final boolean isValid = + isInvokable(true, method, parametersWithContext) + && returnType.isArray() + && isAssignable(result, returnType.getComponentType(), true, true); + if (!isValid) { + throw createMethodNotFoundError( + method.getName(), + parametersWithContext, + Array.newInstance(result, 0).getClass(), + "( [, ]*)"); + } + }; + } + + // -------------------------------------------------------------------------------------------- + // Methods from super class + // -------------------------------------------------------------------------------------------- + @Override protected Set extractGlobalFunctionTemplates() { return TemplateUtils.extractProcedureGlobalFunctionTemplates(typeFactory, procedure); @@ -79,37 +135,4 @@ protected Class getFunctionClass() { protected String getHintType() { return "Procedure"; } - - /** - * Extraction that uses the method return type for producing a {@link FunctionResultTemplate}. - */ - static ResultExtraction createReturnTypeResultExtraction() { - return (extractor, method) -> { - final DataType dataType = - DataTypeExtractor.extractFromMethodOutput( - extractor.typeFactory, - extractor.getFunctionClass(), - method, - method.getReturnType().getComponentType()); - return FunctionResultTemplate.of(dataType); - }; - } - - static MethodVerification createParameterAndReturnTypeVerification() { - return ((method, signature, result) -> { - // ignore the ProcedureContext in the first argument - final Class[] parameters = - Stream.concat(Stream.of((Class) null), signature.stream()) - .toArray(Class[]::new); - final Class returnType = method.getReturnType(); - final boolean isValid = - isInvokable(method, parameters) - && returnType.isArray() - && isAssignable(result, returnType.getComponentType(), true); - if (!isValid) { - throw createMethodNotFoundError( - method.getName(), parameters, Array.newInstance(result, 0).getClass()); - } - }); - } } diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/TypeInferenceExtractor.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/TypeInferenceExtractor.java index 1b54abd5574769..5ef431057fadf2 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/TypeInferenceExtractor.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/TypeInferenceExtractor.java @@ -25,6 +25,7 @@ import org.apache.flink.table.functions.AggregateFunction; import org.apache.flink.table.functions.AsyncScalarFunction; import org.apache.flink.table.functions.AsyncTableFunction; +import org.apache.flink.table.functions.ProcessTableFunction; import org.apache.flink.table.functions.ScalarFunction; import org.apache.flink.table.functions.TableAggregateFunction; import org.apache.flink.table.functions.TableFunction; @@ -32,33 +33,38 @@ import org.apache.flink.table.functions.UserDefinedFunctionHelper; import org.apache.flink.table.procedures.Procedure; import org.apache.flink.table.procedures.ProcedureDefinition; -import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.extraction.FunctionResultTemplate.FunctionOutputTemplate; +import org.apache.flink.table.types.extraction.FunctionResultTemplate.FunctionStateTemplate; import org.apache.flink.table.types.inference.InputTypeStrategies; import org.apache.flink.table.types.inference.InputTypeStrategy; +import org.apache.flink.table.types.inference.StateTypeStrategy; +import org.apache.flink.table.types.inference.StateTypeStrategyWrapper; +import org.apache.flink.table.types.inference.StaticArgument; import org.apache.flink.table.types.inference.TypeInference; import org.apache.flink.table.types.inference.TypeStrategies; import org.apache.flink.table.types.inference.TypeStrategy; import javax.annotation.Nullable; -import java.util.Arrays; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.Set; -import java.util.concurrent.CompletableFuture; import java.util.stream.Collectors; +import static org.apache.flink.table.types.extraction.BaseMappingExtractor.createArgumentsFromParametersExtraction; +import static org.apache.flink.table.types.extraction.BaseMappingExtractor.createStateFromParametersExtraction; import static org.apache.flink.table.types.extraction.ExtractionUtils.extractionError; -import static org.apache.flink.table.types.extraction.FunctionMappingExtractor.createGenericParameterWithArgumentAndReturnTypeVerification; -import static org.apache.flink.table.types.extraction.FunctionMappingExtractor.createGenericResultExtraction; -import static org.apache.flink.table.types.extraction.FunctionMappingExtractor.createGenericResultExtractionFromMethod; +import static org.apache.flink.table.types.extraction.FunctionMappingExtractor.createOutputFromGenericInClass; +import static org.apache.flink.table.types.extraction.FunctionMappingExtractor.createOutputFromGenericInMethod; +import static org.apache.flink.table.types.extraction.FunctionMappingExtractor.createOutputFromReturnTypeInMethod; +import static org.apache.flink.table.types.extraction.FunctionMappingExtractor.createParameterAndCompletableFutureVerification; +import static org.apache.flink.table.types.extraction.FunctionMappingExtractor.createParameterAndOptionalContextVerification; import static org.apache.flink.table.types.extraction.FunctionMappingExtractor.createParameterAndReturnTypeVerification; -import static org.apache.flink.table.types.extraction.FunctionMappingExtractor.createParameterSignatureExtraction; import static org.apache.flink.table.types.extraction.FunctionMappingExtractor.createParameterVerification; -import static org.apache.flink.table.types.extraction.FunctionMappingExtractor.createParameterWithAccumulatorVerification; -import static org.apache.flink.table.types.extraction.FunctionMappingExtractor.createParameterWithArgumentVerification; -import static org.apache.flink.table.types.extraction.FunctionMappingExtractor.createReturnTypeResultExtraction; +import static org.apache.flink.table.types.extraction.FunctionMappingExtractor.createStateFromGenericInClassOrParameters; +import static org.apache.flink.table.types.extraction.ProcedureMappingExtractor.createOutputFromArrayReturnTypeInMethod; +import static org.apache.flink.table.types.extraction.ProcedureMappingExtractor.createParameterWithOptionalContextAndArrayReturnTypeVerification; /** * Reflection-based utility for extracting a {@link TypeInference} from a supported subclass of @@ -81,11 +87,12 @@ public static TypeInference forScalarFunction( typeFactory, function, UserDefinedFunctionHelper.SCALAR_EVAL, - createParameterSignatureExtraction(0), + createArgumentsFromParametersExtraction(0), null, - createReturnTypeResultExtraction(), + null, + createOutputFromReturnTypeInMethod(), createParameterAndReturnTypeVerification()); - return extractTypeInference(mappingExtractor); + return extractTypeInference(mappingExtractor, false); } /** Extracts a type inference from a {@link AsyncScalarFunction}. */ @@ -96,12 +103,12 @@ public static TypeInference forAsyncScalarFunction( typeFactory, function, UserDefinedFunctionHelper.ASYNC_SCALAR_EVAL, - createParameterSignatureExtraction(1), + createArgumentsFromParametersExtraction(1), null, - createGenericResultExtractionFromMethod(0, 0, true), - createGenericParameterWithArgumentAndReturnTypeVerification( - function, CompletableFuture.class, 0, 0)); - return extractTypeInference(mappingExtractor); + null, + createOutputFromGenericInMethod(0, 0, true), + createParameterAndCompletableFutureVerification(function)); + return extractTypeInference(mappingExtractor, false); } /** Extracts a type inference from a {@link AggregateFunction}. */ @@ -112,11 +119,12 @@ public static TypeInference forAggregateFunction( typeFactory, function, UserDefinedFunctionHelper.AGGREGATE_ACCUMULATE, - createParameterSignatureExtraction(1), - createGenericResultExtraction(AggregateFunction.class, 1, false), - createGenericResultExtraction(AggregateFunction.class, 0, true), - createParameterWithAccumulatorVerification()); - return extractTypeInference(mappingExtractor); + createArgumentsFromParametersExtraction(1), + createStateFromGenericInClassOrParameters(AggregateFunction.class, 1), + createParameterVerification(true), + createOutputFromGenericInClass(AggregateFunction.class, 0, true), + null); + return extractTypeInference(mappingExtractor, false); } /** Extracts a type inference from a {@link TableFunction}. */ @@ -127,11 +135,12 @@ public static TypeInference forTableFunction( typeFactory, function, UserDefinedFunctionHelper.TABLE_EVAL, - createParameterSignatureExtraction(0), + createArgumentsFromParametersExtraction(0), null, - createGenericResultExtraction(TableFunction.class, 0, true), - createParameterVerification()); - return extractTypeInference(mappingExtractor); + null, + createOutputFromGenericInClass(TableFunction.class, 0, true), + createParameterVerification(false)); + return extractTypeInference(mappingExtractor, false); } /** Extracts a type inference from a {@link TableAggregateFunction}. */ @@ -142,11 +151,12 @@ public static TypeInference forTableAggregateFunction( typeFactory, function, UserDefinedFunctionHelper.TABLE_AGGREGATE_ACCUMULATE, - createParameterSignatureExtraction(1), - createGenericResultExtraction(TableAggregateFunction.class, 1, false), - createGenericResultExtraction(TableAggregateFunction.class, 0, true), - createParameterWithAccumulatorVerification()); - return extractTypeInference(mappingExtractor); + createArgumentsFromParametersExtraction(1), + createStateFromGenericInClassOrParameters(TableAggregateFunction.class, 1), + createParameterVerification(true), + createOutputFromGenericInClass(TableAggregateFunction.class, 0, true), + null); + return extractTypeInference(mappingExtractor, false); } /** Extracts a type inference from a {@link AsyncTableFunction}. */ @@ -157,11 +167,30 @@ public static TypeInference forAsyncTableFunction( typeFactory, function, UserDefinedFunctionHelper.ASYNC_TABLE_EVAL, - createParameterSignatureExtraction(1), + createArgumentsFromParametersExtraction(1), null, - createGenericResultExtraction(AsyncTableFunction.class, 0, true), - createParameterWithArgumentVerification(CompletableFuture.class)); - return extractTypeInference(mappingExtractor); + null, + createOutputFromGenericInClass(AsyncTableFunction.class, 0, true), + createParameterAndCompletableFutureVerification(function)); + return extractTypeInference(mappingExtractor, false); + } + + /** Extracts a type inference from a {@link ProcessTableFunction}. */ + public static TypeInference forProcessTableFunction( + DataTypeFactory typeFactory, Class> function) { + final FunctionMappingExtractor mappingExtractor = + new FunctionMappingExtractor( + typeFactory, + function, + UserDefinedFunctionHelper.PROCESS_TABLE_EVAL, + createArgumentsFromParametersExtraction( + 0, ProcessTableFunction.Context.class), + createStateFromParametersExtraction(), + createParameterAndOptionalContextVerification( + ProcessTableFunction.Context.class, true), + createOutputFromGenericInClass(ProcessTableFunction.class, 0, true), + null); + return extractTypeInference(mappingExtractor, true); } /** Extracts a type in inference from a {@link Procedure}. */ @@ -172,15 +201,16 @@ public static TypeInference forProcedure( typeFactory, procedure, ProcedureDefinition.PROCEDURE_CALL, - ProcedureMappingExtractor.createParameterSignatureExtraction(1), - ProcedureMappingExtractor.createReturnTypeResultExtraction(), - ProcedureMappingExtractor.createParameterAndReturnTypeVerification()); + createArgumentsFromParametersExtraction(1), + createOutputFromArrayReturnTypeInMethod(), + createParameterWithOptionalContextAndArrayReturnTypeVerification()); return extractTypeInference(mappingExtractor); } - private static TypeInference extractTypeInference(FunctionMappingExtractor mappingExtractor) { + private static TypeInference extractTypeInference( + FunctionMappingExtractor mappingExtractor, boolean requiresStaticSignature) { try { - return extractTypeInferenceOrError(mappingExtractor); + return extractTypeInferenceOrError(mappingExtractor, requiresStaticSignature); } catch (Throwable t) { throw extractionError( t, @@ -192,7 +222,9 @@ private static TypeInference extractTypeInference(FunctionMappingExtractor mappi private static TypeInference extractTypeInference(ProcedureMappingExtractor mappingExtractor) { try { - return extractTypeInferenceOrError(mappingExtractor); + final Map outputMapping = + mappingExtractor.extractOutputMapping(); + return buildInference(null, outputMapping, false); } catch (Throwable t) { throw extractionError( t, @@ -203,110 +235,118 @@ private static TypeInference extractTypeInference(ProcedureMappingExtractor mapp } private static TypeInference extractTypeInferenceOrError( - FunctionMappingExtractor mappingExtractor) { - final Map outputMapping = + FunctionMappingExtractor mappingExtractor, boolean requiresStaticSignature) { + final Map outputMapping = mappingExtractor.extractOutputMapping(); - if (!mappingExtractor.hasAccumulator()) { - return buildInference(null, outputMapping); + if (!mappingExtractor.supportsState()) { + return buildInference(null, outputMapping, requiresStaticSignature); } - final Map accumulatorMapping = - mappingExtractor.extractAccumulatorMapping(); - return buildInference(accumulatorMapping, outputMapping); - } + final Map stateMapping = + mappingExtractor.extractStateMapping(); - private static TypeInference extractTypeInferenceOrError( - ProcedureMappingExtractor mappingExtractor) { - final Map outputMapping = - mappingExtractor.extractOutputMapping(); - return buildInference(null, outputMapping); + return buildInference(stateMapping, outputMapping, requiresStaticSignature); } private static TypeInference buildInference( - @Nullable Map accumulatorMapping, - Map outputMapping) { + @Nullable Map stateMapping, + Map outputMapping, + boolean requiresStaticSignature) { final TypeInference.Builder builder = TypeInference.newBuilder(); - configureNamedArguments(builder, outputMapping); - configureOptionalArguments(builder, outputMapping); - configureTypedArguments(builder, outputMapping); - - builder.inputTypeStrategy(translateInputTypeStrategy(outputMapping)); + if (!configureStaticArguments(builder, outputMapping)) { + if (requiresStaticSignature) { + throw extractionError( + "Process table functions require a non-overloaded, non-vararg, and static signature."); + } + builder.inputTypeStrategy(translateInputTypeStrategy(outputMapping)); + } - if (accumulatorMapping != null) { - // verify that accumulator and output are derived from the same input strategy - if (!accumulatorMapping.keySet().equals(outputMapping.keySet())) { + if (stateMapping != null) { + // verify that state and output are derived from the same signatures + if (!stateMapping.keySet().equals(outputMapping.keySet())) { throw extractionError( - "Mismatch between accumulator signature and output signature. " + "Mismatch between state signature and output signature. " + "Both intermediate and output results must be derived from the same input strategy."); } - builder.accumulatorTypeStrategy(translateResultTypeStrategy(accumulatorMapping)); + builder.stateTypeStrategies(translateStateTypeStrategies(stateMapping)); } - builder.outputTypeStrategy(translateResultTypeStrategy(outputMapping)); + builder.outputTypeStrategy(translateOutputTypeStrategy(outputMapping)); + return builder.build(); } - private static void configureNamedArguments( + private static boolean configureStaticArguments( TypeInference.Builder builder, - Map outputMapping) { + Map outputMapping) { final Set signatures = outputMapping.keySet(); - if (signatures.stream().anyMatch(s -> s.isVarArgs || s.argumentNames == null)) { - return; + if (signatures.size() != 1) { + // Function is overloaded + return false; } - final List> argumentNames = - signatures.stream() - .map( - s -> { - assert s.argumentNames != null; - return Arrays.asList(s.argumentNames); - }) - .collect(Collectors.toList()); - if (argumentNames.size() != 1) { - return; + final List arguments = signatures.iterator().next().toStaticArguments(); + if (arguments == null) { + // Function is var arg or non-static (e.g. uses input groups instead of typed arguments) + return false; } - builder.namedArguments(argumentNames.iterator().next()); + builder.staticArguments(arguments); + return true; } - private static void configureOptionalArguments( - TypeInference.Builder builder, - Map outputMapping) { - final Set signatures = outputMapping.keySet(); - if (signatures.stream().anyMatch(s -> s.isVarArgs || s.argumentNames == null)) { - return; - } - final List> argumentOptional = - signatures.stream() - .filter(s -> s.argumentOptionals != null) - .map(s -> Arrays.asList(s.argumentOptionals)) - .collect(Collectors.toList()); - if (argumentOptional.size() != 1 || argumentOptional.size() != signatures.size()) { - return; - } - builder.optionalArguments(argumentOptional.get(0)); + private static InputTypeStrategy translateInputTypeStrategy( + Map outputMapping) { + return outputMapping.keySet().stream() + .map(FunctionSignatureTemplate::toInputTypeStrategy) + .reduce(InputTypeStrategies::or) + .orElse(InputTypeStrategies.sequence()); } - private static void configureTypedArguments( - TypeInference.Builder builder, - Map outputMapping) { - if (outputMapping.size() != 1) { - return; + private static LinkedHashMap translateStateTypeStrategies( + Map stateMapping) { + // Simple signatures don't require a mapping, default for process table functions + if (stateMapping.size() == 1) { + final FunctionStateTemplate template = + stateMapping.entrySet().iterator().next().getValue(); + return template.toStateTypeStrategies(); } - final FunctionSignatureTemplate signature = outputMapping.keySet().iterator().next(); - final List dataTypes = - signature.argumentTemplates.stream() - .map(a -> a.dataType) - .collect(Collectors.toList()); - if (!signature.isVarArgs && dataTypes.stream().allMatch(Objects::nonNull)) { - builder.typedArguments(dataTypes); + // For overloaded signatures to accumulators in aggregating functions + final Map mappings = + stateMapping.entrySet().stream() + .collect( + Collectors.toMap( + e -> e.getKey().toInputTypeStrategy(), + e -> e.getValue().toAccumulatorTypeStrategy())); + final StateTypeStrategy accumulatorStrategy = + StateTypeStrategyWrapper.of(TypeStrategies.mapping(mappings)); + final Set stateNames = + stateMapping.values().stream() + .map(FunctionStateTemplate::toAccumulatorStateName) + .collect(Collectors.toSet()); + if (stateMapping.size() > 1 && stateNames.size() > 1) { + throw extractionError( + "Overloaded aggregating functions must use the same name for state entries. " + + "Found: %s", + stateNames); } + final String stateName = stateNames.iterator().next(); + final LinkedHashMap stateTypeStrategies = new LinkedHashMap<>(); + stateTypeStrategies.put(stateName, accumulatorStrategy); + return stateTypeStrategies; } - private static TypeStrategy translateResultTypeStrategy( - Map resultMapping) { + private static TypeStrategy translateOutputTypeStrategy( + Map outputMapping) { + // Simple signatures don't require a mapping + if (outputMapping.size() == 1) { + final FunctionOutputTemplate template = + outputMapping.entrySet().iterator().next().getValue(); + return template.toTypeStrategy(); + } + // For overloaded signatures final Map mappings = - resultMapping.entrySet().stream() + outputMapping.entrySet().stream() .collect( Collectors.toMap( e -> e.getKey().toInputTypeStrategy(), @@ -314,12 +354,4 @@ private static TypeStrategy translateResultTypeStrategy( (t1, t2) -> t2)); return TypeStrategies.mapping(mappings); } - - private static InputTypeStrategy translateInputTypeStrategy( - Map outputMapping) { - return outputMapping.keySet().stream() - .map(FunctionSignatureTemplate::toInputTypeStrategy) - .reduce(InputTypeStrategies::or) - .orElse(InputTypeStrategies.sequence()); - } } diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/StaticArgument.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/StaticArgument.java index ad791585c7806c..1faf3a0deaa55d 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/StaticArgument.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/StaticArgument.java @@ -19,15 +19,21 @@ package org.apache.flink.table.types.inference; import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.table.api.ValidationException; import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.NullType; import org.apache.flink.table.types.logical.RowType; import org.apache.flink.table.types.logical.StructuredType; +import org.apache.flink.table.types.logical.utils.LogicalTypeChecks; import org.apache.flink.util.Preconditions; import javax.annotation.Nullable; import java.util.EnumSet; +import java.util.Objects; import java.util.Optional; +import java.util.stream.Collectors; /** * Describes an argument in a static signature that is not overloaded and does not support varargs. @@ -43,6 +49,8 @@ @PublicEvolving public class StaticArgument { + private static final RowType DUMMY_ROW_TYPE = RowType.of(new NullType()); + private final String name; private final @Nullable DataType dataType; private final @Nullable Class conversionClass; @@ -55,13 +63,15 @@ private StaticArgument( @Nullable Class conversionClass, boolean isOptional, EnumSet traits) { - StaticArgumentTrait.checkIntegrity( - Preconditions.checkNotNull(traits, "Traits must not be null.")); this.name = Preconditions.checkNotNull(name, "Name must not be null."); this.dataType = dataType; this.conversionClass = conversionClass; this.isOptional = isOptional; - this.traits = traits; + this.traits = Preconditions.checkNotNull(traits, "Traits must not be null."); + checkName(); + checkTraits(traits); + checkOptionalType(); + checkTableType(); } /** @@ -162,4 +172,121 @@ public boolean isOptional() { public EnumSet getTraits() { return traits; } + + @Override + public String toString() { + final StringBuilder s = new StringBuilder(); + // Possible signatures: + // (myScalar INT) + // (myTypedTable ROW {TABLE BY ROW}) + // (myUntypedTable {TABLE BY ROW}) + s.append(name); + if (dataType != null) { + s.append(" "); + s.append(dataType); + } + if (!traits.equals(EnumSet.of(StaticArgumentTrait.SCALAR))) { + s.append(" "); + s.append(traits.stream().map(Enum::name).collect(Collectors.joining(", ", "{", "}"))); + } + return s.toString(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final StaticArgument that = (StaticArgument) o; + return isOptional == that.isOptional + && Objects.equals(name, that.name) + && Objects.equals(dataType, that.dataType) + && Objects.equals(conversionClass, that.conversionClass) + && Objects.equals(traits, that.traits); + } + + @Override + public int hashCode() { + return Objects.hash(name, dataType, conversionClass, isOptional, traits); + } + + private void checkName() { + if (!TypeInference.PARAMETER_NAME_FORMAT.test(name)) { + throw new ValidationException( + String.format( + "Invalid argument name '%s'. An argument must follow " + + "the pattern [a-zA-Z_$][a-zA-Z_$0-9].", + name)); + } + } + + private void checkTraits(EnumSet traits) { + if (traits.stream().filter(t -> t.getRequirements().isEmpty()).count() != 1) { + throw new ValidationException( + String.format( + "Invalid argument traits for argument '%s'. " + + "An argument must be declared as either scalar, table, or model.", + name)); + } + traits.forEach( + trait -> + trait.getRequirements() + .forEach( + requirement -> { + if (!traits.contains(requirement)) { + throw new ValidationException( + String.format( + "Invalid argument traits for argument '%s'. Trait %s requires %s.", + name, trait, requirement)); + } + })); + } + + private void checkOptionalType() { + if (!isOptional) { + return; + } + // e.g. for untyped table arguments + if (dataType == null) { + return; + } + + final LogicalType type = dataType.getLogicalType(); + if (!type.isNullable() || !type.supportsInputConversion(dataType.getConversionClass())) { + throw new ValidationException( + String.format( + "Invalid data type for optional argument '%s'. " + + "An optional argument has to accept null values.", + name)); + } + } + + void checkTableType() { + if (!traits.contains(StaticArgumentTrait.TABLE)) { + return; + } + if (dataType == null + && conversionClass != null + && !DUMMY_ROW_TYPE.supportsInputConversion(conversionClass)) { + throw new ValidationException( + String.format( + "Invalid conversion class '%s' for argument '%s'. " + + "Polymorphic, untyped table arguments must use a row class.", + conversionClass.getName(), name)); + } + if (dataType != null) { + final LogicalType type = dataType.getLogicalType(); + if (traits.contains(StaticArgumentTrait.TABLE) + && !LogicalTypeChecks.isCompositeType(type)) { + throw new ValidationException( + String.format( + "Invalid data type '%s' for table argument '%s'. " + + "Typed table arguments must use a composite type (i.e. row or structured type).", + type, name)); + } + } + } } diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/StaticArgumentTrait.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/StaticArgumentTrait.java index 76a4e6e26902a6..0590d21a340cd3 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/StaticArgumentTrait.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/StaticArgumentTrait.java @@ -19,10 +19,8 @@ package org.apache.flink.table.types.inference; import org.apache.flink.annotation.PublicEvolving; -import org.apache.flink.table.api.ValidationException; import java.util.Arrays; -import java.util.EnumSet; import java.util.Set; import java.util.stream.Collectors; @@ -47,21 +45,7 @@ public enum StaticArgumentTrait { this.requirements = Arrays.stream(requirements).collect(Collectors.toSet()); } - public static void checkIntegrity(EnumSet traits) { - if (traits.stream().filter(t -> t.requirements.isEmpty()).count() != 1) { - throw new ValidationException( - "Invalid argument traits. An argument must be declared as either scalar, table, or model."); - } - traits.forEach( - trait -> - trait.requirements.forEach( - requirement -> { - if (!traits.contains(requirement)) { - throw new ValidationException( - String.format( - "Invalid argument traits. Trait %s requires %s.", - trait, requirement)); - } - })); + public Set getRequirements() { + return requirements; } } diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TypeInference.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TypeInference.java index 64d372b36c1e78..1939c34b6f18dc 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TypeInference.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TypeInference.java @@ -19,6 +19,7 @@ package org.apache.flink.table.types.inference; import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.table.api.ValidationException; import org.apache.flink.table.types.DataType; import org.apache.flink.util.Preconditions; @@ -28,6 +29,8 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Optional; +import java.util.function.Predicate; +import java.util.regex.Pattern; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -47,6 +50,10 @@ @PublicEvolving public final class TypeInference { + /** Format for both arguments and state entries. */ + static final Predicate PARAMETER_NAME_FORMAT = + Pattern.compile("^[a-zA-Z_$][a-zA-Z_$0-9]*$").asPredicate(); + private final @Nullable List staticArguments; private final InputTypeStrategy inputTypeStrategy; private final LinkedHashMap stateTypeStrategies; @@ -61,6 +68,7 @@ private TypeInference( this.inputTypeStrategy = inputTypeStrategy; this.stateTypeStrategies = stateTypeStrategies; this.outputTypeStrategy = outputTypeStrategy; + checkStateEntries(); } /** Builder for configuring and creating instances of {@link TypeInference}. */ @@ -144,6 +152,19 @@ public Optional getAccumulatorTypeStrategy() { return Optional.of(stateTypeStrategies.values().iterator().next()); } + private void checkStateEntries() { + // Verify state + final List invalidStateEntries = + stateTypeStrategies.keySet().stream() + .filter(n -> !PARAMETER_NAME_FORMAT.test(n)) + .collect(Collectors.toList()); + if (!invalidStateEntries.isEmpty()) { + throw new ValidationException( + "Invalid state names. A state entry must follow the pattern [a-zA-Z_$][a-zA-Z_$0-9]. But found: " + + invalidStateEntries); + } + } + // -------------------------------------------------------------------------------------------- // Builder // -------------------------------------------------------------------------------------------- diff --git a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/extraction/DataTypeExtractorTest.java b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/extraction/DataTypeExtractorTest.java index 03e1e66d137479..21eed2f0277124 100644 --- a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/extraction/DataTypeExtractorTest.java +++ b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/extraction/DataTypeExtractorTest.java @@ -579,7 +579,8 @@ static TestSpec forMethodOutput(String description, Class clazz) { final Method method = clazz.getMethods()[0]; return new TestSpec( description, - (lookup) -> DataTypeExtractor.extractFromMethodOutput(lookup, clazz, method)); + (lookup) -> + DataTypeExtractor.extractFromMethodReturnType(lookup, clazz, method)); } static TestSpec forMethodOutput(Class clazz) { diff --git a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/extraction/TypeInferenceExtractorTest.java b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/extraction/TypeInferenceExtractorTest.java index 5fbc6dc81d354c..9b7ee509926cff 100644 --- a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/extraction/TypeInferenceExtractorTest.java +++ b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/extraction/TypeInferenceExtractorTest.java @@ -31,6 +31,7 @@ import org.apache.flink.table.data.RowData; import org.apache.flink.table.functions.AggregateFunction; import org.apache.flink.table.functions.AsyncScalarFunction; +import org.apache.flink.table.functions.ProcessTableFunction; import org.apache.flink.table.functions.ScalarFunction; import org.apache.flink.table.functions.TableAggregateFunction; import org.apache.flink.table.functions.TableFunction; @@ -39,6 +40,10 @@ import org.apache.flink.table.types.inference.ArgumentTypeStrategy; import org.apache.flink.table.types.inference.InputTypeStrategies; import org.apache.flink.table.types.inference.InputTypeStrategy; +import org.apache.flink.table.types.inference.StateTypeStrategy; +import org.apache.flink.table.types.inference.StateTypeStrategyWrapper; +import org.apache.flink.table.types.inference.StaticArgument; +import org.apache.flink.table.types.inference.StaticArgumentTrait; import org.apache.flink.table.types.inference.TypeInference; import org.apache.flink.table.types.inference.TypeStrategies; import org.apache.flink.table.types.inference.TypeStrategy; @@ -50,7 +55,10 @@ import javax.annotation.Nullable; +import java.math.BigDecimal; +import java.util.ArrayList; import java.util.Arrays; +import java.util.EnumSet; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -74,17 +82,10 @@ private static Stream functionSpecs() { return Stream.of( // function hint defines everything TestSpec.forScalarFunction(FullFunctionHint.class) - .expectNamedArguments("i", "s") - .expectTypedArguments(DataTypes.INT(), DataTypes.STRING()) - .expectOutputMapping( - InputTypeStrategies.sequence( - new String[] {"i", "s"}, - new ArgumentTypeStrategy[] { - InputTypeStrategies.explicit(DataTypes.INT()), - InputTypeStrategies.explicit(DataTypes.STRING()) - }), - TypeStrategies.explicit(DataTypes.BOOLEAN())), - + .expectStaticArgument(StaticArgument.scalar("i", DataTypes.INT(), false)) + .expectStaticArgument(StaticArgument.scalar("s", DataTypes.STRING(), false)) + .expectOutput(TypeStrategies.explicit(DataTypes.BOOLEAN())), + // --- // function hint defines everything with overloading TestSpec.forScalarFunction(FullFunctionHints.class) .expectOutputMapping( @@ -95,7 +96,7 @@ private static Stream functionSpecs() { InputTypeStrategies.sequence( InputTypeStrategies.explicit(DataTypes.BIGINT())), TypeStrategies.explicit(DataTypes.BIGINT())), - + // --- // global output hint with local input overloading TestSpec.forScalarFunction(GlobalOutputFunctionHint.class) .expectOutputMapping( @@ -106,12 +107,12 @@ private static Stream functionSpecs() { InputTypeStrategies.sequence( InputTypeStrategies.explicit(DataTypes.STRING())), TypeStrategies.explicit(DataTypes.INT())), - + // --- // unsupported output overloading TestSpec.forScalarFunction(InvalidSingleOutputFunctionHint.class) .expectErrorMessage( "Function hints that lead to ambiguous results are not allowed."), - + // --- // global and local overloading TestSpec.forScalarFunction(SplitFullFunctionHints.class) .expectOutputMapping( @@ -122,22 +123,21 @@ private static Stream functionSpecs() { InputTypeStrategies.sequence( InputTypeStrategies.explicit(DataTypes.BIGINT())), TypeStrategies.explicit(DataTypes.BIGINT())), - + // --- // global and local overloading with unsupported output overloading TestSpec.forScalarFunction(InvalidFullOutputFunctionHint.class) .expectErrorMessage( "Function hints with same input definition but different result types are not allowed."), - + // --- // ignore argument names during overloading TestSpec.forScalarFunction(InvalidFullOutputFunctionWithArgNamesHint.class) .expectErrorMessage( "Function hints with same input definition but different result types are not allowed."), - + // --- // invalid data type hint TestSpec.forScalarFunction(IncompleteFunctionHint.class) - .expectErrorMessage( - "Data type hint does neither specify a data type nor input group for use as function argument."), - + .expectErrorMessage("Data type missing for scalar argument at position 1."), + // --- // varargs and ANY input group TestSpec.forScalarFunction(ComplexFunctionHint.class) .expectOutputMapping( @@ -149,7 +149,7 @@ private static Stream functionSpecs() { InputTypeStrategies.ANY }), TypeStrategies.explicit(DataTypes.BOOLEAN())), - + // --- // global input hints and local output hints TestSpec.forScalarFunction(GlobalInputFunctionHints.class) .expectOutputMapping( @@ -160,55 +160,33 @@ private static Stream functionSpecs() { InputTypeStrategies.sequence( InputTypeStrategies.explicit(DataTypes.BIGINT())), TypeStrategies.explicit(DataTypes.INT())), - + // --- // no arguments TestSpec.forScalarFunction(ZeroArgFunction.class) - .expectNamedArguments() - .expectTypedArguments() - .expectOutputMapping( - InputTypeStrategies.sequence( - new String[0], new ArgumentTypeStrategy[0]), - TypeStrategies.explicit(DataTypes.INT())), - + .expectEmptyStaticArguments() + .expectOutput(TypeStrategies.explicit(DataTypes.INT())), + // --- // no arguments async TestSpec.forAsyncScalarFunction(ZeroArgFunctionAsync.class) - .expectNamedArguments() - .expectTypedArguments() - .expectOutputMapping( - InputTypeStrategies.sequence( - new String[0], new ArgumentTypeStrategy[0]), - TypeStrategies.explicit(DataTypes.INT())), - + .expectEmptyStaticArguments() + .expectOutput(TypeStrategies.explicit(DataTypes.INT())), + // --- // test primitive arguments extraction TestSpec.forScalarFunction(MixedArgFunction.class) - .expectNamedArguments("i", "d") - .expectTypedArguments( - DataTypes.INT().notNull().bridgedTo(int.class), DataTypes.DOUBLE()) - .expectOutputMapping( - InputTypeStrategies.sequence( - new String[] {"i", "d"}, - new ArgumentTypeStrategy[] { - InputTypeStrategies.explicit( - DataTypes.INT().notNull().bridgedTo(int.class)), - InputTypeStrategies.explicit(DataTypes.DOUBLE()) - }), - TypeStrategies.explicit(DataTypes.INT())), - + .expectStaticArgument( + StaticArgument.scalar( + "i", DataTypes.INT().notNull().bridgedTo(int.class), false)) + .expectStaticArgument(StaticArgument.scalar("d", DataTypes.DOUBLE(), false)) + .expectOutput(TypeStrategies.explicit(DataTypes.INT())), + // --- // test primitive arguments extraction async TestSpec.forAsyncScalarFunction(MixedArgFunctionAsync.class) - .expectNamedArguments("i", "d") - .expectTypedArguments( - DataTypes.INT().notNull().bridgedTo(int.class), DataTypes.DOUBLE()) - .expectOutputMapping( - InputTypeStrategies.sequence( - new String[] {"i", "d"}, - new ArgumentTypeStrategy[] { - InputTypeStrategies.explicit( - DataTypes.INT().notNull().bridgedTo(int.class)), - InputTypeStrategies.explicit(DataTypes.DOUBLE()) - }), - TypeStrategies.explicit(DataTypes.INT())), - + .expectStaticArgument( + StaticArgument.scalar( + "i", DataTypes.INT().notNull().bridgedTo(int.class), false)) + .expectStaticArgument(StaticArgument.scalar("d", DataTypes.DOUBLE(), false)) + .expectOutput(TypeStrategies.explicit(DataTypes.INT())), + // --- // test overloaded arguments extraction TestSpec.forScalarFunction(OverloadedFunction.class) .expectOutputMapping( @@ -228,7 +206,7 @@ private static Stream functionSpecs() { }), TypeStrategies.explicit( DataTypes.BIGINT().notNull().bridgedTo(long.class))), - + // --- // test overloaded arguments extraction async TestSpec.forAsyncScalarFunction(OverloadedFunctionAsync.class) .expectOutputMapping( @@ -247,7 +225,7 @@ private static Stream functionSpecs() { InputTypeStrategies.explicit(DataTypes.STRING()) }), TypeStrategies.explicit(DataTypes.BIGINT())), - + // --- // test varying arguments extraction TestSpec.forScalarFunction(VarArgFunction.class) .expectOutputMapping( @@ -260,7 +238,7 @@ private static Stream functionSpecs() { DataTypes.INT().notNull().bridgedTo(int.class)) }), TypeStrategies.explicit(DataTypes.STRING())), - + // --- // test varying arguments extraction async TestSpec.forAsyncScalarFunction(VarArgFunctionAsync.class) .expectOutputMapping( @@ -273,7 +251,7 @@ private static Stream functionSpecs() { DataTypes.INT().notNull().bridgedTo(int.class)) }), TypeStrategies.explicit(DataTypes.STRING())), - + // --- // test varying arguments extraction with byte TestSpec.forScalarFunction(VarArgWithByteFunction.class) .expectOutputMapping( @@ -286,7 +264,7 @@ private static Stream functionSpecs() { .bridgedTo(byte.class)) }), TypeStrategies.explicit(DataTypes.STRING())), - + // --- // test varying arguments extraction with byte async TestSpec.forAsyncScalarFunction(VarArgWithByteFunctionAsync.class) .expectOutputMapping( @@ -299,57 +277,49 @@ private static Stream functionSpecs() { .bridgedTo(byte.class)) }), TypeStrategies.explicit(DataTypes.STRING())), - + // --- // output hint with input extraction TestSpec.forScalarFunction(ExtractWithOutputHintFunction.class) - .expectNamedArguments("i") - .expectTypedArguments(DataTypes.INT()) - .expectOutputMapping( - InputTypeStrategies.sequence( - new String[] {"i"}, - new ArgumentTypeStrategy[] { - InputTypeStrategies.explicit(DataTypes.INT()) - }), - TypeStrategies.explicit(DataTypes.INT())), - + .expectStaticArgument(StaticArgument.scalar("i", DataTypes.INT(), false)) + .expectOutput(TypeStrategies.explicit(DataTypes.INT())), + // --- // output hint with input extraction TestSpec.forAsyncScalarFunction(ExtractWithOutputHintFunctionAsync.class) - .expectNamedArguments("i") - .expectTypedArguments(DataTypes.INT()) - .expectOutputMapping( - InputTypeStrategies.sequence( - new String[] {"i"}, - new ArgumentTypeStrategy[] { - InputTypeStrategies.explicit(DataTypes.INT()) - }), - TypeStrategies.explicit(DataTypes.INT())), - + .expectStaticArgument(StaticArgument.scalar("i", DataTypes.INT(), false)) + .expectOutput(TypeStrategies.explicit(DataTypes.INT())), + // --- // output extraction with input hints TestSpec.forScalarFunction(ExtractWithInputHintFunction.class) - .expectNamedArguments("i", "b") - .expectTypedArguments(DataTypes.INT(), DataTypes.BOOLEAN()) - .expectOutputMapping( - InputTypeStrategies.sequence( - new String[] {"i", "b"}, - new ArgumentTypeStrategy[] { - InputTypeStrategies.explicit(DataTypes.INT()), - InputTypeStrategies.explicit(DataTypes.BOOLEAN()) - }), + .expectStaticArgument(StaticArgument.scalar("i", DataTypes.INT(), false)) + .expectStaticArgument( + StaticArgument.scalar("b", DataTypes.BOOLEAN(), false)) + .expectOutput( TypeStrategies.explicit( DataTypes.DOUBLE().notNull().bridgedTo(double.class))), - + // --- // different accumulator depending on input TestSpec.forAggregateFunction(InputDependentAccumulatorFunction.class) - .expectAccumulatorMapping( - InputTypeStrategies.sequence( - InputTypeStrategies.explicit(DataTypes.BIGINT())), - TypeStrategies.explicit( - DataTypes.ROW(DataTypes.FIELD("f", DataTypes.BIGINT())))) - .expectAccumulatorMapping( - InputTypeStrategies.sequence( - InputTypeStrategies.explicit(DataTypes.STRING())), - TypeStrategies.explicit( - DataTypes.ROW(DataTypes.FIELD("f", DataTypes.STRING())))) + .expectAccumulator( + TypeStrategies.mapping( + Map.of( + InputTypeStrategies.sequence( + InputTypeStrategies.explicit( + DataTypes.BIGINT())), + TypeStrategies.explicit( + DataTypes.ROW( + DataTypes.FIELD( + "f", + DataTypes + .BIGINT()))), + InputTypeStrategies.sequence( + InputTypeStrategies.explicit( + DataTypes.STRING())), + TypeStrategies.explicit( + DataTypes.ROW( + DataTypes.FIELD( + "f", + DataTypes + .STRING())))))) .expectOutputMapping( InputTypeStrategies.sequence( InputTypeStrategies.explicit(DataTypes.BIGINT())), @@ -358,81 +328,72 @@ private static Stream functionSpecs() { InputTypeStrategies.sequence( InputTypeStrategies.explicit(DataTypes.STRING())), TypeStrategies.explicit(DataTypes.STRING())), - + // --- // input, accumulator, and output are spread across the function TestSpec.forAggregateFunction(AggregateFunctionWithManyAnnotations.class) - .expectNamedArguments("r") - .expectTypedArguments( - DataTypes.ROW( - DataTypes.FIELD("i", DataTypes.INT()), - DataTypes.FIELD("b", DataTypes.BOOLEAN()))) - .expectAccumulatorMapping( - InputTypeStrategies.sequence( - new String[] {"r"}, - new ArgumentTypeStrategy[] { - InputTypeStrategies.explicit( - DataTypes.ROW( - DataTypes.FIELD("i", DataTypes.INT()), - DataTypes.FIELD( - "b", DataTypes.BOOLEAN()))) - }), + .expectStaticArgument( + StaticArgument.scalar( + "r", + DataTypes.ROW( + DataTypes.FIELD("i", DataTypes.INT()), + DataTypes.FIELD("b", DataTypes.BOOLEAN())), + false)) + .expectAccumulator( TypeStrategies.explicit( DataTypes.ROW(DataTypes.FIELD("b", DataTypes.BOOLEAN())))) - .expectOutputMapping( - InputTypeStrategies.sequence( - new String[] {"r"}, - new ArgumentTypeStrategy[] { - InputTypeStrategies.explicit( - DataTypes.ROW( - DataTypes.FIELD("i", DataTypes.INT()), - DataTypes.FIELD( - "b", DataTypes.BOOLEAN()))) - }), - TypeStrategies.explicit(DataTypes.STRING())), - + .expectOutput(TypeStrategies.explicit(DataTypes.STRING())), + // --- + // accumulator with state hint + TestSpec.forAggregateFunction(StateHintAggregateFunction.class) + .expectStaticArgument(StaticArgument.scalar("i", DataTypes.INT(), false)) + .expectState("myAcc", TypeStrategies.explicit(MyState.TYPE)) + .expectOutput(TypeStrategies.explicit(DataTypes.INT())), + // --- + // accumulator with state hint in function hint + TestSpec.forAggregateFunction(StateHintInFunctionHintAggregateFunction.class) + .expectStaticArgument(StaticArgument.scalar("i", DataTypes.INT(), false)) + .expectState("myAcc", TypeStrategies.explicit(MyState.TYPE)) + .expectOutput(TypeStrategies.explicit(DataTypes.INT())), + // --- // test for table functions TestSpec.forTableFunction(OutputHintTableFunction.class) - .expectNamedArguments("i") - .expectTypedArguments(DataTypes.INT().notNull().bridgedTo(int.class)) - .expectOutputMapping( - InputTypeStrategies.sequence( - new String[] {"i"}, - new ArgumentTypeStrategy[] { - InputTypeStrategies.explicit( - DataTypes.INT().notNull().bridgedTo(int.class)) - }), + .expectStaticArgument( + StaticArgument.scalar( + "i", DataTypes.INT().notNull().bridgedTo(int.class), false)) + .expectOutput( TypeStrategies.explicit( DataTypes.ROW( DataTypes.FIELD("i", DataTypes.INT()), DataTypes.FIELD("b", DataTypes.BOOLEAN())))), - + // --- // mismatch between hints and implementation regarding return type TestSpec.forScalarFunction(InvalidMethodScalarFunction.class) .expectErrorMessage( "Considering all hints, the method should comply with the signature:\n" + "java.lang.String eval(int[])"), - + // --- // mismatch between hints and implementation regarding return type TestSpec.forAsyncScalarFunction(InvalidMethodScalarFunctionAsync.class) .expectErrorMessage( "Considering all hints, the method should comply with the signature:\n" + "eval(java.util.concurrent.CompletableFuture, int[])"), - + // --- // mismatch between hints and implementation regarding accumulator TestSpec.forAggregateFunction(InvalidMethodAggregateFunction.class) .expectErrorMessage( "Considering all hints, the method should comply with the signature:\n" - + "accumulate(java.lang.Integer, int, boolean)"), - + + "accumulate(java.lang.Integer, int, boolean)\n" + + "Pattern: ( [, ]*)"), + // --- // no implementation TestSpec.forTableFunction(MissingMethodTableFunction.class) .expectErrorMessage( "Could not find a publicly accessible method named 'eval'."), - + // --- // named arguments with overloaded function // expected no named argument for overloaded function TestSpec.forScalarFunction(NamedArgumentsScalarFunction.class), - + // --- // scalar function that takes any input TestSpec.forScalarFunction(InputGroupScalarFunction.class) .expectOutputMapping( @@ -440,7 +401,7 @@ private static Stream functionSpecs() { new String[] {"o"}, new ArgumentTypeStrategy[] {InputTypeStrategies.ANY}), TypeStrategies.explicit(DataTypes.STRING())), - + // --- // scalar function that takes any input as vararg TestSpec.forScalarFunction(VarArgInputGroupScalarFunction.class) .expectOutputMapping( @@ -448,6 +409,7 @@ private static Stream functionSpecs() { new String[] {"o"}, new ArgumentTypeStrategy[] {InputTypeStrategies.ANY}), TypeStrategies.explicit(DataTypes.STRING())), + // --- TestSpec.forScalarFunction( "Scalar function with implicit overloading order", OrderedScalarFunction.class) @@ -465,6 +427,7 @@ private static Stream functionSpecs() { InputTypeStrategies.explicit(DataTypes.BIGINT()) }), TypeStrategies.explicit(DataTypes.BIGINT())), + // --- TestSpec.forScalarFunction( "Scalar function with explicit overloading order by class annotations", OrderedScalarFunction2.class) @@ -476,6 +439,7 @@ private static Stream functionSpecs() { InputTypeStrategies.sequence( InputTypeStrategies.explicit(DataTypes.INT())), TypeStrategies.explicit(DataTypes.INT())), + // --- TestSpec.forScalarFunction( "Scalar function with explicit overloading order by method annotations", OrderedScalarFunction3.class) @@ -487,138 +451,131 @@ private static Stream functionSpecs() { InputTypeStrategies.sequence( InputTypeStrategies.explicit(DataTypes.INT())), TypeStrategies.explicit(DataTypes.INT())), + // --- TestSpec.forTableFunction( "A data type hint on the class is used instead of a function output hint", DataTypeHintOnTableFunctionClass.class) - .expectNamedArguments() - .expectTypedArguments() - .expectOutputMapping( - InputTypeStrategies.sequence( - new String[] {}, new ArgumentTypeStrategy[] {}), + .expectEmptyStaticArguments() + .expectOutput( TypeStrategies.explicit( DataTypes.ROW(DataTypes.FIELD("i", DataTypes.INT())))), + // --- TestSpec.forTableFunction( "A data type hint on the method is used instead of a function output hint", DataTypeHintOnTableFunctionMethod.class) - .expectNamedArguments("i") - .expectTypedArguments(DataTypes.INT()) - .expectOutputMapping( - InputTypeStrategies.sequence( - new String[] {"i"}, - new ArgumentTypeStrategy[] { - InputTypeStrategies.explicit(DataTypes.INT()) - }), + .expectStaticArgument(StaticArgument.scalar("i", DataTypes.INT(), false)) + .expectOutput( TypeStrategies.explicit( DataTypes.ROW(DataTypes.FIELD("i", DataTypes.INT())))), + // --- TestSpec.forTableFunction( "Invalid data type hint on top of method and class", InvalidDataTypeHintOnTableFunction.class) .expectErrorMessage( "More than one data type hint found for output of function. " + "Please use a function hint instead."), + // --- TestSpec.forScalarFunction( "A data type hint on the method is used for enriching (not a function output hint)", DataTypeHintOnScalarFunction.class) - .expectNamedArguments() - .expectTypedArguments() - .expectOutputMapping( - InputTypeStrategies.sequence( - new String[] {}, new ArgumentTypeStrategy[] {}), + .expectEmptyStaticArguments() + .expectOutput( TypeStrategies.explicit( DataTypes.ROW(DataTypes.FIELD("i", DataTypes.INT())) .bridgedTo(RowData.class))), + // --- TestSpec.forAsyncScalarFunction( "A data type hint on the method is used for enriching (not a function output hint)", DataTypeHintOnScalarFunctionAsync.class) - .expectNamedArguments() - .expectTypedArguments() - .expectOutputMapping( - InputTypeStrategies.sequence( - new String[] {}, new ArgumentTypeStrategy[] {}), + .expectEmptyStaticArguments() + .expectOutput( TypeStrategies.explicit( DataTypes.ROW(DataTypes.FIELD("i", DataTypes.INT())) .bridgedTo(RowData.class))), + // --- TestSpec.forScalarFunction( "Scalar function with arguments hints", ArgumentHintScalarFunction.class) - .expectNamedArguments("f1", "f2") - .expectTypedArguments(DataTypes.STRING(), DataTypes.INT()) - .expectOutputMapping( - InputTypeStrategies.sequence( - new String[] {"f1", "f2"}, - new ArgumentTypeStrategy[] { - InputTypeStrategies.explicit(DataTypes.STRING()), - InputTypeStrategies.explicit(DataTypes.INT()) - }), - TypeStrategies.explicit(DataTypes.STRING())), + .expectStaticArgument( + StaticArgument.scalar("f1", DataTypes.STRING(), false)) + .expectStaticArgument(StaticArgument.scalar("f2", DataTypes.INT(), false)) + .expectOutput(TypeStrategies.explicit(DataTypes.STRING())), + // --- TestSpec.forScalarFunction( "Scalar function with arguments hints missing type", ArgumentHintMissingTypeScalarFunction.class) - .expectErrorMessage("The type of the argument at position 0 is not set."), + .expectErrorMessage("Data type missing for scalar argument at position 0."), + // --- TestSpec.forScalarFunction( "Scalar function with arguments hints all missing name", ArgumentHintMissingNameScalarFunction.class) - .expectNamedArguments("arg0", "arg1") - .expectTypedArguments(DataTypes.STRING(), DataTypes.INT()), + .expectOutputMapping( + InputTypeStrategies.sequence( + InputTypeStrategies.explicit(DataTypes.STRING()), + InputTypeStrategies.explicit(DataTypes.INT())), + TypeStrategies.explicit(DataTypes.STRING())), + // --- TestSpec.forScalarFunction( "Scalar function with arguments hints all missing partial name", ArgumentHintMissingPartialNameScalarFunction.class) .expectErrorMessage( - "The argument name in function hint must be either fully set or not set at all."), + "Argument names in function hint must be either fully set or not set at all."), + // --- TestSpec.forScalarFunction( "Scalar function with arguments hints name conflict", ArgumentHintNameConflictScalarFunction.class) .expectErrorMessage( "Argument name conflict, there are at least two argument names that are the same."), + // --- TestSpec.forScalarFunction( "Scalar function with arguments hints on method parameter", ArgumentHintOnParameterScalarFunction.class) - .expectNamedArguments("in1", "in2") - .expectTypedArguments(DataTypes.STRING(), DataTypes.INT()) - .expectOptionalArguments(false, false) - .expectOutputMapping( - InputTypeStrategies.sequence( - new String[] {"in1", "in2"}, - new ArgumentTypeStrategy[] { - InputTypeStrategies.explicit(DataTypes.STRING()), - InputTypeStrategies.explicit(DataTypes.INT()) - }), - TypeStrategies.explicit(DataTypes.STRING())), + .expectStaticArgument( + StaticArgument.scalar("in1", DataTypes.STRING(), false)) + .expectStaticArgument(StaticArgument.scalar("in2", DataTypes.INT(), false)) + .expectOutput(TypeStrategies.explicit(DataTypes.STRING())), + // --- TestSpec.forScalarFunction( "Scalar function with arguments hints and inputs hints both defined", ArgumentsAndInputsScalarFunction.class) .expectErrorMessage( "Argument and input hints cannot be declared in the same function hint."), + // --- TestSpec.forScalarFunction( - "Scalar function with argument hint and dataType hint declared in the same parameter", + "Scalar function with argument hint and data type hint declared in the same parameter", ArgumentsHintAndDataTypeHintScalarFunction.class) .expectErrorMessage( - "Argument and dataType hints cannot be declared in the same parameter at position 0."), + "Argument and data type hints cannot be declared at the same time at position 0."), + // --- TestSpec.forScalarFunction( "An invalid scalar function that declare FunctionHint for both class and method in the same class.", InvalidFunctionHintOnClassAndMethod.class) .expectErrorMessage( "Argument and input hints cannot be declared in the same function hint."), + // --- TestSpec.forScalarFunction( "A valid scalar class that declare FunctionHint for both class and method in the same class.", ValidFunctionHintOnClassAndMethod.class) - .expectNamedArguments("f1", "f2") - .expectTypedArguments(DataTypes.STRING(), DataTypes.INT()) - .expectOptionalArguments(true, true), + .expectStaticArgument(StaticArgument.scalar("f1", DataTypes.STRING(), true)) + .expectStaticArgument(StaticArgument.scalar("f2", DataTypes.INT(), true)), + // --- TestSpec.forScalarFunction( "The FunctionHint of the function conflicts with the method.", ScalarFunctionWithFunctionHintConflictMethod.class) .expectErrorMessage( "Considering all hints, the method should comply with the signature"), + // --- // For function with overloaded function, argument name will be empty TestSpec.forScalarFunction( "Scalar function with overloaded functions and arguments hint declared.", ArgumentsHintScalarFunctionWithOverloadedFunction.class), + // --- TestSpec.forScalarFunction( "Scalar function with argument type not null but optional.", ArgumentHintNotNullTypeWithOptionalsScalarFunction.class) .expectErrorMessage( "Argument at position 0 is optional but its type doesn't accept null value."), + // --- TestSpec.forScalarFunction( "Scalar function with arguments hint and variable length args", ArgumentHintVariableLengthScalarFunction.class) @@ -630,29 +587,143 @@ private static Stream functionSpecs() { InputTypeStrategies.explicit(DataTypes.INT()) }), TypeStrategies.explicit(DataTypes.STRING())), - TestSpec.forScalarFunction(FunctionHintTableArgScalarFunction.class) - .expectErrorMessage("Only scalar arguments are supported so far."), - TestSpec.forScalarFunction(ArgumentHintTableArgScalarFunction.class) - .expectErrorMessage("Only scalar arguments are supported so far."), - TestSpec.forScalarFunction(StateHintScalarFunction.class) - .expectErrorMessage("State hints are not supported yet.")); + // --- + TestSpec.forProcessTableFunction(StatelessProcessTableFunction.class) + .expectStaticArgument( + StaticArgument.scalar( + "i", DataTypes.INT().notNull().bridgedTo(int.class), false)) + .expectOutput(TypeStrategies.explicit(DataTypes.INT())), + // --- + TestSpec.forProcessTableFunction(StateProcessTableFunction.class) + .expectStaticArgument(StaticArgument.scalar("i", DataTypes.INT(), false)) + .expectState("s", TypeStrategies.explicit(MyState.TYPE)) + .expectOutput(TypeStrategies.explicit(DataTypes.INT())), + // --- + TestSpec.forProcessTableFunction(NamedStateProcessTableFunction.class) + .expectStaticArgument( + StaticArgument.scalar("myArg", DataTypes.INT(), false)) + .expectState("myState", TypeStrategies.explicit(MyState.TYPE)) + .expectOutput(TypeStrategies.explicit(DataTypes.INT())), + // --- + TestSpec.forProcessTableFunction(MultiStateProcessTableFunction.class) + .expectStaticArgument(StaticArgument.scalar("i", DataTypes.INT(), false)) + .expectState("s1", TypeStrategies.explicit(MyFirstState.TYPE)) + .expectState("s2", TypeStrategies.explicit(MySecondState.TYPE)) + .expectOutput(TypeStrategies.explicit(DataTypes.INT())), + // --- + TestSpec.forProcessTableFunction(UntypedTableArgProcessTableFunction.class) + .expectStaticArgument( + StaticArgument.table( + "t", + Row.class, + false, + EnumSet.of(StaticArgumentTrait.TABLE_AS_ROW))) + .expectOutput(TypeStrategies.explicit(DataTypes.INT())), + // --- + TestSpec.forProcessTableFunction(TypedTableArgProcessTableFunction.class) + .expectStaticArgument( + StaticArgument.table( + "t", + TypedTableArgProcessTableFunction.Customer.TYPE, + false, + EnumSet.of(StaticArgumentTrait.TABLE_AS_ROW))) + .expectOutput(TypeStrategies.explicit(DataTypes.INT())), + // --- + TestSpec.forProcessTableFunction(ComplexProcessTableFunction.class) + .expectStaticArgument( + StaticArgument.table( + "setTable", + RowData.class, + false, + EnumSet.of( + StaticArgumentTrait.TABLE_AS_SET, + StaticArgumentTrait.OPTIONAL_PARTITION_BY))) + .expectStaticArgument(StaticArgument.scalar("i", DataTypes.INT(), false)) + .expectStaticArgument( + StaticArgument.table( + "rowTable", + Row.class, + true, + EnumSet.of(StaticArgumentTrait.TABLE_AS_ROW))) + .expectStaticArgument(StaticArgument.scalar("s", DataTypes.STRING(), true)) + .expectState("s1", TypeStrategies.explicit(MyFirstState.TYPE)) + .expectState( + "other", + TypeStrategies.explicit( + DataTypes.ROW(DataTypes.FIELD("f", DataTypes.FLOAT())))) + .expectOutput( + TypeStrategies.explicit( + DataTypes.ROW(DataTypes.FIELD("b", DataTypes.BOOLEAN())))), + // --- + TestSpec.forProcessTableFunction(ComplexProcessTableFunctionWithFunctionHint.class) + .expectStaticArgument( + StaticArgument.table( + "setTable", + RowData.class, + false, + EnumSet.of( + StaticArgumentTrait.TABLE_AS_SET, + StaticArgumentTrait.OPTIONAL_PARTITION_BY))) + .expectStaticArgument(StaticArgument.scalar("i", DataTypes.INT(), false)) + .expectStaticArgument( + StaticArgument.table( + "rowTable", + Row.class, + true, + EnumSet.of(StaticArgumentTrait.TABLE_AS_ROW))) + .expectStaticArgument(StaticArgument.scalar("s", DataTypes.STRING(), true)) + .expectState("s1", TypeStrategies.explicit(MyFirstState.TYPE)) + .expectState( + "other", + TypeStrategies.explicit( + DataTypes.ROW(DataTypes.FIELD("f", DataTypes.FLOAT())))) + .expectOutput( + TypeStrategies.explicit( + DataTypes.ROW(DataTypes.FIELD("b", DataTypes.BOOLEAN())))), + // --- + TestSpec.forProcessTableFunction(WrongStateOrderProcessTableFunction.class) + .expectErrorMessage( + "Considering all hints, the method should comply with the signature:\n" + + "eval(org.apache.flink.table.types.extraction.TypeInferenceExtractorTest.MyFirstState, int)\n" + + "Pattern: (? [, ]* [, ]*)"), + // --- + TestSpec.forProcessTableFunction(MissingStateTypeProcessTableFunction.class) + .expectErrorMessage( + "Could not extract a data type from 'class java.lang.Object' in parameter 0 of method 'eval'"), + // --- + TestSpec.forProcessTableFunction(EnrichedExtractionStateProcessTableFunction.class) + .expectState("d", TypeStrategies.explicit(DataTypes.DECIMAL(3, 2))) + .expectOutput(TypeStrategies.explicit(DataTypes.INT())), + // --- + TestSpec.forProcessTableFunction(WrongTypedTableProcessTableFunction.class) + .expectErrorMessage( + "Invalid data type 'INT' for table argument 'i'. " + + "Typed table arguments must use a composite type (i.e. row or structured type)."), + // --- + TestSpec.forProcessTableFunction(WrongArgumentTraitsProcessTableFunction.class) + .expectErrorMessage( + "Invalid argument traits for argument 'r'. " + + "Trait OPTIONAL_PARTITION_BY requires TABLE_AS_SET."), + // --- + TestSpec.forProcessTableFunction( + MixingStaticAndInputGroupProcessTableFunction.class) + .expectErrorMessage( + "Process table functions require a non-overloaded, non-vararg, and static signature."), + // --- + TestSpec.forProcessTableFunction(MultiEvalProcessTableFunction.class) + .expectErrorMessage( + "Process table functions require a non-overloaded, non-vararg, and static signature.")); } private static Stream procedureSpecs() { return Stream.of( // procedure hint defines everything TestSpec.forProcedure(FullProcedureHint.class) - .expectNamedArguments("i", "s") - .expectTypedArguments(DataTypes.INT(), DataTypes.STRING()) - .expectOutputMapping( - InputTypeStrategies.sequence( - new String[] {"i", "s"}, - new ArgumentTypeStrategy[] { - InputTypeStrategies.explicit(DataTypes.INT()), - InputTypeStrategies.explicit(DataTypes.STRING()) - }), - TypeStrategies.explicit(DataTypes.BOOLEAN())), - // procedure hint defines everything + .expectStaticArgument(StaticArgument.scalar("i", DataTypes.INT(), false)) + .expectStaticArgument(StaticArgument.scalar("s", DataTypes.STRING(), false)) + .expectOutput(TypeStrategies.explicit(DataTypes.BOOLEAN())), + // --- + // procedure hints define everything TestSpec.forProcedure(FullProcedureHints.class) .expectOutputMapping( InputTypeStrategies.sequence( @@ -662,6 +733,7 @@ private static Stream procedureSpecs() { InputTypeStrategies.sequence( InputTypeStrategies.explicit(DataTypes.BIGINT())), TypeStrategies.explicit(DataTypes.BIGINT())), + // --- // global output hint with local input overloading TestSpec.forProcedure(GlobalOutputProcedureHint.class) .expectOutputMapping( @@ -672,6 +744,7 @@ private static Stream procedureSpecs() { InputTypeStrategies.sequence( InputTypeStrategies.explicit(DataTypes.STRING())), TypeStrategies.explicit(DataTypes.INT())), + // --- // global and local overloading TestSpec.forProcedure(SplitFullProcedureHints.class) .expectOutputMapping( @@ -682,6 +755,7 @@ private static Stream procedureSpecs() { InputTypeStrategies.sequence( InputTypeStrategies.explicit(DataTypes.BIGINT())), TypeStrategies.explicit(DataTypes.BIGINT())), + // --- // varargs and ANY input group TestSpec.forProcedure(ComplexProcedureHint.class) .expectOutputMapping( @@ -693,6 +767,7 @@ private static Stream procedureSpecs() { InputTypeStrategies.ANY }), TypeStrategies.explicit(DataTypes.BOOLEAN())), + // --- // global input hints and local output hints TestSpec.forProcedure(GlobalInputProcedureHints.class) .expectOutputMapping( @@ -703,28 +778,20 @@ private static Stream procedureSpecs() { InputTypeStrategies.sequence( InputTypeStrategies.explicit(DataTypes.BIGINT())), TypeStrategies.explicit(DataTypes.INT())), + // --- // no arguments TestSpec.forProcedure(ZeroArgProcedure.class) - .expectNamedArguments() - .expectTypedArguments() - .expectOutputMapping( - InputTypeStrategies.sequence( - new String[0], new ArgumentTypeStrategy[0]), - TypeStrategies.explicit(DataTypes.INT())), + .expectEmptyStaticArguments() + .expectOutput(TypeStrategies.explicit(DataTypes.INT())), + // --- // test primitive arguments extraction TestSpec.forProcedure(MixedArgProcedure.class) - .expectNamedArguments("i", "d") - .expectTypedArguments( - DataTypes.INT().notNull().bridgedTo(int.class), DataTypes.DOUBLE()) - .expectOutputMapping( - InputTypeStrategies.sequence( - new String[] {"i", "d"}, - new ArgumentTypeStrategy[] { - InputTypeStrategies.explicit( - DataTypes.INT().notNull().bridgedTo(int.class)), - InputTypeStrategies.explicit(DataTypes.DOUBLE()) - }), - TypeStrategies.explicit(DataTypes.INT())), + .expectStaticArgument( + StaticArgument.scalar( + "i", DataTypes.INT().notNull().bridgedTo(int.class), false)) + .expectStaticArgument(StaticArgument.scalar("d", DataTypes.DOUBLE(), false)) + .expectOutput(TypeStrategies.explicit(DataTypes.INT())), + // --- // test overloaded arguments extraction TestSpec.forProcedure(OverloadedProcedure.class) .expectOutputMapping( @@ -744,6 +811,7 @@ private static Stream procedureSpecs() { }), TypeStrategies.explicit( DataTypes.BIGINT().notNull().bridgedTo(long.class))), + // --- // test varying arguments extraction TestSpec.forProcedure(VarArgProcedure.class) .expectOutputMapping( @@ -756,6 +824,7 @@ private static Stream procedureSpecs() { DataTypes.INT().notNull().bridgedTo(int.class)) }), TypeStrategies.explicit(DataTypes.STRING())), + // --- // test varying arguments extraction with byte TestSpec.forProcedure(VarArgWithByteProcedure.class) .expectOutputMapping( @@ -768,33 +837,25 @@ private static Stream procedureSpecs() { .bridgedTo(byte.class)) }), TypeStrategies.explicit(DataTypes.STRING())), + // --- // output hint with input extraction TestSpec.forProcedure(ExtractWithOutputHintProcedure.class) - .expectNamedArguments("i") - .expectTypedArguments(DataTypes.INT()) - .expectOutputMapping( - InputTypeStrategies.sequence( - new String[] {"i"}, - new ArgumentTypeStrategy[] { - InputTypeStrategies.explicit(DataTypes.INT()) - }), - TypeStrategies.explicit(DataTypes.INT())), + .expectStaticArgument(StaticArgument.scalar("i", DataTypes.INT(), false)) + .expectOutput(TypeStrategies.explicit(DataTypes.INT())), + // --- // output extraction with input hints TestSpec.forProcedure(ExtractWithInputHintProcedure.class) - .expectNamedArguments("i", "b") - .expectTypedArguments(DataTypes.INT(), DataTypes.BOOLEAN()) - .expectOutputMapping( - InputTypeStrategies.sequence( - new String[] {"i", "b"}, - new ArgumentTypeStrategy[] { - InputTypeStrategies.explicit(DataTypes.INT()), - InputTypeStrategies.explicit(DataTypes.BOOLEAN()) - }), + .expectStaticArgument(StaticArgument.scalar("i", DataTypes.INT(), false)) + .expectStaticArgument( + StaticArgument.scalar("b", DataTypes.BOOLEAN(), false)) + .expectOutput( TypeStrategies.explicit( DataTypes.DOUBLE().notNull().bridgedTo(double.class))), + // --- // named arguments with overloaded function // expected no named argument for overloaded function TestSpec.forProcedure(NamedArgumentsProcedure.class), + // --- // procedure function that takes any input TestSpec.forProcedure(InputGroupProcedure.class) .expectOutputMapping( @@ -802,6 +863,7 @@ private static Stream procedureSpecs() { new String[] {"o"}, new ArgumentTypeStrategy[] {InputTypeStrategies.ANY}), TypeStrategies.explicit(DataTypes.STRING())), + // --- // procedure function that takes any input as vararg TestSpec.forProcedure(VarArgInputGroupProcedure.class) .expectOutputMapping( @@ -809,6 +871,7 @@ private static Stream procedureSpecs() { new String[] {"o"}, new ArgumentTypeStrategy[] {InputTypeStrategies.ANY}), TypeStrategies.explicit(DataTypes.STRING())), + // --- TestSpec.forProcedure( "Procedure with implicit overloading order", OrderedProcedure.class) .expectOutputMapping( @@ -825,6 +888,7 @@ private static Stream procedureSpecs() { InputTypeStrategies.explicit(DataTypes.BIGINT()) }), TypeStrategies.explicit(DataTypes.BIGINT())), + // --- TestSpec.forProcedure( "Procedure with explicit overloading order by class annotations", OrderedProcedure2.class) @@ -836,6 +900,7 @@ private static Stream procedureSpecs() { InputTypeStrategies.sequence( InputTypeStrategies.explicit(DataTypes.INT())), TypeStrategies.explicit(DataTypes.INT())), + // --- TestSpec.forProcedure( "Procedure with explicit overloading order by method annotations", OrderedProcedure3.class) @@ -847,181 +912,141 @@ private static Stream procedureSpecs() { InputTypeStrategies.sequence( InputTypeStrategies.explicit(DataTypes.INT())), TypeStrategies.explicit(DataTypes.INT())), + // --- TestSpec.forProcedure( "A data type hint on the method is used for enriching (not a function output hint)", DataTypeHintOnProcedure.class) - .expectNamedArguments() - .expectTypedArguments() - .expectOutputMapping( - InputTypeStrategies.sequence( - new String[] {}, new ArgumentTypeStrategy[] {}), + .expectEmptyStaticArguments() + .expectOutput( TypeStrategies.explicit( DataTypes.ROW(DataTypes.FIELD("i", DataTypes.INT())) .bridgedTo(RowData.class))), + // --- // unsupported output overloading TestSpec.forProcedure(InvalidSingleOutputProcedureHint.class) .expectErrorMessage( "Procedure hints that lead to ambiguous results are not allowed."), + // --- // global and local overloading with unsupported output overloading TestSpec.forProcedure(InvalidFullOutputProcedureHint.class) .expectErrorMessage( "Procedure hints with same input definition but different result types are not allowed."), + // --- // ignore argument names during overloading TestSpec.forProcedure(InvalidFullOutputProcedureWithArgNamesHint.class) .expectErrorMessage( "Procedure hints with same input definition but different result types are not allowed."), + // --- // invalid data type hint TestSpec.forProcedure(IncompleteProcedureHint.class) - .expectErrorMessage( - "Data type hint does neither specify a data type nor input group for use as function argument."), + .expectErrorMessage("Data type missing for scalar argument at position 1."), + // --- // mismatch between hints and implementation regarding return type TestSpec.forProcedure(InvalidMethodProcedure.class) .expectErrorMessage( "Considering all hints, the method should comply with the signature:\n" - + "java.lang.String[] call(_, int[])"), + + "java.lang.String[] call(_, int[])\n" + + "Pattern: ( [, ]*)"), + // --- // no implementation TestSpec.forProcedure(MissingMethodProcedure.class) .expectErrorMessage( "Could not find a publicly accessible method named 'call'."), + // --- TestSpec.forProcedure( "Named arguments procedure with argument hint on method", ArgumentHintOnMethodProcedure.class) - .expectNamedArguments("f1", "f2") - .expectTypedArguments(DataTypes.STRING(), DataTypes.INT()) - .expectOptionalArguments(true, true) - .expectOutputMapping( - InputTypeStrategies.sequence( - new String[] {"f1", "f2"}, - new ArgumentTypeStrategy[] { - InputTypeStrategies.explicit(DataTypes.STRING()), - InputTypeStrategies.explicit(DataTypes.INT()) - }), + .expectStaticArgument(StaticArgument.scalar("f1", DataTypes.STRING(), true)) + .expectStaticArgument(StaticArgument.scalar("f2", DataTypes.INT(), true)) + .expectOutput( TypeStrategies.explicit( DataTypes.INT().notNull().bridgedTo(int.class))), + // --- TestSpec.forProcedure( "Named arguments procedure with argument hint on class", ArgumentHintOnClassProcedure.class) - .expectNamedArguments("f1", "f2") - .expectTypedArguments(DataTypes.STRING(), DataTypes.INT()) - .expectOptionalArguments(true, true) - .expectOutputMapping( - InputTypeStrategies.sequence( - new String[] {"f1", "f2"}, - new ArgumentTypeStrategy[] { - InputTypeStrategies.explicit(DataTypes.STRING()), - InputTypeStrategies.explicit(DataTypes.INT()) - }), + .expectStaticArgument(StaticArgument.scalar("f1", DataTypes.STRING(), true)) + .expectStaticArgument(StaticArgument.scalar("f2", DataTypes.INT(), true)) + .expectOutput( TypeStrategies.explicit( DataTypes.INT().notNull().bridgedTo(int.class))), + // --- TestSpec.forProcedure( "Named arguments procedure with argument hint on parameter", ArgumentHintOnParameterProcedure.class) - .expectNamedArguments("parameter_f1", "parameter_f2") - .expectTypedArguments( - DataTypes.STRING(), DataTypes.INT().bridgedTo(int.class)) - .expectOptionalArguments(true, false) - .expectOutputMapping( - InputTypeStrategies.sequence( - new String[] {"parameter_f1", "parameter_f2"}, - new ArgumentTypeStrategy[] { - InputTypeStrategies.explicit(DataTypes.STRING()), - InputTypeStrategies.explicit( - DataTypes.INT().bridgedTo(int.class)) - }), + .expectStaticArgument( + StaticArgument.scalar("parameter_f1", DataTypes.STRING(), true)) + .expectStaticArgument( + StaticArgument.scalar( + "parameter_f2", + DataTypes.INT().bridgedTo(int.class), + false)) + .expectOutput( TypeStrategies.explicit( DataTypes.INT().notNull().bridgedTo(int.class))), + // --- TestSpec.forProcedure( "Named arguments procedure with argument hint on method and parameter", ArgumentHintOnMethodAndParameterProcedure.class) - .expectNamedArguments("local_f1", "local_f2") - .expectTypedArguments(DataTypes.STRING(), DataTypes.INT()) - .expectOptionalArguments(true, true) - .expectOutputMapping( - InputTypeStrategies.sequence( - new String[] {"local_f1", "local_f2"}, - new ArgumentTypeStrategy[] { - InputTypeStrategies.explicit(DataTypes.STRING()), - InputTypeStrategies.explicit(DataTypes.INT()) - }), + .expectStaticArgument( + StaticArgument.scalar("local_f1", DataTypes.STRING(), true)) + .expectStaticArgument( + StaticArgument.scalar("local_f2", DataTypes.INT(), true)) + .expectOutput( TypeStrategies.explicit( DataTypes.INT().notNull().bridgedTo(int.class))), + // --- TestSpec.forProcedure( "Named arguments procedure with argument hint on class and method", ArgumentHintOnClassAndMethodProcedure.class) - .expectNamedArguments("global_f1", "global_f2") - .expectTypedArguments(DataTypes.STRING(), DataTypes.INT()) - .expectOptionalArguments(false, false) - .expectOutputMapping( - InputTypeStrategies.sequence( - new String[] {"global_f1", "global_f2"}, - new ArgumentTypeStrategy[] { - InputTypeStrategies.explicit(DataTypes.STRING()), - InputTypeStrategies.explicit(DataTypes.INT()) - }), + .expectStaticArgument( + StaticArgument.scalar("global_f1", DataTypes.STRING(), false)) + .expectStaticArgument( + StaticArgument.scalar("global_f2", DataTypes.INT(), false)) + .expectOutput( TypeStrategies.explicit( DataTypes.INT().notNull().bridgedTo(int.class))), + // --- TestSpec.forProcedure( "Named arguments procedure with argument hint on class and method and parameter", ArgumentHintOnClassAndMethodAndParameterProcedure.class) - .expectNamedArguments("global_f1", "global_f2") - .expectTypedArguments(DataTypes.STRING(), DataTypes.INT()) - .expectOptionalArguments(false, false) - .expectOutputMapping( - InputTypeStrategies.sequence( - new String[] {"global_f1", "global_f2"}, - new ArgumentTypeStrategy[] { - InputTypeStrategies.explicit(DataTypes.STRING()), - InputTypeStrategies.explicit(DataTypes.INT()) - }), + .expectStaticArgument( + StaticArgument.scalar("global_f1", DataTypes.STRING(), false)) + .expectStaticArgument( + StaticArgument.scalar("global_f2", DataTypes.INT(), false)) + .expectOutput( TypeStrategies.explicit( DataTypes.INT().notNull().bridgedTo(int.class))), + // --- TestSpec.forProcedure( "Named arguments procedure with argument hint type not null but optional", ArgumentHintNotNullWithOptionalProcedure.class) .expectErrorMessage( "Argument at position 1 is optional but its type doesn't accept null value."), + // --- TestSpec.forProcedure( "Named arguments procedure with argument name conflict", ArgumentHintNameConflictProcedure.class) .expectErrorMessage( "Argument name conflict, there are at least two argument names that are the same."), + // --- TestSpec.forProcedure( "Named arguments procedure with optional type on primitive type", ArgumentHintOptionalOnPrimitiveParameterConflictProcedure.class) .expectErrorMessage( - "Argument at position 1 is optional but a primitive type doesn't accept null value.")); - } - - @ParameterizedTest(name = "{index}: {0}") - @MethodSource("testData") - void testArgumentNames(TestSpec testSpec) { - if (testSpec.expectedArgumentNames != null) { - assertThat(testSpec.typeInferenceExtraction.get().getNamedArguments()) - .isEqualTo(Optional.of(testSpec.expectedArgumentNames)); - } else if (testSpec.expectedErrorMessage == null) { - assertThat(testSpec.typeInferenceExtraction.get().getNamedArguments()) - .isEqualTo(Optional.empty()); - } - } - - @ParameterizedTest(name = "{index}: {0}") - @MethodSource("testData") - void testArgumentOptionals(TestSpec testSpec) { - if (testSpec.expectedArgumentOptionals != null) { - assertThat(testSpec.typeInferenceExtraction.get().getOptionalArguments()) - .isEqualTo(Optional.of(testSpec.expectedArgumentOptionals)); - } + "Considering all hints, the method should comply with the signature:\n" + + "int[] call(_, java.lang.String, java.lang.Integer)")); } @ParameterizedTest(name = "{index}: {0}") @MethodSource("testData") - void testArgumentTypes(TestSpec testSpec) { - if (testSpec.expectedArgumentTypes != null) { - assertThat(testSpec.typeInferenceExtraction.get().getTypedArguments()) - .isEqualTo(Optional.of(testSpec.expectedArgumentTypes)); - } else if (testSpec.expectedErrorMessage == null) { - assertThat(testSpec.typeInferenceExtraction.get().getTypedArguments()) - .isEqualTo(Optional.empty()); + void testStaticArguments(TestSpec testSpec) { + if (testSpec.expectedStaticArguments != null) { + final Optional> staticArguments = + testSpec.typeInferenceExtraction.get().getStaticArguments(); + assertThat(staticArguments).isPresent(); + assertThat(staticArguments.get()) + .containsExactlyElementsOf(testSpec.expectedStaticArguments); } } @@ -1039,16 +1064,14 @@ void testInputTypeStrategy(TestSpec testSpec) { @ParameterizedTest(name = "{index}: {0}") @MethodSource("testData") - void testAccumulatorTypeStrategy(TestSpec testSpec) { - if (!testSpec.expectedAccumulatorStrategies.isEmpty()) { - assertThat( - testSpec.typeInferenceExtraction - .get() - .getAccumulatorTypeStrategy() - .isPresent()) - .isEqualTo(true); - assertThat(testSpec.typeInferenceExtraction.get().getAccumulatorTypeStrategy().get()) - .isEqualTo(TypeStrategies.mapping(testSpec.expectedAccumulatorStrategies)); + void testStateTypeStrategies(TestSpec testSpec) { + if (!testSpec.expectedStateStrategies.isEmpty()) { + assertThat(testSpec.typeInferenceExtraction.get().getStateTypeStrategies()) + .isNotEmpty(); + assertThat(testSpec.typeInferenceExtraction.get().getStateTypeStrategies()) + .isEqualTo(testSpec.expectedStateStrategies); + } else if (testSpec.expectedErrorMessage == null) { + assertThat(testSpec.typeInferenceExtraction.get().getStateTypeStrategies()).isEmpty(); } } @@ -1056,8 +1079,13 @@ void testAccumulatorTypeStrategy(TestSpec testSpec) { @MethodSource("testData") void testOutputTypeStrategy(TestSpec testSpec) { if (!testSpec.expectedOutputStrategies.isEmpty()) { - assertThat(testSpec.typeInferenceExtraction.get().getOutputTypeStrategy()) - .isEqualTo(TypeStrategies.mapping(testSpec.expectedOutputStrategies)); + if (testSpec.expectedOutputStrategies.size() == 1) { + assertThat(testSpec.typeInferenceExtraction.get().getOutputTypeStrategy()) + .isEqualTo(testSpec.expectedOutputStrategies.values().iterator().next()); + } else { + assertThat(testSpec.typeInferenceExtraction.get().getOutputTypeStrategy()) + .isEqualTo(TypeStrategies.mapping(testSpec.expectedOutputStrategies)); + } } } @@ -1086,13 +1114,9 @@ static class TestSpec { final Supplier typeInferenceExtraction; - @Nullable List expectedArgumentNames; - - @Nullable List expectedArgumentOptionals; - - @Nullable List expectedArgumentTypes; + @Nullable List expectedStaticArguments; - Map expectedAccumulatorStrategies; + LinkedHashMap expectedStateStrategies; Map expectedOutputStrategies; @@ -1101,7 +1125,7 @@ static class TestSpec { private TestSpec(String description, Supplier typeInferenceExtraction) { this.description = description; this.typeInferenceExtraction = typeInferenceExtraction; - this.expectedAccumulatorStrategies = new LinkedHashMap<>(); + this.expectedStateStrategies = new LinkedHashMap<>(); this.expectedOutputStrategies = new LinkedHashMap<>(); } @@ -1161,6 +1185,14 @@ static TestSpec forTableAggregateFunction( new DataTypeFactoryMock(), function)); } + static TestSpec forProcessTableFunction(Class> function) { + return new TestSpec( + function.getSimpleName(), + () -> + TypeInferenceExtractor.forProcessTableFunction( + new DataTypeFactoryMock(), function)); + } + static TestSpec forProcedure(Class procedure) { return forProcedure(null, procedure); } @@ -1174,24 +1206,36 @@ static TestSpec forProcedure( new DataTypeFactoryMock(), procedure)); } - TestSpec expectNamedArguments(String... expectedArgumentNames) { - this.expectedArgumentNames = Arrays.asList(expectedArgumentNames); + TestSpec expectEmptyStaticArguments() { + this.expectedStaticArguments = new ArrayList<>(); + return this; + } + + TestSpec expectStaticArgument(StaticArgument argument) { + if (this.expectedStaticArguments == null) { + this.expectedStaticArguments = new ArrayList<>(); + } + this.expectedStaticArguments.add(argument); + return this; + } + + TestSpec expectAccumulator(TypeStrategy typeStrategy) { + this.expectedStateStrategies.put("acc", StateTypeStrategyWrapper.of(typeStrategy)); return this; } - TestSpec expectOptionalArguments(Boolean... expectedArgumentOptionals) { - this.expectedArgumentOptionals = Arrays.asList(expectedArgumentOptionals); + TestSpec expectState(String name, StateTypeStrategy stateTypeStrategy) { + this.expectedStateStrategies.put(name, stateTypeStrategy); return this; } - TestSpec expectTypedArguments(DataType... expectedArgumentTypes) { - this.expectedArgumentTypes = Arrays.asList(expectedArgumentTypes); + TestSpec expectState(String name, TypeStrategy typeStrategy) { + this.expectedStateStrategies.put(name, StateTypeStrategyWrapper.of(typeStrategy)); return this; } - TestSpec expectAccumulatorMapping( - InputTypeStrategy validator, TypeStrategy accumulatorStrategy) { - this.expectedAccumulatorStrategies.put(validator, accumulatorStrategy); + TestSpec expectState(LinkedHashMap stateTypeStrategy) { + this.expectedStateStrategies.putAll(stateTypeStrategy); return this; } @@ -1200,6 +1244,11 @@ TestSpec expectOutputMapping(InputTypeStrategy validator, TypeStrategy outputStr return this; } + TestSpec expectOutput(TypeStrategy outputStrategy) { + this.expectedOutputStrategies.put(InputTypeStrategies.WILDCARD, outputStrategy); + return this; + } + TestSpec expectErrorMessage(String expectedErrorMessage) { this.expectedErrorMessage = expectedErrorMessage; return this; @@ -1411,6 +1460,39 @@ public Row createAccumulator() { } } + private static class StateHintAggregateFunction extends AggregateFunction { + public void accumulate( + @StateHint(name = "myAcc") MyState acc, @ArgumentHint(name = "i") Integer i) {} + + @Override + public Integer getValue(MyState accumulator) { + return null; + } + + @Override + public MyState createAccumulator() { + return new MyState(); + } + } + + @FunctionHint( + state = {@StateHint(name = "myAcc", type = @DataTypeHint(bridgedTo = MyState.class))}, + arguments = {@ArgumentHint(name = "i", type = @DataTypeHint("INT"))}) + private static class StateHintInFunctionHintAggregateFunction + extends AggregateFunction { + public void accumulate(Object acc, Integer i) {} + + @Override + public Integer getValue(Object accumulator) { + return null; + } + + @Override + public Object createAccumulator() { + return new Object(); + } + } + @FunctionHint(output = @DataTypeHint("ROW")) private static class OutputHintTableFunction extends TableFunction { public void eval(int i) { @@ -2130,32 +2212,167 @@ public String eval(String f1, Integer... f2) { } } + private static class StatelessProcessTableFunction extends ProcessTableFunction { + public void eval(int i) {} + } + + public static class MyState { + static final DataType TYPE = + DataTypes.STRUCTURED( + MyState.class, + DataTypes.FIELD("d", DataTypes.DOUBLE().notNull().bridgedTo(double.class))); + public double d; + } + + public static class MyFirstState { + static final DataType TYPE = + DataTypes.STRUCTURED(MyFirstState.class, DataTypes.FIELD("d", DataTypes.DOUBLE())); + public Double d; + } + + public static class MySecondState { + static final DataType TYPE = + DataTypes.STRUCTURED(MySecondState.class, DataTypes.FIELD("i", DataTypes.INT())); + public Integer i; + } + + private static class StateProcessTableFunction extends ProcessTableFunction { + public void eval(@StateHint MyState s, Integer i) {} + } + + private static class NamedStateProcessTableFunction extends ProcessTableFunction { + public void eval( + @StateHint(name = "myState") MyState s, @ArgumentHint(name = "myArg") Integer i) {} + } + + private static class MultiStateProcessTableFunction extends ProcessTableFunction { + public void eval(@StateHint MyFirstState s1, @StateHint MySecondState s2, Integer i) {} + } + + private static class UntypedTableArgProcessTableFunction extends ProcessTableFunction { + public void eval(@ArgumentHint(ArgumentTrait.TABLE_AS_ROW) Row t) {} + } + + private static class TypedTableArgProcessTableFunction extends ProcessTableFunction { + public static class Customer { + static final DataType TYPE = + DataTypes.STRUCTURED( + Customer.class, + DataTypes.FIELD("age", DataTypes.INT()), + DataTypes.FIELD("name", DataTypes.STRING())); + public String name; + public Integer age; + } + + public void eval(@ArgumentHint(ArgumentTrait.TABLE_AS_ROW) Customer t) {} + } + + @DataTypeHint("ROW") + private static class ComplexProcessTableFunction extends ProcessTableFunction { + public void eval( + Context context, + @StateHint(name = "s1") MyFirstState s1, + @StateHint(name = "other", type = @DataTypeHint("ROW")) Row s2, + @ArgumentHint( + value = { + ArgumentTrait.TABLE_AS_SET, + ArgumentTrait.OPTIONAL_PARTITION_BY + }, + name = "setTable") + RowData t1, + @ArgumentHint(name = "i") Integer i, + @ArgumentHint( + value = {ArgumentTrait.TABLE_AS_ROW}, + name = "rowTable", + isOptional = true) + Row t2, + @ArgumentHint(isOptional = true, name = "s") String s) {} + } + @FunctionHint( + state = { + @StateHint(name = "s1", type = @DataTypeHint(bridgedTo = MyFirstState.class)), + @StateHint(name = "other", type = @DataTypeHint("ROW")) + }, arguments = { @ArgumentHint( - value = ArgumentTrait.TABLE_AS_ROW, - type = @DataTypeHint("ROW")) - }) - private static class FunctionHintTableArgScalarFunction extends ScalarFunction { - public String eval(Row table) { - return ""; - } + name = "setTable", + value = {ArgumentTrait.TABLE_AS_SET, ArgumentTrait.OPTIONAL_PARTITION_BY}, + type = @DataTypeHint(bridgedTo = RowData.class)), + @ArgumentHint(name = "i", type = @DataTypeHint("INT")), + @ArgumentHint( + name = "rowTable", + value = {ArgumentTrait.TABLE_AS_ROW}, + isOptional = true), + @ArgumentHint(name = "s", isOptional = true, type = @DataTypeHint("STRING")) + }, + output = @DataTypeHint("ROW")) + private static class ComplexProcessTableFunctionWithFunctionHint + extends ProcessTableFunction { + + public void eval( + Context context, + MyFirstState arg0, + Row arg1, + RowData arg2, + Integer arg3, + Row arg4, + String arg5) {} } - private static class ArgumentHintTableArgScalarFunction extends ScalarFunction { - public String eval( + private static class WrongStateOrderProcessTableFunction extends ProcessTableFunction { + + public void eval(int i, @StateHint MyFirstState state) {} + } + + private static class MissingStateTypeProcessTableFunction + extends ProcessTableFunction { + + public void eval(@StateHint Object state) {} + } + + private static class EnrichedExtractionStateProcessTableFunction + extends ProcessTableFunction { + + public void eval( + @StateHint( + type = + @DataTypeHint( + defaultDecimalPrecision = 3, + defaultDecimalScale = 2)) + BigDecimal d) {} + } + + private static class WrongTypedTableProcessTableFunction extends ProcessTableFunction { + public void eval(@ArgumentHint(ArgumentTrait.TABLE_AS_SET) Integer i) {} + } + + private static class WrongArgumentTraitsProcessTableFunction + extends ProcessTableFunction { + public void eval( + @ArgumentHint({ArgumentTrait.TABLE_AS_ROW, ArgumentTrait.OPTIONAL_PARTITION_BY}) + Row r) {} + } + + private static class MixingStaticAndInputGroupProcessTableFunction + extends ProcessTableFunction { + public void eval( + @ArgumentHint(ArgumentTrait.TABLE_AS_ROW) Row r, + @DataTypeHint(inputGroup = InputGroup.ANY) Object o) {} + } + + private static class InvalidInputGroupTableArgProcessTableFunction + extends ProcessTableFunction { + public void eval( @ArgumentHint( value = ArgumentTrait.TABLE_AS_ROW, - type = @DataTypeHint("ROW")) - Row table) { - return ""; - } + type = @DataTypeHint(inputGroup = InputGroup.ANY)) + Row r) {} } - @FunctionHint(state = @StateHint(name = "state", type = @DataTypeHint("INT"))) - private static class StateHintScalarFunction extends ScalarFunction { - public String eval() { - return ""; - } + private static class MultiEvalProcessTableFunction extends ProcessTableFunction { + public void eval(int i) {} + + public void eval(String i) {} } } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/inference/TypeInferenceOperandChecker.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/inference/TypeInferenceOperandChecker.java index 3d3301d3c48d68..cb5c6778996051 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/inference/TypeInferenceOperandChecker.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/inference/TypeInferenceOperandChecker.java @@ -27,6 +27,7 @@ import org.apache.flink.table.types.inference.ArgumentCount; import org.apache.flink.table.types.inference.CallContext; import org.apache.flink.table.types.inference.ConstantArgumentCount; +import org.apache.flink.table.types.inference.StaticArgument; import org.apache.flink.table.types.inference.TypeInference; import org.apache.flink.table.types.inference.TypeInferenceUtil; import org.apache.flink.table.types.logical.LogicalType; @@ -47,7 +48,6 @@ import org.apache.calcite.sql.validate.SqlValidatorNamespace; import java.util.List; -import java.util.Optional; import java.util.stream.Collectors; import static org.apache.flink.table.planner.calcite.FlinkTypeFactory.toLogicalType; @@ -83,8 +83,7 @@ public TypeInferenceOperandChecker( this.dataTypeFactory = dataTypeFactory; this.definition = definition; this.typeInference = typeInference; - this.countRange = - new ArgumentCountRange(typeInference.getInputTypeStrategy().getArgumentCount()); + this.countRange = new ArgumentCountRange(deriveArgumentCount(typeInference)); } @Override @@ -105,20 +104,7 @@ public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFail @Override public SqlOperandCountRange getOperandCountRange() { - if (typeInference.getOptionalArguments().isPresent() - && typeInference.getOptionalArguments().get().stream() - .anyMatch(Boolean::booleanValue)) { - int notOptionalCount = - (int) - typeInference.getOptionalArguments().get().stream() - .filter(optional -> !optional) - .count(); - ArgumentCount argumentCount = - ConstantArgumentCount.between(notOptionalCount, countRange.getMax()); - return new ArgumentCountRange(argumentCount); - } else { - return countRange; - } + return countRange; } @Override @@ -133,22 +119,21 @@ public Consistency getConsistency() { @Override public boolean isOptional(int i) { - Optional> optionalArguments = typeInference.getOptionalArguments(); - if (optionalArguments.isPresent()) { - return optionalArguments.get().get(i); - } else { + if (typeInference.getStaticArguments().isEmpty()) { return false; } + final List staticArgs = typeInference.getStaticArguments().get(); + return staticArgs.get(i).isOptional(); } @Override public boolean isFixedParameters() { - // This method returns true only if optional arguments are declared and at least one - // optional argument is present. + // This method returns true only if at least one optional argument is present. // Otherwise, it defaults to false, bypassing the parameter check in Calcite. - return typeInference.getOptionalArguments().isPresent() - && typeInference.getOptionalArguments().get().stream() - .anyMatch(Boolean::booleanValue); + return typeInference + .getStaticArguments() + .map(args -> args.stream().anyMatch(StaticArgument::isOptional)) + .orElse(false); } @Override @@ -239,4 +224,17 @@ private void updateInferredType(SqlValidator validator, SqlNode node, RelDataTyp namespace.setType(type); } } + + private static ArgumentCount deriveArgumentCount(TypeInference typeInference) { + final int staticArgs = typeInference.getStaticArguments().map(List::size).orElse(-1); + if (staticArgs == -1) { + return typeInference.getInputTypeStrategy().getArgumentCount(); + } + final int optionalArgs = + typeInference + .getStaticArguments() + .map(args -> (int) args.stream().filter(StaticArgument::isOptional).count()) + .orElse(0); + return ConstantArgumentCount.between(staticArgs - optionalArgs, staticArgs); + } } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/operations/PlannerCallProcedureOperation.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/operations/PlannerCallProcedureOperation.java index 66d6a6074e3522..a88a5e808d707f 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/operations/PlannerCallProcedureOperation.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/operations/PlannerCallProcedureOperation.java @@ -171,12 +171,13 @@ private Object callProcedure(Procedure procedure, Class[] inputClz, Object[] methods.stream() .filter( method -> - ExtractionUtils.isInvokable(method, inputClz) + ExtractionUtils.isInvokable(false, method, inputClz) && method.getReturnType().isArray() && isAssignable( outputType.getConversionClass(), method.getReturnType().getComponentType(), - true)) + true, + false)) .collect(Collectors.toList()); if (callMethods.isEmpty()) { throw new ValidationException( diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/operations/SqlNodeToCallOperationTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/operations/SqlNodeToCallOperationTest.java index 0a0369bc3dce90..f17259c9c69f7d 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/operations/SqlNodeToCallOperationTest.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/operations/SqlNodeToCallOperationTest.java @@ -212,7 +212,7 @@ private static class ProcedureWithNamedArguments implements Procedure { input = {@DataTypeHint("STRING"), @DataTypeHint("STRING")}, output = @DataTypeHint("INT"), argumentNames = {"c", "d"}) - public int[] call(ProcedureContext context, String arg3, String arg4) { + public java.lang.Integer[] call(ProcedureContext context, String arg3, String arg4) { return null; } } diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/utils/userDefinedScalarFunctions.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/utils/userDefinedScalarFunctions.scala index 6a1fd981620dfa..2ee92e7cda0c72 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/utils/userDefinedScalarFunctions.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/utils/userDefinedScalarFunctions.scala @@ -108,7 +108,7 @@ class RichFunc1 extends ScalarFunction { @FunctionHint( input = Array(new DataTypeHint("INT")), output = new DataTypeHint(value = "INT", bridgedTo = classOf[JInt])) - def eval(index: Int): Int = { + def eval(index: JInt): JInt = { index + added } diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/utils/UserDefinedFunctionTestUtils.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/utils/UserDefinedFunctionTestUtils.scala index 72b2678f79a5f6..0556e2c529f0f4 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/utils/UserDefinedFunctionTestUtils.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/utils/UserDefinedFunctionTestUtils.scala @@ -382,20 +382,15 @@ object UserDefinedFunctionTestUtils { TestAddWithOpen.aliveCounter.incrementAndGet() } - @FunctionHint( - input = Array( - new DataTypeHint(value = "BIGINT", bridgedTo = classOf[JLong]), - new DataTypeHint(value = "BIGINT", bridgedTo = classOf[JLong])), - output = new DataTypeHint(value = "BIGINT", bridgedTo = classOf[JLong])) - def eval(a: Long, b: Long): Long = { + def eval(a: JLong, b: JLong): JLong = { if (!isOpened) { throw new IllegalStateException("Open method is not called.") } a + b } - def eval(a: Long, b: Int): Long = { - eval(a, b.asInstanceOf[Long]) + def eval(a: JLong, b: JInt): JLong = { + eval(a, b.toLong) } override def close(): Unit = { @@ -411,13 +406,7 @@ object UserDefinedFunctionTestUtils { @SerialVersionUID(1L) object TestMod extends ScalarFunction { - @FunctionHint( - input = Array( - new DataTypeHint(value = "BIGINT", bridgedTo = classOf[JLong]), - new DataTypeHint(value = "INT", bridgedTo = classOf[JInt]) - ), - output = new DataTypeHint(value = "BIGINT", bridgedTo = classOf[JLong])) - def eval(src: Long, mod: Int): Long = { + def eval(src: JLong, mod: JInt): JLong = { src % mod } }