From ae8a1dee12e2a6101341c233fc2755943c655c73 Mon Sep 17 00:00:00 2001 From: Timo Walther Date: Wed, 18 Dec 2024 13:42:07 +0100 Subject: [PATCH] Additional validation added --- .../extraction/BaseMappingExtractor.java | 55 +++++++++++++------ .../extraction/FunctionMappingExtractor.java | 29 +++++----- .../extraction/ProcedureMappingExtractor.java | 10 +++- .../TypeInferenceExtractorTest.java | 12 ++++ .../SqlNodeToCallOperationTest.java | 2 +- 5 files changed, 73 insertions(+), 35 deletions(-) diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/BaseMappingExtractor.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/BaseMappingExtractor.java index 2e849eb1d68e2..0e34a18d1219d 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/BaseMappingExtractor.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/BaseMappingExtractor.java @@ -72,6 +72,9 @@ */ abstract class BaseMappingExtractor { + private static final EnumSet SCALAR_TRAITS = + EnumSet.of(StaticArgumentTrait.SCALAR); + protected final DataTypeFactory typeFactory; protected final String methodName; @@ -188,8 +191,15 @@ static ResultExtraction createStateFromGenericInClassOrParameters( protected abstract String getHintType(); - protected static Class[] assembleParameters(List> state, List> arguments) { - return Stream.concat(state.stream(), arguments.stream()).toArray(Class[]::new); + protected static Class[] assembleParameters( + @Nullable FunctionStateTemplate state, FunctionSignatureTemplate arguments) { + return Stream.concat( + Optional.ofNullable(state) + .map(FunctionStateTemplate::toClassList) + .orElse(List.of()) + .stream(), + arguments.toClassList().stream()) + .toArray(Class[]::new); } protected static ValidationException createMethodNotFoundError( @@ -258,19 +268,33 @@ Map extractResultMappings( return (Map) collectedMappings; } - protected static void checkNoState(@Nullable List> state) { - if (state != null && !state.isEmpty()) { + protected static void checkNoState(@Nullable FunctionStateTemplate state) { + if (state != null) { throw extractionError("State is not supported for this kind of function."); } } - protected static void checkSingleState(@Nullable List> state) { - if (state == null || state.size() != 1) { + protected static void checkSingleState(@Nullable FunctionStateTemplate state) { + if (state == null || state.toClassList().size() != 1) { throw extractionError( "Aggregating functions need exactly one state entry for the accumulator."); } } + protected static void checkScalarArgumentsOnly(FunctionSignatureTemplate arguments) { + final EnumSet[] argumentTraits = arguments.argumentTraits; + IntStream.range(0, argumentTraits.length) + .forEach( + pos -> { + if (!argumentTraits[pos].equals(SCALAR_TRAITS)) { + throw extractionError( + "Only scalar arguments are supported at this location. " + + "But argument '%s' declared the following traits: %s", + arguments.argumentNames[pos], argumentTraits[pos]); + } + }); + } + // -------------------------------------------------------------------------------------------- // Helper methods // -------------------------------------------------------------------------------------------- @@ -472,16 +496,11 @@ private void verifyMappingForMethod( (signature, result) -> { if (result instanceof FunctionStateTemplate) { final FunctionStateTemplate stateTemplate = (FunctionStateTemplate) result; - verification.verify( - method, stateTemplate.toClassList(), signature.toClassList(), null); + verification.verify(method, stateTemplate, signature, null); } else if (result instanceof FunctionOutputTemplate) { final FunctionOutputTemplate outputTemplate = (FunctionOutputTemplate) result; - verification.verify( - method, - List.of(), - signature.toClassList(), - outputTemplate.toClass()); + verification.verify(method, null, signature, outputTemplate); } }); } @@ -668,7 +687,7 @@ private static EnumSet[] extractArgumentTraits( final ArgumentHint argumentHint = arg.parameter.getAnnotation(ArgumentHint.class); if (argumentHint == null) { - return EnumSet.of(StaticArgumentTrait.SCALAR); + return SCALAR_TRAITS; } final List traits = Arrays.stream(argumentHint.value()) @@ -735,11 +754,11 @@ interface ResultExtraction { } /** Verifies the signature of a method. */ - protected interface MethodVerification { + interface MethodVerification { void verify( Method method, - List> state, - List> arguments, - @Nullable Class result); + @Nullable FunctionStateTemplate state, + FunctionSignatureTemplate arguments, + @Nullable FunctionOutputTemplate result); } } diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionMappingExtractor.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionMappingExtractor.java index 44b0491f6a0c8..d1de71e9c39a5 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionMappingExtractor.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/FunctionMappingExtractor.java @@ -166,14 +166,16 @@ static ResultExtraction createOutputFromGenericInMethod( static MethodVerification createParameterAndReturnTypeVerification() { return (method, state, arguments, result) -> { checkNoState(state); - final Class[] parameters = assembleParameters(state, arguments); + checkScalarArgumentsOnly(arguments); + final Class[] parameters = assembleParameters(null, arguments); + assert result != null; + final Class resultClass = result.toClass(); final Class returnType = method.getReturnType(); - // TODO enable strict autoboxing final boolean isValid = - isInvokable(false, method, parameters) - && isAssignable(result, returnType, true, false); + isInvokable(true, method, parameters) + && isAssignable(resultClass, returnType, true, false); if (!isValid) { - throw createMethodNotFoundError(method.getName(), parameters, result, ""); + throw createMethodNotFoundError(method.getName(), parameters, resultClass, ""); } }; } @@ -186,9 +188,9 @@ static MethodVerification createParameterVerification(boolean requireAccumulator } else { checkNoState(state); } + checkScalarArgumentsOnly(arguments); final Class[] parameters = assembleParameters(state, arguments); - // TODO enable strict autoboxing - if (!isInvokable(false, method, parameters)) { + if (!isInvokable(true, method, parameters)) { throw createMethodNotFoundError( method.getName(), parameters, @@ -205,11 +207,13 @@ static MethodVerification createParameterVerification(boolean requireAccumulator static MethodVerification createParameterAndCompletableFutureVerification(Class baseClass) { return (method, state, arguments, result) -> { checkNoState(state); - final Class[] parameters = assembleParameters(state, arguments); + checkScalarArgumentsOnly(arguments); + final Class[] parameters = assembleParameters(null, arguments); final Class[] parametersWithFuture = Stream.concat(Stream.of(CompletableFuture.class), Arrays.stream(parameters)) .toArray(Class[]::new); - + assert result != null; + final Class resultClass = result.toClass(); Type genericType = method.getGenericParameterTypes()[0]; genericType = resolveVariableWithClassContext(baseClass, genericType); if (!(genericType instanceof ParameterizedType)) { @@ -218,10 +222,9 @@ static MethodVerification createParameterAndCompletableFutureVerification(Class< method.getName(), 0); } final Type returnType = ((ParameterizedType) genericType).getActualTypeArguments()[0]; - Class returnClazz = getClassFromType(returnType); - // TODO enable strict autoboxing - if (!(isInvokable(false, method, parametersWithFuture) - && isAssignable(result, returnClazz, true, false))) { + Class returnTypeClass = getClassFromType(returnType); + if (!(isInvokable(true, method, parametersWithFuture) + && isAssignable(resultClass, returnTypeClass, true, false))) { throw createMethodNotFoundError( method.getName(), parametersWithFuture, diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/ProcedureMappingExtractor.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/ProcedureMappingExtractor.java index 67f4e2ddfc095..1494c31e15f56 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/ProcedureMappingExtractor.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/extraction/ProcedureMappingExtractor.java @@ -87,21 +87,25 @@ static ResultExtraction createOutputFromArrayReturnTypeInMethod() { static MethodVerification createParameterWithOptionalContextAndArrayReturnTypeVerification() { return (method, state, arguments, result) -> { checkNoState(state); - final Class[] parameters = assembleParameters(state, arguments); + checkScalarArgumentsOnly(arguments); + final Class[] parameters = assembleParameters(null, arguments); // ignore the ProcedureContext in the first argument final Class[] parametersWithContext = Stream.concat(Stream.of((Class) null), Arrays.stream(parameters)) .toArray(Class[]::new); + assert result != null; + final Class resultClass = result.toClass(); final Class returnType = method.getReturnType(); final boolean isValid = isInvokable(true, method, parametersWithContext) && returnType.isArray() - && isAssignable(result, returnType.getComponentType(), true, false); + && isAssignable( + resultClass, returnType.getComponentType(), true, false); if (!isValid) { throw createMethodNotFoundError( method.getName(), parametersWithContext, - Array.newInstance(result, 0).getClass(), + Array.newInstance(resultClass, 0).getClass(), "( [, ]*)"); } }; diff --git a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/extraction/TypeInferenceExtractorTest.java b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/extraction/TypeInferenceExtractorTest.java index 5774d94fdd3ca..867941da344fb 100644 --- a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/extraction/TypeInferenceExtractorTest.java +++ b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/extraction/TypeInferenceExtractorTest.java @@ -305,6 +305,12 @@ private static Stream functionSpecs() { TypeStrategies.explicit( DataTypes.DOUBLE().notNull().bridgedTo(double.class))), // --- + // non-scalar args + TestSpec.forScalarFunction(TableArgScalarFunction.class) + .expectErrorMessage( + "Only scalar arguments are supported at this location. " + + "But argument 't' declared the following traits: [TABLE_AS_ROW]"), + // --- // different accumulator depending on input TestSpec.forAggregateFunction(InputDependentAccumulatorFunction.class) .expectAccumulator( @@ -2401,4 +2407,10 @@ public void eval(int i) {} public void eval(String i) {} } + + private static class TableArgScalarFunction extends ScalarFunction { + public int eval(@ArgumentHint(ArgumentTrait.TABLE_AS_ROW) Row t) { + return 0; + } + } } diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/operations/SqlNodeToCallOperationTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/operations/SqlNodeToCallOperationTest.java index f17259c9c69f7..0a0369bc3dce9 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/operations/SqlNodeToCallOperationTest.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/operations/SqlNodeToCallOperationTest.java @@ -212,7 +212,7 @@ private static class ProcedureWithNamedArguments implements Procedure { input = {@DataTypeHint("STRING"), @DataTypeHint("STRING")}, output = @DataTypeHint("INT"), argumentNames = {"c", "d"}) - public java.lang.Integer[] call(ProcedureContext context, String arg3, String arg4) { + public int[] call(ProcedureContext context, String arg3, String arg4) { return null; } }