|
21 | 21 | import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCalc;
|
22 | 22 | import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCorrelate;
|
23 | 23 | import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalTableFunctionScan;
|
24 |
| -import org.apache.flink.table.planner.plan.rules.physical.stream.StreamPhysicalCorrelateRule; |
25 |
| -import org.apache.flink.table.planner.plan.utils.PythonUtil; |
26 |
| -import org.apache.flink.table.planner.plan.utils.RexDefaultVisitor; |
27 |
| - |
28 |
| -import org.apache.calcite.plan.RelOptRule; |
29 |
| -import org.apache.calcite.plan.RelOptRuleCall; |
30 |
| -import org.apache.calcite.plan.hep.HepRelVertex; |
31 |
| -import org.apache.calcite.rel.RelNode; |
32 |
| -import org.apache.calcite.rel.type.RelDataType; |
33 |
| -import org.apache.calcite.rel.type.RelDataTypeField; |
34 |
| -import org.apache.calcite.rex.RexBuilder; |
35 |
| -import org.apache.calcite.rex.RexCall; |
36 |
| -import org.apache.calcite.rex.RexCorrelVariable; |
37 |
| -import org.apache.calcite.rex.RexFieldAccess; |
38 |
| -import org.apache.calcite.rex.RexInputRef; |
39 |
| -import org.apache.calcite.rex.RexNode; |
40 |
| -import org.apache.calcite.rex.RexProgram; |
41 |
| -import org.apache.calcite.rex.RexProgramBuilder; |
42 |
| -import org.apache.calcite.rex.RexUtil; |
43 |
| -import org.apache.calcite.sql.validate.SqlValidatorUtil; |
44 |
| - |
45 |
| -import java.util.LinkedList; |
46 |
| -import java.util.List; |
47 |
| -import java.util.stream.Collectors; |
48 |
| - |
49 |
| -import scala.collection.Iterator; |
50 |
| -import scala.collection.mutable.ArrayBuffer; |
51 | 24 |
|
52 | 25 | /**
|
53 | 26 | * Rule will split the Python {@link FlinkLogicalTableFunctionScan} with Java calls or the Java
|
54 | 27 | * {@link FlinkLogicalTableFunctionScan} with Python calls into a {@link FlinkLogicalCalc} which
|
55 | 28 | * will be the left input of the new {@link FlinkLogicalCorrelate} and a new {@link
|
56 | 29 | * FlinkLogicalTableFunctionScan}.
|
57 | 30 | */
|
58 |
| -public class PythonCorrelateSplitRule extends RelOptRule { |
59 |
| - public static final PythonCorrelateSplitRule INSTANCE = new PythonCorrelateSplitRule(); |
60 |
| - |
61 |
| - private PythonCorrelateSplitRule() { |
62 |
| - super(operand(FlinkLogicalCorrelate.class, any()), "PythonCorrelateSplitRule"); |
63 |
| - } |
64 |
| - |
65 |
| - private FlinkLogicalTableFunctionScan createNewScan( |
66 |
| - FlinkLogicalTableFunctionScan scan, ScalarFunctionSplitter splitter) { |
67 |
| - RexCall rightRexCall = (RexCall) scan.getCall(); |
68 |
| - // extract Java funcs from Python TableFunction or Python funcs from Java TableFunction. |
69 |
| - List<RexNode> rightCalcProjects = |
70 |
| - rightRexCall.getOperands().stream() |
71 |
| - .map(x -> x.accept(splitter)) |
72 |
| - .collect(Collectors.toList()); |
73 |
| - |
74 |
| - RexCall newRightRexCall = rightRexCall.clone(rightRexCall.getType(), rightCalcProjects); |
75 |
| - return new FlinkLogicalTableFunctionScan( |
76 |
| - scan.getCluster(), |
77 |
| - scan.getTraitSet(), |
78 |
| - scan.getInputs(), |
79 |
| - newRightRexCall, |
80 |
| - scan.getElementType(), |
81 |
| - scan.getRowType(), |
82 |
| - scan.getColumnMappings()); |
83 |
| - } |
84 |
| - |
85 |
| - @Override |
86 |
| - public boolean matches(RelOptRuleCall call) { |
87 |
| - FlinkLogicalCorrelate correlate = call.rel(0); |
88 |
| - RelNode right = ((HepRelVertex) correlate.getRight()).getCurrentRel(); |
89 |
| - FlinkLogicalTableFunctionScan tableFunctionScan; |
90 |
| - if (right instanceof FlinkLogicalTableFunctionScan) { |
91 |
| - tableFunctionScan = (FlinkLogicalTableFunctionScan) right; |
92 |
| - } else if (right instanceof FlinkLogicalCalc) { |
93 |
| - tableFunctionScan = StreamPhysicalCorrelateRule.getTableScan((FlinkLogicalCalc) right); |
94 |
| - } else { |
95 |
| - return false; |
96 |
| - } |
97 |
| - RexNode rexNode = tableFunctionScan.getCall(); |
98 |
| - if (rexNode instanceof RexCall) { |
99 |
| - return PythonUtil.isPythonCall(rexNode, null) |
100 |
| - && PythonUtil.containsNonPythonCall(rexNode) |
101 |
| - || PythonUtil.isNonPythonCall(rexNode) |
102 |
| - && PythonUtil.containsPythonCall(rexNode, null) |
103 |
| - || (PythonUtil.isPythonCall(rexNode, null) |
104 |
| - && RexUtil.containsFieldAccess(rexNode)); |
105 |
| - } |
106 |
| - return false; |
107 |
| - } |
108 |
| - |
109 |
| - private List<String> createNewFieldNames( |
110 |
| - RelDataType rowType, |
111 |
| - RexBuilder rexBuilder, |
112 |
| - int primitiveFieldCount, |
113 |
| - ArrayBuffer<RexNode> extractedRexNodes, |
114 |
| - List<RexNode> calcProjects) { |
115 |
| - for (int i = 0; i < primitiveFieldCount; i++) { |
116 |
| - calcProjects.add(RexInputRef.of(i, rowType)); |
117 |
| - } |
118 |
| - // change RexCorrelVariable to RexInputRef. |
119 |
| - RexDefaultVisitor<RexNode> visitor = |
120 |
| - new RexDefaultVisitor<RexNode>() { |
121 |
| - @Override |
122 |
| - public RexNode visitFieldAccess(RexFieldAccess fieldAccess) { |
123 |
| - RexNode expr = fieldAccess.getReferenceExpr(); |
124 |
| - if (expr instanceof RexCorrelVariable) { |
125 |
| - RelDataTypeField field = fieldAccess.getField(); |
126 |
| - return new RexInputRef(field.getIndex(), field.getType()); |
127 |
| - } else { |
128 |
| - return rexBuilder.makeFieldAccess( |
129 |
| - expr.accept(this), fieldAccess.getField().getIndex()); |
130 |
| - } |
131 |
| - } |
132 |
| - |
133 |
| - @Override |
134 |
| - public RexNode visitNode(RexNode rexNode) { |
135 |
| - return rexNode; |
136 |
| - } |
137 |
| - }; |
138 |
| - // add the fields of the extracted rex calls. |
139 |
| - Iterator<RexNode> iterator = extractedRexNodes.iterator(); |
140 |
| - while (iterator.hasNext()) { |
141 |
| - RexNode rexNode = iterator.next(); |
142 |
| - if (rexNode instanceof RexCall) { |
143 |
| - RexCall rexCall = (RexCall) rexNode; |
144 |
| - List<RexNode> newProjects = |
145 |
| - rexCall.getOperands().stream() |
146 |
| - .map(x -> x.accept(visitor)) |
147 |
| - .collect(Collectors.toList()); |
148 |
| - RexCall newRexCall = rexCall.clone(rexCall.getType(), newProjects); |
149 |
| - calcProjects.add(newRexCall); |
150 |
| - } else { |
151 |
| - calcProjects.add(rexNode); |
152 |
| - } |
153 |
| - } |
154 |
| - |
155 |
| - List<String> nameList = new LinkedList<>(); |
156 |
| - for (int i = 0; i < primitiveFieldCount; i++) { |
157 |
| - nameList.add(rowType.getFieldNames().get(i)); |
158 |
| - } |
159 |
| - Iterator<Object> indicesIterator = extractedRexNodes.indices().iterator(); |
160 |
| - while (indicesIterator.hasNext()) { |
161 |
| - nameList.add("f" + indicesIterator.next()); |
162 |
| - } |
163 |
| - return SqlValidatorUtil.uniquify( |
164 |
| - nameList, rexBuilder.getTypeFactory().getTypeSystem().isSchemaCaseSensitive()); |
165 |
| - } |
166 |
| - |
167 |
| - private FlinkLogicalCalc createNewLeftCalc( |
168 |
| - RelNode left, |
169 |
| - RexBuilder rexBuilder, |
170 |
| - ArrayBuffer<RexNode> extractedRexNodes, |
171 |
| - FlinkLogicalCorrelate correlate) { |
172 |
| - // add the fields of the primitive left input. |
173 |
| - List<RexNode> leftCalcProjects = new LinkedList<>(); |
174 |
| - RelDataType leftRowType = left.getRowType(); |
175 |
| - List<String> leftCalcCalcFieldNames = |
176 |
| - createNewFieldNames( |
177 |
| - leftRowType, |
178 |
| - rexBuilder, |
179 |
| - leftRowType.getFieldCount(), |
180 |
| - extractedRexNodes, |
181 |
| - leftCalcProjects); |
182 |
| - |
183 |
| - // create a new calc |
184 |
| - return new FlinkLogicalCalc( |
185 |
| - correlate.getCluster(), |
186 |
| - correlate.getTraitSet(), |
187 |
| - left, |
188 |
| - RexProgram.create( |
189 |
| - leftRowType, leftCalcProjects, null, leftCalcCalcFieldNames, rexBuilder)); |
190 |
| - } |
191 |
| - |
192 |
| - private FlinkLogicalCalc createTopCalc( |
193 |
| - int primitiveLeftFieldCount, |
194 |
| - RexBuilder rexBuilder, |
195 |
| - ArrayBuffer<RexNode> extractedRexNodes, |
196 |
| - RelDataType calcRowType, |
197 |
| - FlinkLogicalCorrelate newCorrelate) { |
198 |
| - RexProgram rexProgram = |
199 |
| - new RexProgramBuilder(newCorrelate.getRowType(), rexBuilder).getProgram(); |
200 |
| - int offset = extractedRexNodes.size() + primitiveLeftFieldCount; |
201 |
| - |
202 |
| - // extract correlate output RexNode. |
203 |
| - List<RexNode> newTopCalcProjects = |
204 |
| - rexProgram.getExprList().stream() |
205 |
| - .filter(x -> x instanceof RexInputRef) |
206 |
| - .filter( |
207 |
| - x -> { |
208 |
| - int index = ((RexInputRef) x).getIndex(); |
209 |
| - return index < primitiveLeftFieldCount || index >= offset; |
210 |
| - }) |
211 |
| - .collect(Collectors.toList()); |
212 |
| - |
213 |
| - return new FlinkLogicalCalc( |
214 |
| - newCorrelate.getCluster(), |
215 |
| - newCorrelate.getTraitSet(), |
216 |
| - newCorrelate, |
217 |
| - RexProgram.create( |
218 |
| - newCorrelate.getRowType(), |
219 |
| - newTopCalcProjects, |
220 |
| - null, |
221 |
| - calcRowType, |
222 |
| - rexBuilder)); |
223 |
| - } |
224 |
| - |
225 |
| - private ScalarFunctionSplitter createScalarFunctionSplitter( |
226 |
| - RexProgram program, |
227 |
| - RexBuilder rexBuilder, |
228 |
| - int primitiveLeftFieldCount, |
229 |
| - ArrayBuffer<RexNode> extractedRexNodes, |
230 |
| - RexNode tableFunctionNode) { |
231 |
| - return new ScalarFunctionSplitter( |
232 |
| - program, |
233 |
| - rexBuilder, |
234 |
| - primitiveLeftFieldCount, |
235 |
| - extractedRexNodes, |
236 |
| - node -> { |
237 |
| - if (PythonUtil.isNonPythonCall(tableFunctionNode)) { |
238 |
| - // splits the RexCalls which contain Python functions into separate node |
239 |
| - return PythonUtil.isPythonCall(node, null); |
240 |
| - } else if (PythonUtil.containsNonPythonCall(node)) { |
241 |
| - // splits the RexCalls which contain non-Python functions into separate node |
242 |
| - return PythonUtil.isNonPythonCall(node); |
243 |
| - } else { |
244 |
| - // splits the RexFieldAccesses which contain non-Python functions into |
245 |
| - // separate node |
246 |
| - return node instanceof RexFieldAccess; |
247 |
| - } |
248 |
| - }, |
249 |
| - new PythonRemoteCalcCallFinder()); |
250 |
| - } |
251 |
| - |
252 |
| - @Override |
253 |
| - public void onMatch(RelOptRuleCall call) { |
254 |
| - FlinkLogicalCorrelate correlate = call.rel(0); |
255 |
| - RexBuilder rexBuilder = call.builder().getRexBuilder(); |
256 |
| - RelNode left = ((HepRelVertex) correlate.getLeft()).getCurrentRel(); |
257 |
| - RelNode right = ((HepRelVertex) correlate.getRight()).getCurrentRel(); |
258 |
| - int primitiveLeftFieldCount = left.getRowType().getFieldCount(); |
259 |
| - ArrayBuffer<RexNode> extractedRexNodes = new ArrayBuffer<>(); |
260 |
| - |
261 |
| - RelNode rightNewInput; |
262 |
| - if (right instanceof FlinkLogicalTableFunctionScan) { |
263 |
| - FlinkLogicalTableFunctionScan scan = (FlinkLogicalTableFunctionScan) right; |
264 |
| - rightNewInput = |
265 |
| - createNewScan( |
266 |
| - scan, |
267 |
| - createScalarFunctionSplitter( |
268 |
| - null, |
269 |
| - rexBuilder, |
270 |
| - primitiveLeftFieldCount, |
271 |
| - extractedRexNodes, |
272 |
| - scan.getCall())); |
273 |
| - } else { |
274 |
| - FlinkLogicalCalc calc = (FlinkLogicalCalc) right; |
275 |
| - FlinkLogicalTableFunctionScan scan = StreamPhysicalCorrelateRule.getTableScan(calc); |
276 |
| - FlinkLogicalCalc mergedCalc = StreamPhysicalCorrelateRule.getMergedCalc(calc); |
277 |
| - FlinkLogicalTableFunctionScan newScan = |
278 |
| - createNewScan( |
279 |
| - scan, |
280 |
| - createScalarFunctionSplitter( |
281 |
| - null, |
282 |
| - rexBuilder, |
283 |
| - primitiveLeftFieldCount, |
284 |
| - extractedRexNodes, |
285 |
| - scan.getCall())); |
286 |
| - rightNewInput = |
287 |
| - mergedCalc.copy(mergedCalc.getTraitSet(), newScan, mergedCalc.getProgram()); |
288 |
| - } |
289 |
| - |
290 |
| - FlinkLogicalCorrelate newCorrelate; |
291 |
| - if (extractedRexNodes.size() > 0) { |
292 |
| - FlinkLogicalCalc leftCalc = |
293 |
| - createNewLeftCalc(left, rexBuilder, extractedRexNodes, correlate); |
294 |
| - |
295 |
| - newCorrelate = |
296 |
| - new FlinkLogicalCorrelate( |
297 |
| - correlate.getCluster(), |
298 |
| - correlate.getTraitSet(), |
299 |
| - leftCalc, |
300 |
| - rightNewInput, |
301 |
| - correlate.getCorrelationId(), |
302 |
| - correlate.getRequiredColumns(), |
303 |
| - correlate.getJoinType()); |
304 |
| - } else { |
305 |
| - newCorrelate = |
306 |
| - new FlinkLogicalCorrelate( |
307 |
| - correlate.getCluster(), |
308 |
| - correlate.getTraitSet(), |
309 |
| - left, |
310 |
| - rightNewInput, |
311 |
| - correlate.getCorrelationId(), |
312 |
| - correlate.getRequiredColumns(), |
313 |
| - correlate.getJoinType()); |
314 |
| - } |
315 |
| - |
316 |
| - FlinkLogicalCalc newTopCalc = |
317 |
| - createTopCalc( |
318 |
| - primitiveLeftFieldCount, |
319 |
| - rexBuilder, |
320 |
| - extractedRexNodes, |
321 |
| - correlate.getRowType(), |
322 |
| - newCorrelate); |
| 31 | +public class PythonCorrelateSplitRule { |
323 | 32 |
|
324 |
| - call.transformTo(newTopCalc); |
325 |
| - } |
| 33 | + public static final RemoteCorrelateSplitRule INSTANCE = |
| 34 | + new RemoteCorrelateSplitRule(new PythonRemoteCalcCallFinder()); |
326 | 35 | }
|
0 commit comments