diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/codegen/AsyncCodeGenerator.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/codegen/AsyncCodeGenerator.java index fad669b7e55ca..04b8e5ad31a3e 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/codegen/AsyncCodeGenerator.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/codegen/AsyncCodeGenerator.java @@ -129,30 +129,24 @@ private static String generateProcessCode( projection.stream() .map(exprGenerator::generateExpression) .collect(Collectors.toList()); - int syncIndex = 0; int index = 0; - StringBuilder outputs = new StringBuilder(); - StringBuilder syncInvocations = new StringBuilder(); + StringBuilder metadataInvocations = new StringBuilder(); StringBuilder asyncInvocation = new StringBuilder(); if (retainHeader) { - outputs.append(String.format("%s.setRowKind(rowKind);\n", recordTerm, inputTerm)); + metadataInvocations.append( + String.format("%s.setRowKind(rowKind);\n", delegatingFutureTerm)); } for (GeneratedExpression fieldExpr : projectionExprs) { if (fieldExpr.resultTerm().isEmpty()) { - outputs.append( - String.format("%s.setField(%d, resultObject);\n", recordTerm, index)); asyncInvocation.append(fieldExpr.code()); + metadataInvocations.append( + String.format("%s.addAsyncIndex(%d);\n", delegatingFutureTerm, index)); } else { - outputs.append( + metadataInvocations.append(fieldExpr.code()); + metadataInvocations.append( String.format( - "%s.setField(%d, %s.getSynchronousResult(%d));\n", - recordTerm, index, delegatingFutureTerm, syncIndex)); - syncInvocations.append(fieldExpr.code()); - syncInvocations.append( - String.format( - "%s.addSynchronousResult(%s);\n", - delegatingFutureTerm, fieldExpr.resultTerm())); - syncIndex++; + "%s.addSynchronousResult(%d, %s);\n", + delegatingFutureTerm, index, fieldExpr.resultTerm())); } index++; } @@ -165,8 +159,7 @@ private static String generateProcessCode( values.put("recordTerm", recordTerm); values.put("inputTerm", inputTerm); values.put("fieldCount", Integer.toString(LogicalTypeChecks.getFieldCount(outRowType))); - values.put("outputs", outputs.toString()); - values.put("syncInvocations", syncInvocations.toString()); + values.put("metadataInvocations", metadataInvocations.toString()); values.put("asyncInvocation", asyncInvocation.toString()); values.put("errorTerm", errorTerm); @@ -175,23 +168,12 @@ private static String generateProcessCode( "\n", new String[] { "final ${delegatingFutureType} ${delegatingFutureTerm} ", - " = new ${delegatingFutureType}(${collectorTerm});", + " = new ${delegatingFutureType}(${collectorTerm}, ${fieldCount});", "final org.apache.flink.types.RowKind rowKind = ${inputTerm}.getRowKind();\n", "try {", - " java.util.function.Function outputFactory = ", - " new java.util.function.Function() {", - " @Override", - " public ${typeTerm} apply(Object resultObject) {", - " final ${typeTerm} ${recordTerm} = new ${typeTerm}(${fieldCount});", - " ${outputs}", - " return ${recordTerm};", - " }", - " };", - "", - " ${delegatingFutureTerm}.setOutputFactory(outputFactory);", - // Ensure that sync invocations come first so that we know that they're + // Ensure that metadata setup come first so that we know that they're // available when the async callback occurs. - " ${syncInvocations}", + " ${metadataInvocations}", " ${asyncInvocation}", "", "} catch (Throwable ${errorTerm}) {", diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/codegen/AsyncCodeGeneratorTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/codegen/AsyncCodeGeneratorTest.java index d5584de18b673..fb26a6f9af55c 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/codegen/AsyncCodeGeneratorTest.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/codegen/AsyncCodeGeneratorTest.java @@ -35,14 +35,17 @@ import org.apache.flink.table.types.logical.IntType; import org.apache.flink.table.types.logical.RowType; import org.apache.flink.table.types.logical.VarCharType; +import org.apache.flink.types.RowKind; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rex.RexNode; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; +import java.util.Collections; import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.stream.Collectors; @@ -55,11 +58,15 @@ public class AsyncCodeGeneratorTest { private static final RowType INPUT_TYPE = RowType.of(new IntType(), new BigIntType(), new VarCharType()); + private static final RowType INPUT_TYPE2 = + RowType.of(new VarCharType(), new VarCharType(), new VarCharType()); private PlannerMocks plannerMocks; private SqlToRexConverter converter; + private SqlToRexConverter converter2; private RelDataType tableRowType; + private RelDataType tableRowType2; @BeforeEach public void before() { @@ -75,16 +82,35 @@ public void before() { new IntType(), new BigIntType(), new VarCharType()))); + tableRowType2 = + plannerMocks + .getPlannerContext() + .getTypeFactory() + .buildRelNodeRowType( + JavaScalaConversionUtil.toScala(Arrays.asList("f1", "f2", "f3")), + JavaScalaConversionUtil.toScala( + Arrays.asList( + new VarCharType(), + new VarCharType(), + new VarCharType()))); ShortcutUtils.unwrapContext(plannerMocks.getPlanner().createToRelContext().getCluster()); converter = ShortcutUtils.unwrapContext( plannerMocks.getPlanner().createToRelContext().getCluster()) .getRexFactory() .createSqlToRexConverter(tableRowType, null); + converter2 = + ShortcutUtils.unwrapContext( + plannerMocks.getPlanner().createToRelContext().getCluster()) + .getRexFactory() + .createSqlToRexConverter(tableRowType2, null); plannerMocks .getFunctionCatalog() .registerTemporarySystemFunction("myfunc", new AsyncFunc(), false); + plannerMocks + .getFunctionCatalog() + .registerTemporarySystemFunction("myfunc2", new AsyncFuncThreeParams(), false); plannerMocks .getFunctionCatalog() .registerTemporarySystemFunction("myfunc_error", new AsyncFuncError(), false); @@ -111,6 +137,33 @@ public void testTwoReturnTypes_passThroughFirst() throws Exception { .isEqualTo(GenericRowData.of(3L, StringData.fromString("complete foo 4 6"))); } + @Test + public void testTwoReturnTypes_passThroughFirst_stringArgs() throws Exception { + List rowData = + executeMany( + converter2, + INPUT_TYPE2, + Arrays.asList("f1", "myFunc2(f1, f2, f3)"), + RowType.of(new VarCharType(), new VarCharType()), + Arrays.asList( + GenericRowData.of( + StringData.fromString("a1"), + StringData.fromString("b1"), + StringData.fromString("c1")), + GenericRowData.of( + StringData.fromString("a2"), + StringData.fromString("b2"), + StringData.fromString("c2")))); + assertThat(rowData.get(0)) + .isEqualTo( + GenericRowData.of( + StringData.fromString("a1"), StringData.fromString("val a1b1c1"))); + assertThat(rowData.get(1)) + .isEqualTo( + GenericRowData.of( + StringData.fromString("a2"), StringData.fromString("val a2b2c2"))); + } + @Test public void testTwoReturnTypes_passThroughSecond() throws Exception { RowData rowData = @@ -124,15 +177,42 @@ public void testTwoReturnTypes_passThroughSecond() throws Exception { @Test public void testError() throws Exception { - CompletableFuture> future = + List>> futures = executeFuture( + converter, + INPUT_TYPE, Arrays.asList("myFunc_error(f1, f2, f3)"), RowType.of(new VarCharType(), new BigIntType()), - GenericRowData.of(2, 3L, StringData.fromString("foo"))); + Arrays.asList(GenericRowData.of(2, 3L, StringData.fromString("foo")))); + CompletableFuture> future = futures.get(0); assertThat(future).isCompletedExceptionally(); assertThatThrownBy(future::get).cause().hasMessage("Error!"); } + @Test + public void testPassThroughChangelogTypes() throws Exception { + RowData rowData = + execute( + Arrays.asList("myFunc(f1)"), + RowType.of(new IntType()), + GenericRowData.ofKind(RowKind.INSERT, 2, 3L, StringData.fromString("foo"))); + assertThat(rowData).isEqualTo(GenericRowData.of(12)); + RowData rowData2 = + execute( + Arrays.asList("myFunc(f1)"), + RowType.of(new IntType()), + GenericRowData.ofKind( + RowKind.UPDATE_AFTER, 2, 3L, StringData.fromString("foo"))); + assertThat(rowData2).isEqualTo(GenericRowData.ofKind(RowKind.UPDATE_AFTER, 12)); + + RowData rowData3 = + execute( + Arrays.asList("myFunc(f1)"), + RowType.of(new IntType()), + GenericRowData.ofKind(RowKind.DELETE, 2, 3L, StringData.fromString("foo"))); + assertThat(rowData3).isEqualTo(GenericRowData.ofKind(RowKind.DELETE, 12)); + } + private RowData execute(String sqlExpression, RowType resultType, RowData input) throws Exception { return execute(Arrays.asList(sqlExpression), resultType, input); @@ -140,13 +220,42 @@ private RowData execute(String sqlExpression, RowType resultType, RowData input) private RowData execute(List sqlExpressions, RowType resultType, RowData input) throws Exception { - Collection result = executeFuture(sqlExpressions, resultType, input).get(); + Collection result = + executeFuture( + converter, + INPUT_TYPE, + sqlExpressions, + resultType, + Collections.singletonList(input)) + .get(0) + .get(); assertThat(result).hasSize(1); return result.iterator().next(); } - private CompletableFuture> executeFuture( - List sqlExpressions, RowType resultType, RowData input) throws Exception { + private List executeMany( + SqlToRexConverter converter, + RowType rowType, + List sqlExpressions, + RowType resultType, + List inputs) + throws Exception { + List>> list = + executeFuture(converter, rowType, sqlExpressions, resultType, inputs); + CompletableFuture.allOf(list.toArray(new CompletableFuture[0])).get(); + List> results = + list.stream().map(CompletableFuture::join).collect(Collectors.toList()); + assertThat(results).hasSize(inputs.size()); + return results.stream().flatMap(Collection::stream).collect(Collectors.toList()); + } + + private List>> executeFuture( + SqlToRexConverter converter, + RowType rowType, + List sqlExpressions, + RowType resultType, + List inputs) + throws Exception { List nodes = sqlExpressions.stream() .map(sql -> converter.convertToRexNode(sql)) @@ -154,7 +263,7 @@ private CompletableFuture> executeFuture( GeneratedFunction> function = AsyncCodeGenerator.generateFunction( "name", - INPUT_TYPE, + rowType, resultType, nodes, true, @@ -162,9 +271,13 @@ private CompletableFuture> executeFuture( Thread.currentThread().getContextClassLoader()); AsyncFunction asyncFunction = function.newInstance(Thread.currentThread().getContextClassLoader()); - TestResultFuture resultFuture = new TestResultFuture(); - asyncFunction.asyncInvoke(input, resultFuture); - return resultFuture.getResult(); + List>> results = new ArrayList<>(); + for (RowData input : inputs) { + TestResultFuture resultFuture = new TestResultFuture(); + asyncFunction.asyncInvoke(input, resultFuture); + results.add(resultFuture.getResult()); + } + return results; } /** Test function. */ @@ -172,6 +285,10 @@ public static final class AsyncFunc extends AsyncScalarFunction { public void eval(CompletableFuture f, Integer i, Long l, String s) { f.complete("complete " + s + " " + (i * i) + " " + (2 * l)); } + + public void eval(CompletableFuture f, Integer i) { + f.complete(i + 10); + } } /** Test function. */ @@ -181,6 +298,16 @@ public void eval(CompletableFuture f, Integer i, Long l, String s) { } } + /** Test function. */ + public static class AsyncFuncThreeParams extends AsyncScalarFunction { + + private static final long serialVersionUID = 1L; + + public void eval(CompletableFuture future, String a, String b, String c) { + future.complete("val " + a + b + c); + } + } + /** Test result future. */ public static final class TestResultFuture implements ResultFuture { diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/AsyncCalcITCase.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/AsyncCalcITCase.java index b06454f9006d7..ab4e5978442a1 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/AsyncCalcITCase.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/AsyncCalcITCase.java @@ -30,6 +30,7 @@ import org.apache.flink.table.functions.FunctionContext; import org.apache.flink.table.planner.runtime.utils.StreamingTestBase; import org.apache.flink.types.Row; +import org.apache.flink.types.RowKind; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -46,6 +47,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; +import static org.apache.flink.table.api.Expressions.row; import static org.assertj.core.api.Assertions.assertThat; /** IT Case tests for {@link AsyncScalarFunction}. */ @@ -238,6 +240,35 @@ public void testFailures() { assertThat(results).containsSequence(expectedRows); } + @Test + public void testMultiArgumentAsyncWithAdditionalProjection() { + // This was the cause of a bug previously where the reference to the sync projection was + // getting garbled by janino. + Table t1 = + tEnv.fromValues(row("a1", "b1", "c1"), row("a2", "b2", "c2")).as("f1", "f2", "f3"); + tEnv.createTemporaryView("t1", t1); + tEnv.createTemporarySystemFunction("func", new AsyncFuncThreeParams()); + final List results = executeSql("select f1, func(f1, f2, f3) FROM t1"); + final List expectedRows = + Arrays.asList(Row.of("a1", "val a1b1c1"), Row.of("a2", "val a2b2c2")); + assertThat(results).containsSequence(expectedRows); + } + + @Test + public void testGroupBy() { + Table t1 = tEnv.fromValues(row(1, 1), row(2, 2), row(1, 3)).as("f1", "f2"); + tEnv.createTemporaryView("t1", t1); + tEnv.createTemporarySystemFunction("func", new AsyncFuncAdd10()); + final List results = executeSql("select f1, func(SUM(f2)) FROM t1 group by f1"); + final List expectedRows = + Arrays.asList( + Row.of(1, 11), + Row.of(2, 12), + Row.ofKind(RowKind.UPDATE_BEFORE, 1, 11), + Row.ofKind(RowKind.UPDATE_AFTER, 1, 14)); + assertThat(results).containsSequence(expectedRows); + } + private List executeSql(String sql) { TableResult result = tEnv.executeSql(sql); final List rows = new ArrayList<>(); @@ -255,6 +286,16 @@ public void eval(CompletableFuture future, Integer param) { } } + /** Test function. */ + public static class AsyncFuncThreeParams extends AsyncFuncBase { + + private static final long serialVersionUID = 1L; + + public void eval(CompletableFuture future, String a, String b, String c) { + executor.schedule(() -> future.complete("val " + a + b + c), 10, TimeUnit.MILLISECONDS); + } + } + /** Test function. */ public static class AsyncFuncAdd10 extends AsyncFuncBase { diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/calc/async/DelegatingAsyncResultFuture.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/calc/async/DelegatingAsyncResultFuture.java index dd6ae830ce034..6c071686dde4a 100644 --- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/calc/async/DelegatingAsyncResultFuture.java +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/calc/async/DelegatingAsyncResultFuture.java @@ -19,16 +19,17 @@ package org.apache.flink.table.runtime.operators.calc.async; import org.apache.flink.streaming.api.functions.async.ResultFuture; +import org.apache.flink.table.data.GenericRowData; import org.apache.flink.table.data.RowData; import org.apache.flink.table.data.conversion.DataStructureConverter; +import org.apache.flink.types.RowKind; import org.apache.flink.util.Preconditions; -import java.util.ArrayList; import java.util.Collections; -import java.util.List; +import java.util.HashMap; +import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.function.BiConsumer; -import java.util.function.Function; /** * Inspired by {@link org.apache.flink.table.runtime.operators.join.lookup.DelegatingResultFuture} @@ -37,32 +38,38 @@ public class DelegatingAsyncResultFuture implements BiConsumer { private final ResultFuture delegatedResultFuture; - private final List synchronousResults = new ArrayList<>(); - private Function outputFactory; + private final int totalResultSize; + private final Map synchronousIndexToResults = new HashMap<>(); private CompletableFuture future; private DataStructureConverter converter; - public DelegatingAsyncResultFuture(ResultFuture delegatedResultFuture) { + private int asyncIndex = -1; + private RowKind rowKind; + + public DelegatingAsyncResultFuture( + ResultFuture delegatedResultFuture, int totalResultSize) { this.delegatedResultFuture = delegatedResultFuture; + this.totalResultSize = totalResultSize; } - public synchronized void addSynchronousResult(Object object) { - synchronousResults.add(object); + public synchronized void setRowKind(RowKind rowKind) { + this.rowKind = rowKind; } - public synchronized Object getSynchronousResult(int index) { - return synchronousResults.get(index); + public synchronized void addSynchronousResult(int resultIndex, Object object) { + synchronousIndexToResults.put(resultIndex, object); } - public void setOutputFactory(Function outputFactory) { - this.outputFactory = outputFactory; + public synchronized void addAsyncIndex(int resultIndex) { + Preconditions.checkState(asyncIndex == -1); + asyncIndex = resultIndex; } public CompletableFuture createAsyncFuture( DataStructureConverter converter) { Preconditions.checkState(future == null); Preconditions.checkState(this.converter == null); - Preconditions.checkNotNull(outputFactory); + Preconditions.checkState(this.asyncIndex >= 0); future = new CompletableFuture<>(); this.converter = converter; future.whenComplete(this); @@ -78,11 +85,23 @@ public void accept(Object o, Throwable throwable) { delegatedResultFuture.complete( () -> { Object converted = converter.toInternal(o); - return Collections.singleton(outputFactory.apply(converted)); + return Collections.singleton(createResult(converted)); }); } catch (Throwable t) { delegatedResultFuture.completeExceptionally(t); } } } + + private RowData createResult(Object asyncResult) { + GenericRowData result = new GenericRowData(totalResultSize); + if (rowKind != null) { + result.setRowKind(rowKind); + } + for (Map.Entry entry : synchronousIndexToResults.entrySet()) { + result.setField(entry.getKey(), entry.getValue()); + } + result.setField(asyncIndex, asyncResult); + return result; + } }