Skip to content

Commit

Permalink
Additional validation added
Browse files Browse the repository at this point in the history
  • Loading branch information
twalthr committed Dec 18, 2024
1 parent a52847b commit ae8a1de
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@
*/
abstract class BaseMappingExtractor {

private static final EnumSet<StaticArgumentTrait> SCALAR_TRAITS =
EnumSet.of(StaticArgumentTrait.SCALAR);

protected final DataTypeFactory typeFactory;

protected final String methodName;
Expand Down Expand Up @@ -188,8 +191,15 @@ static ResultExtraction createStateFromGenericInClassOrParameters(

protected abstract String getHintType();

protected static Class<?>[] assembleParameters(List<Class<?>> state, List<Class<?>> arguments) {
return Stream.concat(state.stream(), arguments.stream()).toArray(Class[]::new);
protected static Class<?>[] assembleParameters(
@Nullable FunctionStateTemplate state, FunctionSignatureTemplate arguments) {
return Stream.concat(
Optional.ofNullable(state)
.map(FunctionStateTemplate::toClassList)
.orElse(List.of())
.stream(),
arguments.toClassList().stream())
.toArray(Class[]::new);
}

protected static ValidationException createMethodNotFoundError(
Expand Down Expand Up @@ -258,19 +268,33 @@ Map<FunctionSignatureTemplate, T> extractResultMappings(
return (Map<FunctionSignatureTemplate, T>) collectedMappings;
}

protected static void checkNoState(@Nullable List<Class<?>> state) {
if (state != null && !state.isEmpty()) {
protected static void checkNoState(@Nullable FunctionStateTemplate state) {
if (state != null) {
throw extractionError("State is not supported for this kind of function.");
}
}

protected static void checkSingleState(@Nullable List<Class<?>> state) {
if (state == null || state.size() != 1) {
protected static void checkSingleState(@Nullable FunctionStateTemplate state) {
if (state == null || state.toClassList().size() != 1) {
throw extractionError(
"Aggregating functions need exactly one state entry for the accumulator.");
}
}

protected static void checkScalarArgumentsOnly(FunctionSignatureTemplate arguments) {
final EnumSet<StaticArgumentTrait>[] argumentTraits = arguments.argumentTraits;
IntStream.range(0, argumentTraits.length)
.forEach(
pos -> {
if (!argumentTraits[pos].equals(SCALAR_TRAITS)) {
throw extractionError(
"Only scalar arguments are supported at this location. "
+ "But argument '%s' declared the following traits: %s",
arguments.argumentNames[pos], argumentTraits[pos]);
}
});
}

// --------------------------------------------------------------------------------------------
// Helper methods
// --------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -472,16 +496,11 @@ private void verifyMappingForMethod(
(signature, result) -> {
if (result instanceof FunctionStateTemplate) {
final FunctionStateTemplate stateTemplate = (FunctionStateTemplate) result;
verification.verify(
method, stateTemplate.toClassList(), signature.toClassList(), null);
verification.verify(method, stateTemplate, signature, null);
} else if (result instanceof FunctionOutputTemplate) {
final FunctionOutputTemplate outputTemplate =
(FunctionOutputTemplate) result;
verification.verify(
method,
List.of(),
signature.toClassList(),
outputTemplate.toClass());
verification.verify(method, null, signature, outputTemplate);
}
});
}
Expand Down Expand Up @@ -668,7 +687,7 @@ private static EnumSet<StaticArgumentTrait>[] extractArgumentTraits(
final ArgumentHint argumentHint =
arg.parameter.getAnnotation(ArgumentHint.class);
if (argumentHint == null) {
return EnumSet.of(StaticArgumentTrait.SCALAR);
return SCALAR_TRAITS;
}
final List<StaticArgumentTrait> traits =
Arrays.stream(argumentHint.value())
Expand Down Expand Up @@ -735,11 +754,11 @@ interface ResultExtraction {
}

/** Verifies the signature of a method. */
protected interface MethodVerification {
interface MethodVerification {
void verify(
Method method,
List<Class<?>> state,
List<Class<?>> arguments,
@Nullable Class<?> result);
@Nullable FunctionStateTemplate state,
FunctionSignatureTemplate arguments,
@Nullable FunctionOutputTemplate result);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -166,14 +166,16 @@ static ResultExtraction createOutputFromGenericInMethod(
static MethodVerification createParameterAndReturnTypeVerification() {
return (method, state, arguments, result) -> {
checkNoState(state);
final Class<?>[] parameters = assembleParameters(state, arguments);
checkScalarArgumentsOnly(arguments);
final Class<?>[] parameters = assembleParameters(null, arguments);
assert result != null;
final Class<?> resultClass = result.toClass();
final Class<?> returnType = method.getReturnType();
// TODO enable strict autoboxing
final boolean isValid =
isInvokable(false, method, parameters)
&& isAssignable(result, returnType, true, false);
isInvokable(true, method, parameters)
&& isAssignable(resultClass, returnType, true, false);
if (!isValid) {
throw createMethodNotFoundError(method.getName(), parameters, result, "");
throw createMethodNotFoundError(method.getName(), parameters, resultClass, "");
}
};
}
Expand All @@ -186,9 +188,9 @@ static MethodVerification createParameterVerification(boolean requireAccumulator
} else {
checkNoState(state);
}
checkScalarArgumentsOnly(arguments);
final Class<?>[] parameters = assembleParameters(state, arguments);
// TODO enable strict autoboxing
if (!isInvokable(false, method, parameters)) {
if (!isInvokable(true, method, parameters)) {
throw createMethodNotFoundError(
method.getName(),
parameters,
Expand All @@ -205,11 +207,13 @@ static MethodVerification createParameterVerification(boolean requireAccumulator
static MethodVerification createParameterAndCompletableFutureVerification(Class<?> baseClass) {
return (method, state, arguments, result) -> {
checkNoState(state);
final Class<?>[] parameters = assembleParameters(state, arguments);
checkScalarArgumentsOnly(arguments);
final Class<?>[] parameters = assembleParameters(null, arguments);
final Class<?>[] parametersWithFuture =
Stream.concat(Stream.of(CompletableFuture.class), Arrays.stream(parameters))
.toArray(Class<?>[]::new);

assert result != null;
final Class<?> resultClass = result.toClass();
Type genericType = method.getGenericParameterTypes()[0];
genericType = resolveVariableWithClassContext(baseClass, genericType);
if (!(genericType instanceof ParameterizedType)) {
Expand All @@ -218,10 +222,9 @@ static MethodVerification createParameterAndCompletableFutureVerification(Class<
method.getName(), 0);
}
final Type returnType = ((ParameterizedType) genericType).getActualTypeArguments()[0];
Class<?> returnClazz = getClassFromType(returnType);
// TODO enable strict autoboxing
if (!(isInvokable(false, method, parametersWithFuture)
&& isAssignable(result, returnClazz, true, false))) {
Class<?> returnTypeClass = getClassFromType(returnType);
if (!(isInvokable(true, method, parametersWithFuture)
&& isAssignable(resultClass, returnTypeClass, true, false))) {
throw createMethodNotFoundError(
method.getName(),
parametersWithFuture,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,21 +87,25 @@ static ResultExtraction createOutputFromArrayReturnTypeInMethod() {
static MethodVerification createParameterWithOptionalContextAndArrayReturnTypeVerification() {
return (method, state, arguments, result) -> {
checkNoState(state);
final Class<?>[] parameters = assembleParameters(state, arguments);
checkScalarArgumentsOnly(arguments);
final Class<?>[] parameters = assembleParameters(null, arguments);
// ignore the ProcedureContext in the first argument
final Class<?>[] parametersWithContext =
Stream.concat(Stream.of((Class<?>) null), Arrays.stream(parameters))
.toArray(Class<?>[]::new);
assert result != null;
final Class<?> resultClass = result.toClass();
final Class<?> returnType = method.getReturnType();
final boolean isValid =
isInvokable(true, method, parametersWithContext)
&& returnType.isArray()
&& isAssignable(result, returnType.getComponentType(), true, false);
&& isAssignable(
resultClass, returnType.getComponentType(), true, false);
if (!isValid) {
throw createMethodNotFoundError(
method.getName(),
parametersWithContext,
Array.newInstance(result, 0).getClass(),
Array.newInstance(resultClass, 0).getClass(),
"(<context> [, <argument>]*)");
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,12 @@ private static Stream<TestSpec> functionSpecs() {
TypeStrategies.explicit(
DataTypes.DOUBLE().notNull().bridgedTo(double.class))),
// ---
// non-scalar args
TestSpec.forScalarFunction(TableArgScalarFunction.class)
.expectErrorMessage(
"Only scalar arguments are supported at this location. "
+ "But argument 't' declared the following traits: [TABLE_AS_ROW]"),
// ---
// different accumulator depending on input
TestSpec.forAggregateFunction(InputDependentAccumulatorFunction.class)
.expectAccumulator(
Expand Down Expand Up @@ -2401,4 +2407,10 @@ public void eval(int i) {}

public void eval(String i) {}
}

private static class TableArgScalarFunction extends ScalarFunction {
public int eval(@ArgumentHint(ArgumentTrait.TABLE_AS_ROW) Row t) {
return 0;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ private static class ProcedureWithNamedArguments implements Procedure {
input = {@DataTypeHint("STRING"), @DataTypeHint("STRING")},
output = @DataTypeHint("INT"),
argumentNames = {"c", "d"})
public java.lang.Integer[] call(ProcedureContext context, String arg3, String arg4) {
public int[] call(ProcedureContext context, String arg3, String arg4) {
return null;
}
}
Expand Down

0 comments on commit ae8a1de

Please sign in to comment.