Skip to content

Commit

Permalink
work in progress
Browse files Browse the repository at this point in the history
  • Loading branch information
zachyee committed Sep 2, 2016
1 parent 6e8721e commit ed8b85d
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package com.facebook.presto.sql.planner.assertions;

import com.facebook.presto.sql.tree.AstVisitor;
import com.facebook.presto.sql.tree.Cast;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.GenericLiteral;
Expand Down Expand Up @@ -115,6 +116,16 @@ else if (expression instanceof GenericLiteral) {
}
}

@Override
protected Boolean visitCast(Cast actual, Expression expectedExpession)
{
if (expectedExpession instanceof Cast) {
Cast expected = (Cast) expectedExpession;
return process(actual.getExpression(), expected.getExpression());
}
return false;
}

@Override
protected Boolean visitStringLiteral(StringLiteral actual, Expression expectedExpression)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.predicate.Domain;
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.plan.ApplyNode;
import com.facebook.presto.sql.planner.plan.FilterNode;
import com.facebook.presto.sql.planner.plan.JoinNode;
Expand All @@ -38,6 +39,7 @@
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;

import static com.facebook.presto.util.ImmutableCollectors.toImmutableList;
import static com.google.common.base.Preconditions.checkState;
Expand Down Expand Up @@ -147,9 +149,9 @@ List<PlanMatchingState> matches(PlanNode node, Session session, Metadata metadat
return states.build();
}

public PlanMatchPattern withAssignment(String pattern)
public PlanMatchPattern withAssignments(Map<Symbol, Expression> assignments)
{
return with(new ProjectNodeMatcher(pattern));
return with(new ProjectNodeMatcher(assignments));
}

public PlanMatchPattern withSymbol(String pattern, String alias)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,20 @@
final class ProjectNodeMatcher
implements Matcher
{
private final Pattern pattern;
private final Map<Symbol, Expression> assignments;

ProjectNodeMatcher(String pattern)
ProjectNodeMatcher(Map<Symbol, Expression> assignments)
{
this.pattern = Pattern.compile(pattern);
this.assignments = assignments;
}

@Override
public boolean matches(PlanNode node, Session session, Metadata metadata, ExpressionAliases expressionAliases)
{
if (node instanceof ProjectNode) {
ProjectNode projectNode = (ProjectNode) node;
for (Map.Entry<Symbol, Expression> assignment : projectNode.getAssignments().entrySet()) {
Expression expression = assignment.getValue();
if (pattern.matcher(expression.getClass().getSimpleName()).find()) {
return true;
}
if (projectNode.getAssignments().equals(assignments)) {
return true;
}
}
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
*/
package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.metadata.MetadataManager;
import com.facebook.presto.spi.type.BigintType;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.analyzer.FeaturesConfig;
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.planner.DependencyExtractor;
import com.facebook.presto.sql.planner.Plan;
Expand All @@ -24,18 +28,27 @@
import com.facebook.presto.sql.planner.assertions.PlanAssert;
import com.facebook.presto.sql.planner.assertions.PlanMatchPattern;
import com.facebook.presto.sql.planner.plan.FilterNode;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.planner.plan.ValuesNode;
import com.facebook.presto.sql.tree.ArithmeticBinaryExpression;
import com.facebook.presto.sql.tree.Cast;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.ExpressionRewriter;
import com.facebook.presto.sql.tree.ExpressionTreeRewriter;
import com.facebook.presto.sql.tree.GenericLiteral;
import com.facebook.presto.sql.tree.LogicalBinaryExpression;
import com.facebook.presto.testing.LocalQueryRunner;
import com.facebook.presto.tpch.TpchConnectorFactory;
import com.facebook.presto.type.TypeRegistry;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.intellij.lang.annotations.Language;
import org.testng.annotations.BeforeTest;
import org.testng.annotations.Test;

import javax.inject.Provider;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
Expand All @@ -59,6 +72,9 @@ public class TestSimplifyExpressions
private static final SqlParser SQL_PARSER = new SqlParser();
private static final SimplifyExpressions SIMPLIFIER = new SimplifyExpressions(createTestMetadataManager(), SQL_PARSER);
private LocalQueryRunner queryRunner;
private Map<Symbol, Expression> inputAssignments;
private Map<Symbol, Expression> expectedAssignments;
private Map<Symbol, Type> typeAssignments;

@BeforeTest
public void setUp()
Expand All @@ -71,6 +87,10 @@ public void setUp()
queryRunner.createCatalog(queryRunner.getDefaultSession().getCatalog().get(),
new TpchConnectorFactory(queryRunner.getNodeManager(), 1),
ImmutableMap.<String, String>of());

inputAssignments = new HashMap<Symbol, Expression>();
expectedAssignments = new HashMap<Symbol, Expression>();
typeAssignments = new HashMap<Symbol, Type>();
}

@Test
Expand Down Expand Up @@ -116,9 +136,27 @@ public void testExtractsCommonPredicate()
}

@Test
public void testRemoveIdentityCasts()
public void testRemoveIdentityCastBigintLiteral()
{
inputAssignments.put(new Symbol("expr"), new Cast(new GenericLiteral("BIGINT", "5"), "BIGINT"));
expectedAssignments.put(new Symbol("expr"), new GenericLiteral("BIGINT", "5"));
typeAssignments.put(new Symbol("expr"), BigintType.BIGINT);
ProjectNode simplifiedProjectNode = simplifyProjectNode(inputAssignments, typeAssignments);
Map<Symbol, Expression> actualAssignments = simplifiedProjectNode.getAssignments();
assert (actualAssignments.equals(expectedAssignments));
}

@Test
public void testRemoveIdentityCastMultiplyBigintLiteral()
{
assertCastNotInPlan("SELECT CAST(BIGINT '5' as BIGINT)");
inputAssignments.put(new Symbol("expr"),
new Cast(new ArithmeticBinaryExpression(BigintType.BIGINT,
new GenericLiteral("BIGINT", "3"),
),
"BIGINT"))
}

/*
assertCastNotInPlan("SELECT 3 * CAST(BIGINT '5' as BIGINT)");
assertCastNotInPlan("SELECT CAST(nationkey AS BIGINT) FROM nation");
assertCastNotInPlan("SELECT 3 * CAST(nationkey AS BIGINT) FROM nation");
Expand All @@ -134,8 +172,10 @@ public void testRemoveIdentityCasts()
assertCastInPlan("SELECT CAST(nationkey AS SMALLINT) FROM nation");
assertCastInPlan("SELECT CAST(name AS VARCHAR(30)) FROM nation");
assertCastInPlan("SELECT CAST(name AS VARCHAR) FROM nation");
}
assertCastInPlan("SELECT CAST(name AS VARCHAR) FROM nation");*/


private void assertRemovesIdentityCast()

public void assertPlanDoesNotMatch(@Language("SQL") String sql, PlanMatchPattern pattern)
{
Expand All @@ -160,32 +200,43 @@ public Plan plan(@Language("SQL") String sql)
return queryRunner.inTransaction(transactionSession -> queryRunner.createPlan(transactionSession, sql));
}

private void assertCastNotInPlan(@Language("SQL") String sql)
private void assertUnitPlanMatches(@Language("SQL") String sql, PlanMatchPattern pattern)
{
PlanMatchPattern pattern = anyTree(
project(anyTree()).withAssignment("Cast")
);
assertPlanDoesNotMatch(sql, pattern);
Plan actualPlan = unitPlan(sql);
queryRunner.inTransaction(transactionSession -> {
PlanAssert.assertPlanMatches(transactionSession, queryRunner.getMetadata(), actualPlan, pattern);
return null;
});
}

private void assertCastInPlan(@Language("SQL") String sql)
private Plan unitPlan(@Language("SQL") String sql)
{
PlanMatchPattern pattern = anyTree(
project(anyTree()).withAssignment("Cast")
);
assertPlanMatches(sql, pattern);
FeaturesConfig featuresConfig = new FeaturesConfig()
.setExperimentalSyntaxEnabled(true)
.setDistributedIndexJoinsEnabled(false)
.setOptimizeHashGeneration(true);
Metadata metadata = new MetadataManager(featuresConfig,
new TypeRegistry(),
)
Provider<List<PlanOptimizer>> optimizerProvider = () -> ImmutableList.of(
new UnaliasSymbolReferences(),
new PruneIdentityProjections(),
new MergeIdenticalWindows(),
new PruneUnreferencedOutputs(),
new SimplifyExpressions());
return queryRunner.inTransaction(transactionSession -> queryRunner.createPlan(transactionSession, sql, featuresConfig, optimizerProvider));
}

private static void assertSimplifies(String expression, String expected)
{
Expression actualExpression = rewriteQualifiedNamesToSymbolReferences(SQL_PARSER.createExpression(expression));
Expression expectedExpression = rewriteQualifiedNamesToSymbolReferences(SQL_PARSER.createExpression(expected));
assertEquals(
normalize(simplifyExpressions(actualExpression)),
normalize(simplifyFilterNode(actualExpression)),
normalize(expectedExpression));
}

private static Expression simplifyExpressions(Expression expression)
private static Expression simplifyFilterNode(Expression expression)
{
PlanNodeIdAllocator planNodeIdAllocator = new PlanNodeIdAllocator();
FilterNode filterNode = new FilterNode(
Expand All @@ -200,6 +251,19 @@ private static Expression simplifyExpressions(Expression expression)
return simplifiedNode.getPredicate();
}

private static ProjectNode simplifyProjectNode(Map<Symbol, Expression> assignments, Map<Symbol, Type> typeAssignments)
{
PlanNodeIdAllocator planNodeIdAllocator = new PlanNodeIdAllocator();
ProjectNode projectNode = new ProjectNode(planNodeIdAllocator.getNextId(),
new ValuesNode(planNodeIdAllocator.getNextId(), emptyList(), emptyList()),
assignments);
return (ProjectNode) SIMPLIFIER.optimize(projectNode,
TEST_SESSION,
typeAssignments,
new SymbolAllocator(),
planNodeIdAllocator);
}

private static Map<Symbol, Type> booleanSymbolTypeMapFor(Expression expression)
{
return DependencyExtractor.extractUnique(expression).stream()
Expand Down

0 comments on commit ed8b85d

Please sign in to comment.