Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] Support Generated Column rewrite in complex Query #50398

Merged
merged 1 commit into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import com.starrocks.analysis.FunctionCallExpr;
import com.starrocks.analysis.LimitElement;
import com.starrocks.analysis.OrderByElement;
import com.starrocks.analysis.SlotRef;
import com.starrocks.common.IdGenerator;
import com.starrocks.sql.ast.Relation;
import com.starrocks.sql.ast.SelectRelation;
Expand Down Expand Up @@ -56,8 +55,6 @@ public class AnalyzeState {
private Scope orderScope;
private List<Expr> orderSourceExpressions;

private Map<Expr, SlotRef> generatedExprToColumnRef = new HashMap<>();

/**
* outputExprInOrderByScope is used to record which expressions in outputExpression are to be
* recorded in the first level of OrderByScope (order by expressions can refer to columns in output)
Expand Down Expand Up @@ -257,12 +254,4 @@ public ExprId getNextNondeterministicId() {
public List<Expr> getColumnNotInGroupBy() {
return columnNotInGroupBy;
}

public void setGeneratedExprToColumnRef(Map<Expr, SlotRef> generatedExprToColumnRef) {
this.generatedExprToColumnRef = generatedExprToColumnRef;
}

public Map<Expr, SlotRef> getGeneratedExprToColumnRef() {
return generatedExprToColumnRef;
}
}
281 changes: 168 additions & 113 deletions fe/fe-core/src/main/java/com/starrocks/sql/analyzer/QueryAnalyzer.java

Large diffs are not rendered by default.

18 changes: 18 additions & 0 deletions fe/fe-core/src/main/java/com/starrocks/sql/ast/Relation.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,19 @@

package com.starrocks.sql.ast;

import com.starrocks.analysis.Expr;
import com.starrocks.analysis.ParseNode;
import com.starrocks.analysis.SlotRef;
import com.starrocks.analysis.TableName;
import com.starrocks.sql.analyzer.RelationFields;
import com.starrocks.sql.analyzer.Scope;
import com.starrocks.sql.common.ErrorType;
import com.starrocks.sql.common.StarRocksPlannerException;
import com.starrocks.sql.parser.NodePosition;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

public abstract class Relation implements ParseNode {
private Scope scope;
Expand All @@ -39,6 +43,12 @@ public abstract class Relation implements ParseNode {
// generated by Security Policy rewriting does not perform permission verification.
private boolean createByPolicyRewritten = false;

/**
* generatedExprToColumnRef stores the mapping relationship
* between generated expressions and generated columns
*/
private Map<Expr, SlotRef> generatedExprToColumnRef = new HashMap<>();

protected final NodePosition pos;

protected Relation(NodePosition pos) {
Expand Down Expand Up @@ -109,6 +119,14 @@ public List<String> getExplicitColumnNames() {
return explicitColumnNames;
}

public void setGeneratedExprToColumnRef(Map<Expr, SlotRef> generatedExprToColumnRef) {
this.generatedExprToColumnRef = generatedExprToColumnRef;
}

public Map<Expr, SlotRef> getGeneratedExprToColumnRef() {
return generatedExprToColumnRef;
}

@Override
public <R, C> R accept(AstVisitor<R, C> visitor, C context) {
return visitor.visitRelation(this, context);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The most risky bug in this code is:
Potential thread-safety issues with generatedExprToColumnRef.

You can modify the code like this:

import java.util.Collections;
import java.util.concurrent.ConcurrentHashMap;
//...

private Map<Expr, SlotRef> generatedExprToColumnRef = new ConcurrentHashMap<>();

public void setGeneratedExprToColumnRef(Map<Expr, SlotRef> generatedExprToColumnRef) {
    this.generatedExprToColumnRef = Collections.synchronizedMap(generatedExprToColumnRef);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,11 @@
import com.starrocks.analysis.GroupByClause;
import com.starrocks.analysis.LimitElement;
import com.starrocks.analysis.OrderByElement;
import com.starrocks.analysis.SlotRef;
import com.starrocks.sql.analyzer.AnalyzeState;
import com.starrocks.sql.analyzer.FieldId;
import com.starrocks.sql.analyzer.Scope;
import com.starrocks.sql.parser.NodePosition;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

Expand Down Expand Up @@ -79,12 +77,6 @@ public class SelectRelation extends QueryRelation {

private Map<Expr, FieldId> columnReferences;

/**
* materializeExpressionToColumnRef stores the mapping relationship
* between generated expressions and generated columns
*/
private Map<Expr, SlotRef> generatedExprToColumnRef = new HashMap<>();

public SelectRelation(
SelectList selectList,
Relation fromRelation,
Expand Down Expand Up @@ -160,8 +152,6 @@ public void fillResolvedAST(AnalyzeState analyzeState) {

this.columnReferences = analyzeState.getColumnReferences();

this.generatedExprToColumnRef = analyzeState.getGeneratedExprToColumnRef();

this.setScope(analyzeState.getOutputScope());
}

Expand Down Expand Up @@ -307,8 +297,4 @@ public boolean hasAnalyticInfo() {
public List<Expr> getOutputExpression() {
return outputExpr;
}

public Map<Expr, SlotRef> getGeneratedExprToColumnRef() {
return generatedExprToColumnRef;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import com.starrocks.common.Config;
import com.starrocks.sql.common.ErrorType;
import com.starrocks.sql.common.StarRocksPlannerException;
import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator;
import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator;
import com.starrocks.sql.optimizer.rewrite.scalar.ArithmeticCommutativeRule;
import com.starrocks.sql.optimizer.rewrite.scalar.ConsolidateLikesRule;
Expand All @@ -28,12 +29,14 @@
import com.starrocks.sql.optimizer.rewrite.scalar.NormalizePredicateRule;
import com.starrocks.sql.optimizer.rewrite.scalar.PruneTediousPredicateRule;
import com.starrocks.sql.optimizer.rewrite.scalar.ReduceCastRule;
import com.starrocks.sql.optimizer.rewrite.scalar.ReplaceScalarOperatorRule;
import com.starrocks.sql.optimizer.rewrite.scalar.ScalarOperatorRewriteRule;
import com.starrocks.sql.optimizer.rewrite.scalar.SimplifiedCaseWhenRule;
import com.starrocks.sql.optimizer.rewrite.scalar.SimplifiedPredicateRule;
import com.starrocks.sql.optimizer.rewrite.scalar.SimplifiedScanColumnRule;

import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

public class ScalarOperatorRewriter {
Expand Down Expand Up @@ -157,4 +160,10 @@ public static ScalarOperator simplifyCaseWhen(ScalarOperator predicates) {
// simplify case-when
return new ScalarOperatorRewriter().rewrite(predicates, CASE_WHEN_PREDICATE_RULE);
}

public static ScalarOperator replaceScalarOperatorByColumnRef(ScalarOperator operator,
Map<ScalarOperator, ColumnRefOperator> translateMap) {
ReplaceScalarOperatorRule rule = new ReplaceScalarOperatorRule(translateMap);
return new ScalarOperatorRewriter().rewrite(operator, Lists.newArrayList(rule));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright 2021-present StarRocks, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package com.starrocks.sql.optimizer.rewrite.scalar;

import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator;
import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator;
import com.starrocks.sql.optimizer.rewrite.ScalarOperatorRewriteContext;

import java.util.Map;

public class ReplaceScalarOperatorRule extends BottomUpScalarOperatorRewriteRule {
private Map<ScalarOperator, ColumnRefOperator> translateMap;

public ReplaceScalarOperatorRule(Map<ScalarOperator, ColumnRefOperator> translateMap) {
this.translateMap = translateMap;
}

@Override
public ScalarOperator visit(ScalarOperator scalarOperator, ScalarOperatorRewriteContext context) {
for (Map.Entry<ScalarOperator, ColumnRefOperator> m : translateMap.entrySet()) {
if (ScalarOperator.isEquivalent(m.getKey(), scalarOperator)) {
return m.getValue();
}
}
return scalarOperator;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ public class ExpressionMapping {
// record columnRefOp which is generated by const expr in project node
// if this columnRefOp is referenced by upper node, we should replace it with const expr in upper node
private Map<ColumnRefOperator, ScalarOperator> columnRefToConstOperators = new HashMap<>();
private Map<ScalarOperator, ColumnRefOperator> generatedColumnExprOpToColumnRef = new HashMap<>();

public ExpressionMapping(Scope scope, List<ColumnRefOperator> fieldMappings) {
this.scope = scope;
Expand Down Expand Up @@ -202,4 +203,12 @@ public Map<ColumnRefOperator, ScalarOperator> getColumnRefToConstOperators() {
public void addColumnRefToConstOperators(Map<ColumnRefOperator, ScalarOperator> columnRefToConstOperators) {
this.columnRefToConstOperators.putAll(columnRefToConstOperators);
}

public Map<ScalarOperator, ColumnRefOperator> getGeneratedColumnExprOpToColumnRef() {
return generatedColumnExprOpToColumnRef;
}

public void addGeneratedColumnExprOpToColumnRef(Map<ScalarOperator, ColumnRefOperator> generatedColumnExprOpToColumnRef) {
this.generatedColumnExprOpToColumnRef.putAll(generatedColumnExprOpToColumnRef);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ public Map<ColumnRefOperator, ScalarOperator> getColumnRefToConstOperators() {
return expressionMapping.getColumnRefToConstOperators();
}

public Map<ScalarOperator, ColumnRefOperator> getGeneratedColumnExprOpToColumnRef() {
return expressionMapping.getGeneratedColumnExprOpToColumnRef();
}

public void setExpressionMapping(ExpressionMapping expressionMapping) {
this.expressionMapping = expressionMapping;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,15 @@ public LogicalPlan plan(SelectRelation queryBlock, ExpressionMapping outer) {
builder.getColumnRefToConstOperators()));

Map<Expr, SlotRef> generatedExprToColumnRef = queryBlock.getGeneratedExprToColumnRef();
ExpressionMapping expressionMapping = builder.getExpressionMapping();
Map<ScalarOperator, ColumnRefOperator> generatedColumnExprOpToColumnRef = new HashMap<>();
for (Map.Entry<Expr, SlotRef> m : generatedExprToColumnRef.entrySet()) {
ScalarOperator scalarOperator = SqlToScalarOperatorTranslator.translate(m.getValue(),
ScalarOperator scalarOperator = SqlToScalarOperatorTranslator.translate(m.getKey(),
builder.getExpressionMapping(), columnRefFactory);
expressionMapping.put(m.getKey(), (ColumnRefOperator) scalarOperator);
ColumnRefOperator columnRefOp = (ColumnRefOperator) SqlToScalarOperatorTranslator.translate(m.getValue(),
builder.getExpressionMapping(), columnRefFactory);
generatedColumnExprOpToColumnRef.put(scalarOperator, columnRefOp);
}
builder.getExpressionMapping().addGeneratedColumnExprOpToColumnRef(generatedColumnExprOpToColumnRef);

builder = filter(builder, queryBlock.getPredicate());
builder = aggregate(builder, queryBlock.getGroupBy(), queryBlock.getAggregate(),
Expand Down Expand Up @@ -246,6 +249,7 @@ private OptExprBuilder projectForOrder(OptExprBuilder subOpt,

outputTranslations.addExpressionToColumns(subOpt.getExpressionMapping().getExpressionToColumns());
outputTranslations.addColumnRefToConstOperators(subOpt.getColumnRefToConstOperators());
outputTranslations.addGeneratedColumnExprOpToColumnRef(subOpt.getGeneratedColumnExprOpToColumnRef());

LogicalProjectOperator projectOperator = new LogicalProjectOperator(projections);
return new OptExprBuilder(projectOperator, Lists.newArrayList(subOpt), outputTranslations);
Expand Down Expand Up @@ -279,6 +283,7 @@ private OptExprBuilder project(OptExprBuilder subOpt, Iterable<Expr> expressions

outputTranslations.addExpressionToColumns(subOpt.getExpressionMapping().getExpressionToColumns());
outputTranslations.addColumnRefToConstOperators(subOpt.getColumnRefToConstOperators());
outputTranslations.addGeneratedColumnExprOpToColumnRef(subOpt.getGeneratedColumnExprOpToColumnRef());

LogicalProjectOperator projectOperator = new LogicalProjectOperator(projections, limit);
return new OptExprBuilder(projectOperator, Lists.newArrayList(subOpt), outputTranslations);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,9 @@ public static ScalarOperator translate(Expr expression, ExpressionMapping expres
ScalarOperatorRewriter scalarRewriter = new ScalarOperatorRewriter();
result = scalarRewriter.rewrite(result, ScalarOperatorRewriter.DEFAULT_REWRITE_RULES);

result = ScalarOperatorRewriter.replaceScalarOperatorByColumnRef(result,
expressionMapping.getGeneratedColumnExprOpToColumnRef());

requireNonNull(result, "translated expression is null");
return result;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,11 @@ public void test() throws Exception {

sql = " select tmc.v1 + 1 from tmc as v,tmc2 as tmc";
plan = getFragmentPlan(sql);
assertContains(plan, "<slot 8> : 5: v1 + 1");
assertContains(plan, "<slot 3> : 3: v3");

sql = " select tmc.v1 + 1 from tmc as v,tmc2 as tmc";
plan = getFragmentPlan(sql);
assertContains(plan, "<slot 8> : 5: v1 + 1");
assertContains(plan, "<slot 3> : 3: v3");

sql = " select * from view_1";
plan = getFragmentPlan(sql);
Expand Down
Loading
Loading