Skip to content

Commit

Permalink
[Enhancement] Support Generated Column rewrite in complex Query (#50398)
Browse files Browse the repository at this point in the history
Signed-off-by: srlch <[email protected]>
  • Loading branch information
srlch committed Sep 9, 2024
1 parent 90ce1c2 commit 96b97a7
Show file tree
Hide file tree
Showing 13 changed files with 469 additions and 180 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import com.starrocks.analysis.FunctionCallExpr;
import com.starrocks.analysis.LimitElement;
import com.starrocks.analysis.OrderByElement;
import com.starrocks.analysis.SlotRef;
import com.starrocks.common.IdGenerator;
import com.starrocks.sql.ast.Relation;
import com.starrocks.sql.ast.SelectRelation;
Expand Down Expand Up @@ -56,8 +55,6 @@ public class AnalyzeState {
private Scope orderScope;
private List<Expr> orderSourceExpressions;

private Map<Expr, SlotRef> generatedExprToColumnRef = new HashMap<>();

/**
* outputExprInOrderByScope is used to record which expressions in outputExpression are to be
* recorded in the first level of OrderByScope (order by expressions can refer to columns in output)
Expand Down Expand Up @@ -257,12 +254,4 @@ public ExprId getNextNondeterministicId() {
public List<Expr> getColumnNotInGroupBy() {
return columnNotInGroupBy;
}

public void setGeneratedExprToColumnRef(Map<Expr, SlotRef> generatedExprToColumnRef) {
this.generatedExprToColumnRef = generatedExprToColumnRef;
}

public Map<Expr, SlotRef> getGeneratedExprToColumnRef() {
return generatedExprToColumnRef;
}
}
281 changes: 168 additions & 113 deletions fe/fe-core/src/main/java/com/starrocks/sql/analyzer/QueryAnalyzer.java

Large diffs are not rendered by default.

18 changes: 18 additions & 0 deletions fe/fe-core/src/main/java/com/starrocks/sql/ast/Relation.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,19 @@

package com.starrocks.sql.ast;

import com.starrocks.analysis.Expr;
import com.starrocks.analysis.ParseNode;
import com.starrocks.analysis.SlotRef;
import com.starrocks.analysis.TableName;
import com.starrocks.sql.analyzer.RelationFields;
import com.starrocks.sql.analyzer.Scope;
import com.starrocks.sql.common.ErrorType;
import com.starrocks.sql.common.StarRocksPlannerException;
import com.starrocks.sql.parser.NodePosition;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

public abstract class Relation implements ParseNode {
private Scope scope;
Expand All @@ -39,6 +43,12 @@ public abstract class Relation implements ParseNode {
// generated by Security Policy rewriting does not perform permission verification.
private boolean createByPolicyRewritten = false;

/**
* generatedExprToColumnRef stores the mapping relationship
* between generated expressions and generated columns
*/
private Map<Expr, SlotRef> generatedExprToColumnRef = new HashMap<>();

protected final NodePosition pos;

protected Relation(NodePosition pos) {
Expand Down Expand Up @@ -109,6 +119,14 @@ public List<String> getExplicitColumnNames() {
return explicitColumnNames;
}

public void setGeneratedExprToColumnRef(Map<Expr, SlotRef> generatedExprToColumnRef) {
this.generatedExprToColumnRef = generatedExprToColumnRef;
}

public Map<Expr, SlotRef> getGeneratedExprToColumnRef() {
return generatedExprToColumnRef;
}

@Override
public <R, C> R accept(AstVisitor<R, C> visitor, C context) {
return visitor.visitRelation(this, context);
Expand Down
14 changes: 0 additions & 14 deletions fe/fe-core/src/main/java/com/starrocks/sql/ast/SelectRelation.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,11 @@
import com.starrocks.analysis.GroupByClause;
import com.starrocks.analysis.LimitElement;
import com.starrocks.analysis.OrderByElement;
import com.starrocks.analysis.SlotRef;
import com.starrocks.sql.analyzer.AnalyzeState;
import com.starrocks.sql.analyzer.FieldId;
import com.starrocks.sql.analyzer.Scope;
import com.starrocks.sql.parser.NodePosition;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

Expand Down Expand Up @@ -79,12 +77,6 @@ public class SelectRelation extends QueryRelation {

private Map<Expr, FieldId> columnReferences;

/**
* materializeExpressionToColumnRef stores the mapping relationship
* between generated expressions and generated columns
*/
private Map<Expr, SlotRef> generatedExprToColumnRef = new HashMap<>();

public SelectRelation(
SelectList selectList,
Relation fromRelation,
Expand Down Expand Up @@ -160,8 +152,6 @@ public void fillResolvedAST(AnalyzeState analyzeState) {

this.columnReferences = analyzeState.getColumnReferences();

this.generatedExprToColumnRef = analyzeState.getGeneratedExprToColumnRef();

this.setScope(analyzeState.getOutputScope());
}

Expand Down Expand Up @@ -307,8 +297,4 @@ public boolean hasAnalyticInfo() {
public List<Expr> getOutputExpression() {
return outputExpr;
}

public Map<Expr, SlotRef> getGeneratedExprToColumnRef() {
return generatedExprToColumnRef;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import com.starrocks.common.Config;
import com.starrocks.sql.common.ErrorType;
import com.starrocks.sql.common.StarRocksPlannerException;
import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator;
import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator;
import com.starrocks.sql.optimizer.rewrite.scalar.ArithmeticCommutativeRule;
import com.starrocks.sql.optimizer.rewrite.scalar.ConsolidateLikesRule;
Expand All @@ -28,12 +29,14 @@
import com.starrocks.sql.optimizer.rewrite.scalar.NormalizePredicateRule;
import com.starrocks.sql.optimizer.rewrite.scalar.PruneTediousPredicateRule;
import com.starrocks.sql.optimizer.rewrite.scalar.ReduceCastRule;
import com.starrocks.sql.optimizer.rewrite.scalar.ReplaceScalarOperatorRule;
import com.starrocks.sql.optimizer.rewrite.scalar.ScalarOperatorRewriteRule;
import com.starrocks.sql.optimizer.rewrite.scalar.SimplifiedCaseWhenRule;
import com.starrocks.sql.optimizer.rewrite.scalar.SimplifiedPredicateRule;
import com.starrocks.sql.optimizer.rewrite.scalar.SimplifiedScanColumnRule;

import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

public class ScalarOperatorRewriter {
Expand Down Expand Up @@ -157,4 +160,10 @@ public static ScalarOperator simplifyCaseWhen(ScalarOperator predicates) {
// simplify case-when
return new ScalarOperatorRewriter().rewrite(predicates, CASE_WHEN_PREDICATE_RULE);
}

public static ScalarOperator replaceScalarOperatorByColumnRef(ScalarOperator operator,
Map<ScalarOperator, ColumnRefOperator> translateMap) {
ReplaceScalarOperatorRule rule = new ReplaceScalarOperatorRule(translateMap);
return new ScalarOperatorRewriter().rewrite(operator, Lists.newArrayList(rule));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright 2021-present StarRocks, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package com.starrocks.sql.optimizer.rewrite.scalar;

import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator;
import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator;
import com.starrocks.sql.optimizer.rewrite.ScalarOperatorRewriteContext;

import java.util.Map;

public class ReplaceScalarOperatorRule extends BottomUpScalarOperatorRewriteRule {
private Map<ScalarOperator, ColumnRefOperator> translateMap;

public ReplaceScalarOperatorRule(Map<ScalarOperator, ColumnRefOperator> translateMap) {
this.translateMap = translateMap;
}

@Override
public ScalarOperator visit(ScalarOperator scalarOperator, ScalarOperatorRewriteContext context) {
for (Map.Entry<ScalarOperator, ColumnRefOperator> m : translateMap.entrySet()) {
if (ScalarOperator.isEquivalent(m.getKey(), scalarOperator)) {
return m.getValue();
}
}
return scalarOperator;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ public class ExpressionMapping {
// record columnRefOp which is generated by const expr in project node
// if this columnRefOp is referenced by upper node, we should replace it with const expr in upper node
private Map<ColumnRefOperator, ScalarOperator> columnRefToConstOperators = new HashMap<>();
private Map<ScalarOperator, ColumnRefOperator> generatedColumnExprOpToColumnRef = new HashMap<>();

public ExpressionMapping(Scope scope, List<ColumnRefOperator> fieldMappings) {
this.scope = scope;
Expand Down Expand Up @@ -202,4 +203,12 @@ public Map<ColumnRefOperator, ScalarOperator> getColumnRefToConstOperators() {
public void addColumnRefToConstOperators(Map<ColumnRefOperator, ScalarOperator> columnRefToConstOperators) {
this.columnRefToConstOperators.putAll(columnRefToConstOperators);
}

public Map<ScalarOperator, ColumnRefOperator> getGeneratedColumnExprOpToColumnRef() {
return generatedColumnExprOpToColumnRef;
}

public void addGeneratedColumnExprOpToColumnRef(Map<ScalarOperator, ColumnRefOperator> generatedColumnExprOpToColumnRef) {
this.generatedColumnExprOpToColumnRef.putAll(generatedColumnExprOpToColumnRef);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ public Map<ColumnRefOperator, ScalarOperator> getColumnRefToConstOperators() {
return expressionMapping.getColumnRefToConstOperators();
}

public Map<ScalarOperator, ColumnRefOperator> getGeneratedColumnExprOpToColumnRef() {
return expressionMapping.getGeneratedColumnExprOpToColumnRef();
}

public void setExpressionMapping(ExpressionMapping expressionMapping) {
this.expressionMapping = expressionMapping;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,15 @@ public LogicalPlan plan(SelectRelation queryBlock, ExpressionMapping outer) {
builder.getColumnRefToConstOperators()));

Map<Expr, SlotRef> generatedExprToColumnRef = queryBlock.getGeneratedExprToColumnRef();
ExpressionMapping expressionMapping = builder.getExpressionMapping();
Map<ScalarOperator, ColumnRefOperator> generatedColumnExprOpToColumnRef = new HashMap<>();
for (Map.Entry<Expr, SlotRef> m : generatedExprToColumnRef.entrySet()) {
ScalarOperator scalarOperator = SqlToScalarOperatorTranslator.translate(m.getValue(),
ScalarOperator scalarOperator = SqlToScalarOperatorTranslator.translate(m.getKey(),
builder.getExpressionMapping(), columnRefFactory);
expressionMapping.put(m.getKey(), (ColumnRefOperator) scalarOperator);
ColumnRefOperator columnRefOp = (ColumnRefOperator) SqlToScalarOperatorTranslator.translate(m.getValue(),
builder.getExpressionMapping(), columnRefFactory);
generatedColumnExprOpToColumnRef.put(scalarOperator, columnRefOp);
}
builder.getExpressionMapping().addGeneratedColumnExprOpToColumnRef(generatedColumnExprOpToColumnRef);

builder = filter(builder, queryBlock.getPredicate());
builder = aggregate(builder, queryBlock.getGroupBy(), queryBlock.getAggregate(),
Expand Down Expand Up @@ -246,6 +249,7 @@ private OptExprBuilder projectForOrder(OptExprBuilder subOpt,

outputTranslations.addExpressionToColumns(subOpt.getExpressionMapping().getExpressionToColumns());
outputTranslations.addColumnRefToConstOperators(subOpt.getColumnRefToConstOperators());
outputTranslations.addGeneratedColumnExprOpToColumnRef(subOpt.getGeneratedColumnExprOpToColumnRef());

LogicalProjectOperator projectOperator = new LogicalProjectOperator(projections);
return new OptExprBuilder(projectOperator, Lists.newArrayList(subOpt), outputTranslations);
Expand Down Expand Up @@ -279,6 +283,7 @@ private OptExprBuilder project(OptExprBuilder subOpt, Iterable<Expr> expressions

outputTranslations.addExpressionToColumns(subOpt.getExpressionMapping().getExpressionToColumns());
outputTranslations.addColumnRefToConstOperators(subOpt.getColumnRefToConstOperators());
outputTranslations.addGeneratedColumnExprOpToColumnRef(subOpt.getGeneratedColumnExprOpToColumnRef());

LogicalProjectOperator projectOperator = new LogicalProjectOperator(projections, limit);
return new OptExprBuilder(projectOperator, Lists.newArrayList(subOpt), outputTranslations);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,9 @@ public static ScalarOperator translate(Expr expression, ExpressionMapping expres
ScalarOperatorRewriter scalarRewriter = new ScalarOperatorRewriter();
result = scalarRewriter.rewrite(result, ScalarOperatorRewriter.DEFAULT_REWRITE_RULES);

result = ScalarOperatorRewriter.replaceScalarOperatorByColumnRef(result,
expressionMapping.getGeneratedColumnExprOpToColumnRef());

requireNonNull(result, "translated expression is null");
return result;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,11 @@ public void test() throws Exception {

sql = " select tmc.v1 + 1 from tmc as v,tmc2 as tmc";
plan = getFragmentPlan(sql);
assertContains(plan, "<slot 8> : 5: v1 + 1");
assertContains(plan, "<slot 3> : 3: v3");

sql = " select tmc.v1 + 1 from tmc as v,tmc2 as tmc";
plan = getFragmentPlan(sql);
assertContains(plan, "<slot 8> : 5: v1 + 1");
assertContains(plan, "<slot 3> : 3: v3");

sql = " select * from view_1";
plan = getFragmentPlan(sql);
Expand Down
Loading

0 comments on commit 96b97a7

Please sign in to comment.