Skip to content

Fixes janino bug returning incorrect results #26504

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -129,30 +129,24 @@ private static String generateProcessCode(
projection.stream()
.map(exprGenerator::generateExpression)
.collect(Collectors.toList());
int syncIndex = 0;
int index = 0;
StringBuilder outputs = new StringBuilder();
StringBuilder syncInvocations = new StringBuilder();
StringBuilder metadataInvocations = new StringBuilder();
StringBuilder asyncInvocation = new StringBuilder();
if (retainHeader) {
outputs.append(String.format("%s.setRowKind(rowKind);\n", recordTerm, inputTerm));
metadataInvocations.append(
String.format("%s.setRowKind(rowKind);\n", delegatingFutureTerm));
}
for (GeneratedExpression fieldExpr : projectionExprs) {
if (fieldExpr.resultTerm().isEmpty()) {
outputs.append(
String.format("%s.setField(%d, resultObject);\n", recordTerm, index));
asyncInvocation.append(fieldExpr.code());
metadataInvocations.append(
String.format("%s.addAsyncIndex(%d);\n", delegatingFutureTerm, index));
} else {
outputs.append(
metadataInvocations.append(fieldExpr.code());
metadataInvocations.append(
String.format(
"%s.setField(%d, %s.getSynchronousResult(%d));\n",
recordTerm, index, delegatingFutureTerm, syncIndex));
syncInvocations.append(fieldExpr.code());
syncInvocations.append(
String.format(
"%s.addSynchronousResult(%s);\n",
delegatingFutureTerm, fieldExpr.resultTerm()));
syncIndex++;
"%s.addSynchronousResult(%d, %s);\n",
delegatingFutureTerm, index, fieldExpr.resultTerm()));
}
index++;
}
Expand All @@ -165,8 +159,7 @@ private static String generateProcessCode(
values.put("recordTerm", recordTerm);
values.put("inputTerm", inputTerm);
values.put("fieldCount", Integer.toString(LogicalTypeChecks.getFieldCount(outRowType)));
values.put("outputs", outputs.toString());
values.put("syncInvocations", syncInvocations.toString());
values.put("metadataInvocations", metadataInvocations.toString());
values.put("asyncInvocation", asyncInvocation.toString());
values.put("errorTerm", errorTerm);

Expand All @@ -175,23 +168,12 @@ private static String generateProcessCode(
"\n",
new String[] {
"final ${delegatingFutureType} ${delegatingFutureTerm} ",
" = new ${delegatingFutureType}(${collectorTerm});",
" = new ${delegatingFutureType}(${collectorTerm}, ${fieldCount});",
"final org.apache.flink.types.RowKind rowKind = ${inputTerm}.getRowKind();\n",
"try {",
" java.util.function.Function<Object, ${typeTerm}> outputFactory = ",
" new java.util.function.Function<Object, ${typeTerm}>() {",
" @Override",
" public ${typeTerm} apply(Object resultObject) {",
" final ${typeTerm} ${recordTerm} = new ${typeTerm}(${fieldCount});",
" ${outputs}",
" return ${recordTerm};",
" }",
" };",
"",
" ${delegatingFutureTerm}.setOutputFactory(outputFactory);",
// Ensure that sync invocations come first so that we know that they're
// Ensure that metadata setup come first so that we know that they're
// available when the async callback occurs.
" ${syncInvocations}",
" ${metadataInvocations}",
" ${asyncInvocation}",
"",
"} catch (Throwable ${errorTerm}) {",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,17 @@
import org.apache.flink.table.types.logical.IntType;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.table.types.logical.VarCharType;
import org.apache.flink.types.RowKind;

import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexNode;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
Expand All @@ -55,11 +58,15 @@ public class AsyncCodeGeneratorTest {

private static final RowType INPUT_TYPE =
RowType.of(new IntType(), new BigIntType(), new VarCharType());
private static final RowType INPUT_TYPE2 =
RowType.of(new VarCharType(), new VarCharType(), new VarCharType());

private PlannerMocks plannerMocks;
private SqlToRexConverter converter;
private SqlToRexConverter converter2;

private RelDataType tableRowType;
private RelDataType tableRowType2;

@BeforeEach
public void before() {
Expand All @@ -75,16 +82,35 @@ public void before() {
new IntType(),
new BigIntType(),
new VarCharType())));
tableRowType2 =
plannerMocks
.getPlannerContext()
.getTypeFactory()
.buildRelNodeRowType(
JavaScalaConversionUtil.toScala(Arrays.asList("f1", "f2", "f3")),
JavaScalaConversionUtil.toScala(
Arrays.asList(
new VarCharType(),
new VarCharType(),
new VarCharType())));
ShortcutUtils.unwrapContext(plannerMocks.getPlanner().createToRelContext().getCluster());
converter =
ShortcutUtils.unwrapContext(
plannerMocks.getPlanner().createToRelContext().getCluster())
.getRexFactory()
.createSqlToRexConverter(tableRowType, null);
converter2 =
ShortcutUtils.unwrapContext(
plannerMocks.getPlanner().createToRelContext().getCluster())
.getRexFactory()
.createSqlToRexConverter(tableRowType2, null);

plannerMocks
.getFunctionCatalog()
.registerTemporarySystemFunction("myfunc", new AsyncFunc(), false);
plannerMocks
.getFunctionCatalog()
.registerTemporarySystemFunction("myfunc2", new AsyncFuncThreeParams(), false);
plannerMocks
.getFunctionCatalog()
.registerTemporarySystemFunction("myfunc_error", new AsyncFuncError(), false);
Expand All @@ -111,6 +137,33 @@ public void testTwoReturnTypes_passThroughFirst() throws Exception {
.isEqualTo(GenericRowData.of(3L, StringData.fromString("complete foo 4 6")));
}

@Test
public void testTwoReturnTypes_passThroughFirst_stringArgs() throws Exception {
List<RowData> rowData =
executeMany(
converter2,
INPUT_TYPE2,
Arrays.asList("f1", "myFunc2(f1, f2, f3)"),
RowType.of(new VarCharType(), new VarCharType()),
Arrays.asList(
GenericRowData.of(
StringData.fromString("a1"),
StringData.fromString("b1"),
StringData.fromString("c1")),
GenericRowData.of(
StringData.fromString("a2"),
StringData.fromString("b2"),
StringData.fromString("c2"))));
assertThat(rowData.get(0))
.isEqualTo(
GenericRowData.of(
StringData.fromString("a1"), StringData.fromString("val a1b1c1")));
assertThat(rowData.get(1))
.isEqualTo(
GenericRowData.of(
StringData.fromString("a2"), StringData.fromString("val a2b2c2")));
}

@Test
public void testTwoReturnTypes_passThroughSecond() throws Exception {
RowData rowData =
Expand All @@ -124,54 +177,118 @@ public void testTwoReturnTypes_passThroughSecond() throws Exception {

@Test
public void testError() throws Exception {
CompletableFuture<Collection<RowData>> future =
List<CompletableFuture<Collection<RowData>>> futures =
executeFuture(
converter,
INPUT_TYPE,
Arrays.asList("myFunc_error(f1, f2, f3)"),
RowType.of(new VarCharType(), new BigIntType()),
GenericRowData.of(2, 3L, StringData.fromString("foo")));
Arrays.asList(GenericRowData.of(2, 3L, StringData.fromString("foo"))));
CompletableFuture<Collection<RowData>> future = futures.get(0);
assertThat(future).isCompletedExceptionally();
assertThatThrownBy(future::get).cause().hasMessage("Error!");
}

@Test
public void testPassThroughChangelogTypes() throws Exception {
RowData rowData =
execute(
Arrays.asList("myFunc(f1)"),
RowType.of(new IntType()),
GenericRowData.ofKind(RowKind.INSERT, 2, 3L, StringData.fromString("foo")));
assertThat(rowData).isEqualTo(GenericRowData.of(12));
RowData rowData2 =
execute(
Arrays.asList("myFunc(f1)"),
RowType.of(new IntType()),
GenericRowData.ofKind(
RowKind.UPDATE_AFTER, 2, 3L, StringData.fromString("foo")));
assertThat(rowData2).isEqualTo(GenericRowData.ofKind(RowKind.UPDATE_AFTER, 12));

RowData rowData3 =
execute(
Arrays.asList("myFunc(f1)"),
RowType.of(new IntType()),
GenericRowData.ofKind(RowKind.DELETE, 2, 3L, StringData.fromString("foo")));
assertThat(rowData3).isEqualTo(GenericRowData.ofKind(RowKind.DELETE, 12));
}

private RowData execute(String sqlExpression, RowType resultType, RowData input)
throws Exception {
return execute(Arrays.asList(sqlExpression), resultType, input);
}

private RowData execute(List<String> sqlExpressions, RowType resultType, RowData input)
throws Exception {
Collection<RowData> result = executeFuture(sqlExpressions, resultType, input).get();
Collection<RowData> result =
executeFuture(
converter,
INPUT_TYPE,
sqlExpressions,
resultType,
Collections.singletonList(input))
.get(0)
.get();
assertThat(result).hasSize(1);
return result.iterator().next();
}

private CompletableFuture<Collection<RowData>> executeFuture(
List<String> sqlExpressions, RowType resultType, RowData input) throws Exception {
private List<RowData> executeMany(
SqlToRexConverter converter,
RowType rowType,
List<String> sqlExpressions,
RowType resultType,
List<RowData> inputs)
throws Exception {
List<CompletableFuture<Collection<RowData>>> list =
executeFuture(converter, rowType, sqlExpressions, resultType, inputs);
CompletableFuture.allOf(list.toArray(new CompletableFuture[0])).get();
List<Collection<RowData>> results =
list.stream().map(CompletableFuture::join).collect(Collectors.toList());
assertThat(results).hasSize(inputs.size());
return results.stream().flatMap(Collection::stream).collect(Collectors.toList());
}

private List<CompletableFuture<Collection<RowData>>> executeFuture(
SqlToRexConverter converter,
RowType rowType,
List<String> sqlExpressions,
RowType resultType,
List<RowData> inputs)
throws Exception {
List<RexNode> nodes =
sqlExpressions.stream()
.map(sql -> converter.convertToRexNode(sql))
.collect(Collectors.toList());
GeneratedFunction<AsyncFunction<RowData, RowData>> function =
AsyncCodeGenerator.generateFunction(
"name",
INPUT_TYPE,
rowType,
resultType,
nodes,
true,
new Configuration(),
Thread.currentThread().getContextClassLoader());
AsyncFunction<RowData, RowData> asyncFunction =
function.newInstance(Thread.currentThread().getContextClassLoader());
TestResultFuture resultFuture = new TestResultFuture();
asyncFunction.asyncInvoke(input, resultFuture);
return resultFuture.getResult();
List<CompletableFuture<Collection<RowData>>> results = new ArrayList<>();
for (RowData input : inputs) {
TestResultFuture resultFuture = new TestResultFuture();
asyncFunction.asyncInvoke(input, resultFuture);
results.add(resultFuture.getResult());
}
return results;
}

/** Test function. */
public static final class AsyncFunc extends AsyncScalarFunction {
public void eval(CompletableFuture<String> f, Integer i, Long l, String s) {
f.complete("complete " + s + " " + (i * i) + " " + (2 * l));
}

public void eval(CompletableFuture<Integer> f, Integer i) {
f.complete(i + 10);
}
}

/** Test function. */
Expand All @@ -181,6 +298,16 @@ public void eval(CompletableFuture<String> f, Integer i, Long l, String s) {
}
}

/** Test function. */
public static class AsyncFuncThreeParams extends AsyncScalarFunction {

private static final long serialVersionUID = 1L;

public void eval(CompletableFuture<String> future, String a, String b, String c) {
future.complete("val " + a + b + c);
}
}

/** Test result future. */
public static final class TestResultFuture implements ResultFuture<RowData> {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.apache.flink.table.functions.FunctionContext;
import org.apache.flink.table.planner.runtime.utils.StreamingTestBase;
import org.apache.flink.types.Row;
import org.apache.flink.types.RowKind;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
Expand All @@ -46,6 +47,7 @@
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

import static org.apache.flink.table.api.Expressions.row;
import static org.assertj.core.api.Assertions.assertThat;

/** IT Case tests for {@link AsyncScalarFunction}. */
Expand Down Expand Up @@ -238,6 +240,35 @@ public void testFailures() {
assertThat(results).containsSequence(expectedRows);
}

@Test
public void testMultiArgumentAsyncWithAdditionalProjection() {
// This was the cause of a bug previously where the reference to the sync projection was
// getting garbled by janino.
Table t1 =
tEnv.fromValues(row("a1", "b1", "c1"), row("a2", "b2", "c2")).as("f1", "f2", "f3");
tEnv.createTemporaryView("t1", t1);
tEnv.createTemporarySystemFunction("func", new AsyncFuncThreeParams());
final List<Row> results = executeSql("select f1, func(f1, f2, f3) FROM t1");
final List<Row> expectedRows =
Arrays.asList(Row.of("a1", "val a1b1c1"), Row.of("a2", "val a2b2c2"));
assertThat(results).containsSequence(expectedRows);
}

@Test
public void testGroupBy() {
Table t1 = tEnv.fromValues(row(1, 1), row(2, 2), row(1, 3)).as("f1", "f2");
tEnv.createTemporaryView("t1", t1);
tEnv.createTemporarySystemFunction("func", new AsyncFuncAdd10());
final List<Row> results = executeSql("select f1, func(SUM(f2)) FROM t1 group by f1");
final List<Row> expectedRows =
Arrays.asList(
Row.of(1, 11),
Row.of(2, 12),
Row.ofKind(RowKind.UPDATE_BEFORE, 1, 11),
Row.ofKind(RowKind.UPDATE_AFTER, 1, 14));
assertThat(results).containsSequence(expectedRows);
}

private List<Row> executeSql(String sql) {
TableResult result = tEnv.executeSql(sql);
final List<Row> rows = new ArrayList<>();
Expand All @@ -255,6 +286,16 @@ public void eval(CompletableFuture<String> future, Integer param) {
}
}

/** Test function. */
public static class AsyncFuncThreeParams extends AsyncFuncBase {

private static final long serialVersionUID = 1L;

public void eval(CompletableFuture<String> future, String a, String b, String c) {
executor.schedule(() -> future.complete("val " + a + b + c), 10, TimeUnit.MILLISECONDS);
}
}

/** Test function. */
public static class AsyncFuncAdd10 extends AsyncFuncBase {

Expand Down
Loading