Skip to content

Commit d002034

Browse files
committed
[FRT-734] Makes Async calcs share correlate split rule with Python (#1382)
* [FRT-734] Makes Async calcs share correlate split rule with Python * Style * Feedback
1 parent f0ed3f7 commit d002034

File tree

11 files changed

+614
-295
lines changed

11 files changed

+614
-295
lines changed

flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/AsyncCalcSplitRule.java

+10
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,16 @@ public boolean isRemoteCall(RexNode node) {
9090
public boolean isNonRemoteCall(RexNode node) {
9191
return AsyncUtil.isNonAsyncCall(node);
9292
}
93+
94+
@Override
95+
public String getName() {
96+
return "Async";
97+
}
98+
99+
@Override
100+
public boolean equals(Object obj) {
101+
return obj != null && this.getClass() == obj.getClass();
102+
}
93103
}
94104

95105
private static boolean hasNestedCalls(List<RexNode> projects) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.flink.table.planner.plan.rules.logical;
20+
21+
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCalc;
22+
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCorrelate;
23+
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalTableFunctionScan;
24+
import org.apache.flink.table.planner.plan.rules.logical.AsyncCalcSplitRule.AsyncRemoteCalcCallFinder;
25+
26+
import org.apache.calcite.plan.RelOptRule;
27+
28+
/**
29+
* Rule will split the Async {@link FlinkLogicalTableFunctionScan} with Java calls or the Java
30+
* {@link FlinkLogicalTableFunctionScan} with Async calls into a {@link FlinkLogicalCalc} which will
31+
* be the left input of the new {@link FlinkLogicalCorrelate} and a new {@link
32+
* FlinkLogicalTableFunctionScan}.
33+
*/
34+
public class AsyncCorrelateSplitRule {
35+
36+
private static final RemoteCalcCallFinder ASYNC_CALL_FINDER = new AsyncRemoteCalcCallFinder();
37+
38+
public static final RelOptRule CORRELATE_SPLIT =
39+
new RemoteCorrelateSplitRule(ASYNC_CALL_FINDER);
40+
}

flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/PythonCorrelateSplitRule.java

+3-294
Original file line numberDiff line numberDiff line change
@@ -21,306 +21,15 @@
2121
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCalc;
2222
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCorrelate;
2323
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalTableFunctionScan;
24-
import org.apache.flink.table.planner.plan.rules.physical.stream.StreamPhysicalCorrelateRule;
25-
import org.apache.flink.table.planner.plan.utils.PythonUtil;
26-
import org.apache.flink.table.planner.plan.utils.RexDefaultVisitor;
27-
28-
import org.apache.calcite.plan.RelOptRule;
29-
import org.apache.calcite.plan.RelOptRuleCall;
30-
import org.apache.calcite.plan.hep.HepRelVertex;
31-
import org.apache.calcite.rel.RelNode;
32-
import org.apache.calcite.rel.type.RelDataType;
33-
import org.apache.calcite.rel.type.RelDataTypeField;
34-
import org.apache.calcite.rex.RexBuilder;
35-
import org.apache.calcite.rex.RexCall;
36-
import org.apache.calcite.rex.RexCorrelVariable;
37-
import org.apache.calcite.rex.RexFieldAccess;
38-
import org.apache.calcite.rex.RexInputRef;
39-
import org.apache.calcite.rex.RexNode;
40-
import org.apache.calcite.rex.RexProgram;
41-
import org.apache.calcite.rex.RexProgramBuilder;
42-
import org.apache.calcite.rex.RexUtil;
43-
import org.apache.calcite.sql.validate.SqlValidatorUtil;
44-
45-
import java.util.LinkedList;
46-
import java.util.List;
47-
import java.util.stream.Collectors;
48-
49-
import scala.collection.Iterator;
50-
import scala.collection.mutable.ArrayBuffer;
5124

5225
/**
5326
* Rule will split the Python {@link FlinkLogicalTableFunctionScan} with Java calls or the Java
5427
* {@link FlinkLogicalTableFunctionScan} with Python calls into a {@link FlinkLogicalCalc} which
5528
* will be the left input of the new {@link FlinkLogicalCorrelate} and a new {@link
5629
* FlinkLogicalTableFunctionScan}.
5730
*/
58-
public class PythonCorrelateSplitRule extends RelOptRule {
59-
public static final PythonCorrelateSplitRule INSTANCE = new PythonCorrelateSplitRule();
60-
61-
private PythonCorrelateSplitRule() {
62-
super(operand(FlinkLogicalCorrelate.class, any()), "PythonCorrelateSplitRule");
63-
}
64-
65-
private FlinkLogicalTableFunctionScan createNewScan(
66-
FlinkLogicalTableFunctionScan scan, ScalarFunctionSplitter splitter) {
67-
RexCall rightRexCall = (RexCall) scan.getCall();
68-
// extract Java funcs from Python TableFunction or Python funcs from Java TableFunction.
69-
List<RexNode> rightCalcProjects =
70-
rightRexCall.getOperands().stream()
71-
.map(x -> x.accept(splitter))
72-
.collect(Collectors.toList());
73-
74-
RexCall newRightRexCall = rightRexCall.clone(rightRexCall.getType(), rightCalcProjects);
75-
return new FlinkLogicalTableFunctionScan(
76-
scan.getCluster(),
77-
scan.getTraitSet(),
78-
scan.getInputs(),
79-
newRightRexCall,
80-
scan.getElementType(),
81-
scan.getRowType(),
82-
scan.getColumnMappings());
83-
}
84-
85-
@Override
86-
public boolean matches(RelOptRuleCall call) {
87-
FlinkLogicalCorrelate correlate = call.rel(0);
88-
RelNode right = ((HepRelVertex) correlate.getRight()).getCurrentRel();
89-
FlinkLogicalTableFunctionScan tableFunctionScan;
90-
if (right instanceof FlinkLogicalTableFunctionScan) {
91-
tableFunctionScan = (FlinkLogicalTableFunctionScan) right;
92-
} else if (right instanceof FlinkLogicalCalc) {
93-
tableFunctionScan = StreamPhysicalCorrelateRule.getTableScan((FlinkLogicalCalc) right);
94-
} else {
95-
return false;
96-
}
97-
RexNode rexNode = tableFunctionScan.getCall();
98-
if (rexNode instanceof RexCall) {
99-
return PythonUtil.isPythonCall(rexNode, null)
100-
&& PythonUtil.containsNonPythonCall(rexNode)
101-
|| PythonUtil.isNonPythonCall(rexNode)
102-
&& PythonUtil.containsPythonCall(rexNode, null)
103-
|| (PythonUtil.isPythonCall(rexNode, null)
104-
&& RexUtil.containsFieldAccess(rexNode));
105-
}
106-
return false;
107-
}
108-
109-
private List<String> createNewFieldNames(
110-
RelDataType rowType,
111-
RexBuilder rexBuilder,
112-
int primitiveFieldCount,
113-
ArrayBuffer<RexNode> extractedRexNodes,
114-
List<RexNode> calcProjects) {
115-
for (int i = 0; i < primitiveFieldCount; i++) {
116-
calcProjects.add(RexInputRef.of(i, rowType));
117-
}
118-
// change RexCorrelVariable to RexInputRef.
119-
RexDefaultVisitor<RexNode> visitor =
120-
new RexDefaultVisitor<RexNode>() {
121-
@Override
122-
public RexNode visitFieldAccess(RexFieldAccess fieldAccess) {
123-
RexNode expr = fieldAccess.getReferenceExpr();
124-
if (expr instanceof RexCorrelVariable) {
125-
RelDataTypeField field = fieldAccess.getField();
126-
return new RexInputRef(field.getIndex(), field.getType());
127-
} else {
128-
return rexBuilder.makeFieldAccess(
129-
expr.accept(this), fieldAccess.getField().getIndex());
130-
}
131-
}
132-
133-
@Override
134-
public RexNode visitNode(RexNode rexNode) {
135-
return rexNode;
136-
}
137-
};
138-
// add the fields of the extracted rex calls.
139-
Iterator<RexNode> iterator = extractedRexNodes.iterator();
140-
while (iterator.hasNext()) {
141-
RexNode rexNode = iterator.next();
142-
if (rexNode instanceof RexCall) {
143-
RexCall rexCall = (RexCall) rexNode;
144-
List<RexNode> newProjects =
145-
rexCall.getOperands().stream()
146-
.map(x -> x.accept(visitor))
147-
.collect(Collectors.toList());
148-
RexCall newRexCall = rexCall.clone(rexCall.getType(), newProjects);
149-
calcProjects.add(newRexCall);
150-
} else {
151-
calcProjects.add(rexNode);
152-
}
153-
}
154-
155-
List<String> nameList = new LinkedList<>();
156-
for (int i = 0; i < primitiveFieldCount; i++) {
157-
nameList.add(rowType.getFieldNames().get(i));
158-
}
159-
Iterator<Object> indicesIterator = extractedRexNodes.indices().iterator();
160-
while (indicesIterator.hasNext()) {
161-
nameList.add("f" + indicesIterator.next());
162-
}
163-
return SqlValidatorUtil.uniquify(
164-
nameList, rexBuilder.getTypeFactory().getTypeSystem().isSchemaCaseSensitive());
165-
}
166-
167-
private FlinkLogicalCalc createNewLeftCalc(
168-
RelNode left,
169-
RexBuilder rexBuilder,
170-
ArrayBuffer<RexNode> extractedRexNodes,
171-
FlinkLogicalCorrelate correlate) {
172-
// add the fields of the primitive left input.
173-
List<RexNode> leftCalcProjects = new LinkedList<>();
174-
RelDataType leftRowType = left.getRowType();
175-
List<String> leftCalcCalcFieldNames =
176-
createNewFieldNames(
177-
leftRowType,
178-
rexBuilder,
179-
leftRowType.getFieldCount(),
180-
extractedRexNodes,
181-
leftCalcProjects);
182-
183-
// create a new calc
184-
return new FlinkLogicalCalc(
185-
correlate.getCluster(),
186-
correlate.getTraitSet(),
187-
left,
188-
RexProgram.create(
189-
leftRowType, leftCalcProjects, null, leftCalcCalcFieldNames, rexBuilder));
190-
}
191-
192-
private FlinkLogicalCalc createTopCalc(
193-
int primitiveLeftFieldCount,
194-
RexBuilder rexBuilder,
195-
ArrayBuffer<RexNode> extractedRexNodes,
196-
RelDataType calcRowType,
197-
FlinkLogicalCorrelate newCorrelate) {
198-
RexProgram rexProgram =
199-
new RexProgramBuilder(newCorrelate.getRowType(), rexBuilder).getProgram();
200-
int offset = extractedRexNodes.size() + primitiveLeftFieldCount;
201-
202-
// extract correlate output RexNode.
203-
List<RexNode> newTopCalcProjects =
204-
rexProgram.getExprList().stream()
205-
.filter(x -> x instanceof RexInputRef)
206-
.filter(
207-
x -> {
208-
int index = ((RexInputRef) x).getIndex();
209-
return index < primitiveLeftFieldCount || index >= offset;
210-
})
211-
.collect(Collectors.toList());
212-
213-
return new FlinkLogicalCalc(
214-
newCorrelate.getCluster(),
215-
newCorrelate.getTraitSet(),
216-
newCorrelate,
217-
RexProgram.create(
218-
newCorrelate.getRowType(),
219-
newTopCalcProjects,
220-
null,
221-
calcRowType,
222-
rexBuilder));
223-
}
224-
225-
private ScalarFunctionSplitter createScalarFunctionSplitter(
226-
RexProgram program,
227-
RexBuilder rexBuilder,
228-
int primitiveLeftFieldCount,
229-
ArrayBuffer<RexNode> extractedRexNodes,
230-
RexNode tableFunctionNode) {
231-
return new ScalarFunctionSplitter(
232-
program,
233-
rexBuilder,
234-
primitiveLeftFieldCount,
235-
extractedRexNodes,
236-
node -> {
237-
if (PythonUtil.isNonPythonCall(tableFunctionNode)) {
238-
// splits the RexCalls which contain Python functions into separate node
239-
return PythonUtil.isPythonCall(node, null);
240-
} else if (PythonUtil.containsNonPythonCall(node)) {
241-
// splits the RexCalls which contain non-Python functions into separate node
242-
return PythonUtil.isNonPythonCall(node);
243-
} else {
244-
// splits the RexFieldAccesses which contain non-Python functions into
245-
// separate node
246-
return node instanceof RexFieldAccess;
247-
}
248-
},
249-
new PythonRemoteCalcCallFinder());
250-
}
251-
252-
@Override
253-
public void onMatch(RelOptRuleCall call) {
254-
FlinkLogicalCorrelate correlate = call.rel(0);
255-
RexBuilder rexBuilder = call.builder().getRexBuilder();
256-
RelNode left = ((HepRelVertex) correlate.getLeft()).getCurrentRel();
257-
RelNode right = ((HepRelVertex) correlate.getRight()).getCurrentRel();
258-
int primitiveLeftFieldCount = left.getRowType().getFieldCount();
259-
ArrayBuffer<RexNode> extractedRexNodes = new ArrayBuffer<>();
260-
261-
RelNode rightNewInput;
262-
if (right instanceof FlinkLogicalTableFunctionScan) {
263-
FlinkLogicalTableFunctionScan scan = (FlinkLogicalTableFunctionScan) right;
264-
rightNewInput =
265-
createNewScan(
266-
scan,
267-
createScalarFunctionSplitter(
268-
null,
269-
rexBuilder,
270-
primitiveLeftFieldCount,
271-
extractedRexNodes,
272-
scan.getCall()));
273-
} else {
274-
FlinkLogicalCalc calc = (FlinkLogicalCalc) right;
275-
FlinkLogicalTableFunctionScan scan = StreamPhysicalCorrelateRule.getTableScan(calc);
276-
FlinkLogicalCalc mergedCalc = StreamPhysicalCorrelateRule.getMergedCalc(calc);
277-
FlinkLogicalTableFunctionScan newScan =
278-
createNewScan(
279-
scan,
280-
createScalarFunctionSplitter(
281-
null,
282-
rexBuilder,
283-
primitiveLeftFieldCount,
284-
extractedRexNodes,
285-
scan.getCall()));
286-
rightNewInput =
287-
mergedCalc.copy(mergedCalc.getTraitSet(), newScan, mergedCalc.getProgram());
288-
}
289-
290-
FlinkLogicalCorrelate newCorrelate;
291-
if (extractedRexNodes.size() > 0) {
292-
FlinkLogicalCalc leftCalc =
293-
createNewLeftCalc(left, rexBuilder, extractedRexNodes, correlate);
294-
295-
newCorrelate =
296-
new FlinkLogicalCorrelate(
297-
correlate.getCluster(),
298-
correlate.getTraitSet(),
299-
leftCalc,
300-
rightNewInput,
301-
correlate.getCorrelationId(),
302-
correlate.getRequiredColumns(),
303-
correlate.getJoinType());
304-
} else {
305-
newCorrelate =
306-
new FlinkLogicalCorrelate(
307-
correlate.getCluster(),
308-
correlate.getTraitSet(),
309-
left,
310-
rightNewInput,
311-
correlate.getCorrelationId(),
312-
correlate.getRequiredColumns(),
313-
correlate.getJoinType());
314-
}
315-
316-
FlinkLogicalCalc newTopCalc =
317-
createTopCalc(
318-
primitiveLeftFieldCount,
319-
rexBuilder,
320-
extractedRexNodes,
321-
correlate.getRowType(),
322-
newCorrelate);
31+
public class PythonCorrelateSplitRule {
32332

324-
call.transformTo(newTopCalc);
325-
}
33+
public static final RemoteCorrelateSplitRule INSTANCE =
34+
new RemoteCorrelateSplitRule(new PythonRemoteCalcCallFinder());
32635
}

0 commit comments

Comments
 (0)