Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove identity casts from optimizer #329

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import com.facebook.presto.Session;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.spi.type.TypeSignature;
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.planner.DeterminismEvaluator;
import com.facebook.presto.sql.planner.ExpressionInterpreter;
Expand All @@ -30,18 +31,18 @@
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.facebook.presto.sql.planner.plan.TableScanNode;
import com.facebook.presto.sql.planner.plan.ValuesNode;
import com.facebook.presto.sql.tree.Cast;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.ExpressionRewriter;
import com.facebook.presto.sql.tree.ExpressionTreeRewriter;
import com.facebook.presto.sql.tree.GenericLiteral;
import com.facebook.presto.sql.tree.LogicalBinaryExpression;
import com.facebook.presto.sql.tree.NotExpression;
import com.facebook.presto.sql.tree.NullLiteral;
import com.facebook.presto.sql.tree.SymbolReference;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;

import java.util.Collection;
Expand All @@ -58,6 +59,7 @@
import static com.facebook.presto.sql.tree.ComparisonExpression.Type.IS_DISTINCT_FROM;
import static com.facebook.presto.sql.tree.LogicalBinaryExpression.Type.OR;
import static com.facebook.presto.util.ImmutableCollectors.toImmutableList;
import static com.facebook.presto.util.ImmutableCollectors.toImmutableMap;
import static java.util.Collections.emptyList;
import static java.util.Collections.emptySet;
import static java.util.Objects.requireNonNull;
Expand Down Expand Up @@ -94,6 +96,7 @@ private static class Rewriter
private final Session session;
private final Map<Symbol, Type> types;
private final PlanNodeIdAllocator idAllocator;
private Map<Symbol, Expression> expressionAssignments;

public Rewriter(Metadata metadata, SqlParser sqlParser, Session session, Map<Symbol, Type> types, PlanNodeIdAllocator idAllocator)
{
Expand All @@ -108,8 +111,10 @@ public Rewriter(Metadata metadata, SqlParser sqlParser, Session session, Map<Sym
public PlanNode visitProject(ProjectNode node, RewriteContext<Void> context)
{
PlanNode source = context.rewrite(node.getSource());
Map<Symbol, Expression> assignments = ImmutableMap.copyOf(Maps.transformValues(node.getAssignments(), this::simplifyExpression));
return new ProjectNode(node.getId(), source, assignments);
expressionAssignments = node.getAssignments();

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You shouldn't be using class field to pass this information between visit methods. Use context instead.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmmm. I need to pass the assignments of the ProjectNode that I'm currently visiting as context to the rewriter. I can't pass additional context to simplifyExpression where the rewriters are called because it has a set function definition, which is why I put it in a class field. I'm not sure of a better way to get the context to the rewriter.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can have a wrapper class doing a composition of functions and assignments.

Map<Symbol, Expression> simplifiedAssignments = expressionAssignments.entrySet().stream()
.collect(toImmutableMap(entry -> entry.getKey(), entry -> simplifyExpression(entry.getValue())));
return new ProjectNode(node.getId(), source, simplifiedAssignments);
}

@Override
Expand Down Expand Up @@ -150,6 +155,11 @@ private Expression simplifyExpression(Expression expression)
if (expression instanceof SymbolReference) {
return expression;
}

if (expressionAssignments != null && types != null) {
RemoveIdentityCastContext removeIdentityCastContext = new RemoveIdentityCastContext(expressionAssignments, types);
expression = ExpressionTreeRewriter.rewriteWith(new RemoveIdentityCastsRewriter(), expression, removeIdentityCastContext);
}
expression = ExpressionTreeRewriter.rewriteWith(new PushDownNegationsExpressionRewriter(), expression);
expression = ExpressionTreeRewriter.rewriteWith(new ExtractCommonPredicatesExpressionRewriter(), expression, NodeContext.ROOT_NODE);
IdentityHashMap<Expression, Type> expressionTypes = getExpressionTypes(session, metadata, sqlParser, types, expression, emptyList() /* parameters already replaced */);
Expand Down Expand Up @@ -271,4 +281,70 @@ private static <T> List<T> removeAll(Collection<T> collection, Collection<T> ele
.collect(toImmutableList());
}
}

private static class RemoveIdentityCastContext
{
private final Map<Symbol, Expression> expressionAssignments;
private final Map<Symbol, Type> typeAssignments;

public RemoveIdentityCastContext(Map<Symbol, Expression> expressionAssignments,
Map<Symbol, Type> typeAssignments)
{
requireNonNull(expressionAssignments);
requireNonNull(typeAssignments);

this.expressionAssignments = expressionAssignments;
this.typeAssignments = typeAssignments;
}

public Map<Symbol, Expression> getExpressionAssignments()
{
return expressionAssignments;
}

public Map<Symbol, Type> getTypeAssignments()
{
return typeAssignments;
}
}

private static class RemoveIdentityCastsRewriter
extends ExpressionRewriter<RemoveIdentityCastContext>
{
@Override
public Expression rewriteExpression(Expression node, RemoveIdentityCastContext context,
ExpressionTreeRewriter<RemoveIdentityCastContext> treeRewriter)
{
Map<Symbol, Type> typeAssignments = context.getTypeAssignments();
for (Map.Entry<Symbol, Expression> expressionAssignment : context.getExpressionAssignments().entrySet()) {
Symbol assignmentSymbol = expressionAssignment.getKey();
Expression expression = expressionAssignment.getValue();
if (expression == node) {
if (!(expression instanceof Cast)) {
return expression;
}

Expression expressionToCast = ((Cast) expression).getExpression();
TypeSignature typeOfExpressionToCastTo = typeAssignments.get(assignmentSymbol).getTypeSignature();
TypeSignature typeOfExpressionToCast;
if (expressionToCast instanceof SymbolReference) {
Symbol expressionSymbol = new Symbol(((SymbolReference) expressionToCast).getName());
typeOfExpressionToCast = typeAssignments.get(expressionSymbol).getTypeSignature();
}
else if (expressionToCast instanceof GenericLiteral) {
typeOfExpressionToCast = TypeSignature.parseTypeSignature(((GenericLiteral) expressionToCast).getType());
}
else {
return expression;
}

if (typeOfExpressionToCast.equals(typeOfExpressionToCastTo)) {
return expressionToCast;
}
return expression;
}
}
return node;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner;

import com.facebook.presto.sql.planner.assertions.PlanMatchPattern;
import org.intellij.lang.annotations.Language;

public interface PlanTester
{
Copy link

@kokosing kokosing Sep 1, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO interface here is an over engineering. I would prefer abstract class with below methods implemented.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this iface is needed at all?

public void assertPlanDoesNotMatch(@Language("SQL") String sql, PlanMatchPattern pattern);

public void assertPlanMatches(@Language("SQL") String sql, PlanMatchPattern pattern);

public Plan plan(@Language("SQL") String sql);
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import com.facebook.presto.tpch.TpchConnectorFactory;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.intellij.lang.annotations.Language;
import org.testng.annotations.BeforeTest;
import org.testng.annotations.Test;

import java.util.List;
Expand Down Expand Up @@ -54,10 +56,12 @@
import static org.testng.Assert.fail;

public class TestLogicalPlanner
implements PlanTester
{
private final LocalQueryRunner queryRunner;
private LocalQueryRunner queryRunner;

public TestLogicalPlanner()
@BeforeTest
public void setUp()
{
this.queryRunner = new LocalQueryRunner(testSessionBuilder()
.setCatalog("local")
Expand All @@ -72,7 +76,7 @@ public TestLogicalPlanner()
@Test
public void testJoin()
{
assertPlan("SELECT o.orderkey FROM orders o, lineitem l WHERE l.orderkey = o.orderkey",
assertPlanMatches("SELECT o.orderkey FROM orders o, lineitem l WHERE l.orderkey = o.orderkey",
anyTree(
join(INNER, ImmutableList.of(aliasPair("O", "L")),
any(
Expand All @@ -84,7 +88,7 @@ public void testJoin()
@Test
public void testUncorrelatedSubqueries()
{
assertPlan("SELECT * FROM orders WHERE orderkey = (SELECT orderkey FROM lineitem ORDER BY orderkey LIMIT 1)",
assertPlanMatches("SELECT * FROM orders WHERE orderkey = (SELECT orderkey FROM lineitem ORDER BY orderkey LIMIT 1)",
anyTree(
join(INNER, ImmutableList.of(aliasPair("X", "Y")),
project(
Expand All @@ -94,7 +98,7 @@ public void testUncorrelatedSubqueries()
anyTree(
tableScan("lineitem").withSymbol("orderkey", "Y")))))));

assertPlan("SELECT * FROM orders WHERE orderkey IN (SELECT orderkey FROM lineitem WHERE linenumber % 4 = 0)",
assertPlanMatches("SELECT * FROM orders WHERE orderkey IN (SELECT orderkey FROM lineitem WHERE linenumber % 4 = 0)",
anyTree(
filter("S",
project(
Expand All @@ -104,7 +108,7 @@ public void testUncorrelatedSubqueries()
anyTree(
tableScan("lineitem").withSymbol("orderkey", "Y")))))));

assertPlan("SELECT * FROM orders WHERE orderkey NOT IN (SELECT orderkey FROM lineitem WHERE linenumber < 0)",
assertPlanMatches("SELECT * FROM orders WHERE orderkey NOT IN (SELECT orderkey FROM lineitem WHERE linenumber < 0)",
anyTree(
filter("NOT S",
project(
Expand All @@ -122,7 +126,7 @@ public void testPushDownJoinConditionConjunctsToInnerSideBasedOnInheritedPredica
.put("name", singleValue(createVarcharType(25), utf8Slice("blah")))
.build();

assertPlan(
assertPlanMatches(
"SELECT nationkey FROM nation LEFT OUTER JOIN region " +
"ON nation.regionkey = region.regionkey and nation.name = region.name WHERE nation.name = 'blah'",
anyTree(
Expand Down Expand Up @@ -187,7 +191,7 @@ private void assertPlanContainsNoApplyOrJoin(String sql)
@Test
public void testCorrelatedSubqueries()
{
assertPlan(
assertPlanMatches(
"SELECT orderkey FROM orders WHERE 3 = (SELECT orderkey)",
LogicalPlanner.Stage.OPTIMIZED,
anyTree(
Expand All @@ -200,7 +204,7 @@ public void testCorrelatedSubqueries()
))))));

// double nesting
assertPlan(
assertPlanMatches(
"SELECT orderkey FROM orders o " +
"WHERE 3 IN (SELECT o.custkey FROM lineitem l WHERE (SELECT l.orderkey = o.orderkey))",
LogicalPlanner.Stage.OPTIMIZED,
Expand All @@ -218,21 +222,26 @@ public void testCorrelatedSubqueries()
))))))));
}

private void assertPlan(String sql, PlanMatchPattern pattern)
public void assertPlanDoesNotMatch(@Language("SQL") String sql, PlanMatchPattern pattern)
{
assertPlan(sql, LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED, pattern);
throw new UnsupportedOperationException("assertPlanDoesNotMatch() is not supported");
}

private void assertPlan(String sql, LogicalPlanner.Stage stage, PlanMatchPattern pattern)
public void assertPlanMatches(@Language("SQL") String sql, PlanMatchPattern pattern)
{
assertPlanMatches(sql, LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED, pattern);
}

private void assertPlanMatches(@Language("SQL") String sql, LogicalPlanner.Stage stage, PlanMatchPattern pattern)
{
Plan actualPlan = plan(sql, stage);
queryRunner.inTransaction(transactionSession -> {
PlanAssert.assertPlan(transactionSession, queryRunner.getMetadata(), actualPlan, pattern);
PlanAssert.assertPlanMatches(transactionSession, queryRunner.getMetadata(), actualPlan, pattern);
return null;
});
}

private Plan plan(String sql)
public Plan plan(String sql)
{
return plan(sql, LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package com.facebook.presto.sql.planner.assertions;

import com.facebook.presto.sql.tree.AstVisitor;
import com.facebook.presto.sql.tree.Cast;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.GenericLiteral;
Expand Down Expand Up @@ -115,6 +116,16 @@ else if (expression instanceof GenericLiteral) {
}
}

@Override
protected Boolean visitCast(Cast actual, Expression expectedExpession)
{
if (expectedExpession instanceof Cast) {
Cast expected = (Cast) expectedExpession;
return process(actual.getExpression(), expected.getExpression());
}
return false;
}

@Override
protected Boolean visitStringLiteral(StringLiteral actual, Expression expectedExpression)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,37 @@
import static com.facebook.presto.sql.planner.PlanPrinter.textLogicalPlan;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
import static org.testng.Assert.assertTrue;
import static org.testng.Assert.fail;

public final class PlanAssert
{
private PlanAssert() {}

public static void assertPlan(Session session, Metadata metadata, Plan actual, PlanMatchPattern pattern)
public static void assertPlanMatches(Session session, Metadata metadata, Plan actual, PlanMatchPattern pattern)
{
assertPlan(session, metadata, actual, pattern, true);
}

public static void assertPlanDoesNotMatch(Session session, Metadata metadata, Plan actual, PlanMatchPattern pattern)
{
assertPlan(session, metadata, actual, pattern, false);
}

private static void assertPlan(Session session, Metadata metadata, Plan actual, PlanMatchPattern pattern, boolean expectedMatch)
{
requireNonNull(actual, "root is null");

boolean matches = actual.getRoot().accept(new PlanMatchingVisitor(session, metadata), new PlanMatchingContext(pattern));
if (!matches) {
boolean actualMatch = actual.getRoot().accept(new PlanMatchingVisitor(session, metadata), new PlanMatchingContext(pattern));
if (expectedMatch != actualMatch) {
String logicalPlan = textLogicalPlan(actual.getRoot(), actual.getTypes(), metadata, session);
assertTrue(matches, format("Plan does not match:\n %s\n, to pattern:\n%s", logicalPlan, pattern));
String errorMessage;
if (expectedMatch) {
errorMessage = format("Plan does not match:\n%s\nto pattern:\n%s", logicalPlan, pattern);
}
else {
errorMessage = format("Plan matches:\n%s\nto pattern:\n%s", logicalPlan, pattern);
}
fail(errorMessage);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.predicate.Domain;
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.plan.ApplyNode;
import com.facebook.presto.sql.planner.plan.FilterNode;
import com.facebook.presto.sql.planner.plan.JoinNode;
Expand All @@ -38,6 +39,7 @@
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;

import static com.facebook.presto.util.ImmutableCollectors.toImmutableList;
import static com.google.common.base.Preconditions.checkState;
Expand Down Expand Up @@ -147,6 +149,11 @@ List<PlanMatchingState> matches(PlanNode node, Session session, Metadata metadat
return states.build();
}

public PlanMatchPattern withAssignments(Map<Symbol, Expression> assignments)
{
return with(new ProjectNodeMatcher(assignments));
}

public PlanMatchPattern withSymbol(String pattern, String alias)
{
return with(new SymbolMatcher(pattern, alias));
Expand Down
Loading