Skip to content

Commit

Permalink
Add rewriter for multiple similar APPROX_PERCENTILEs in same clause
Browse files Browse the repository at this point in the history
  • Loading branch information
jnmugerwa committed Apr 1, 2021
1 parent 2dbc1ef commit e8204c4
Show file tree
Hide file tree
Showing 3 changed files with 310 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.facebook.coresql.rewriter;

import com.facebook.coresql.parser.AstNode;
import com.facebook.coresql.parser.FunctionCall;
import com.facebook.coresql.parser.SqlParserDefaultVisitor;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ListMultimap;

import java.util.Formatter;
import java.util.Optional;

import static com.facebook.coresql.parser.SqlParserTreeConstants.JJTARGUMENTLIST;
import static com.facebook.coresql.parser.SqlParserTreeConstants.JJTIDENTIFIER;
import static com.facebook.coresql.parser.SqlParserTreeConstants.JJTUNSIGNEDNUMERICLITERAL;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.toList;

public class ApproxPercentileRewriter
extends Rewriter
{
private static final int ZERO_INDEXING_OFFSET = 1;
private static final String REPLACEMENT_FORMAT = "(%s, ARRAY[%s])[%d]";
private static final String REWRITE_NAME = "Multiple APPROX PERCENTILE with same first arg and literal second arg";
private static final String APPROX_PERCENTILE_IDENTIFIER = "APPROX_PERCENTILE";
private final AstNode root;
private final ListMultimap<String, AstNode> firstArgToApproxPercentileNode; // A map of String to the APPROX_PERCENTILE nodes with that String as their first argument
private final ListMultimap<String, AstNode> firstArgToPercentiles; // A map of first argument to the node that contains the percentiles corresponding to that first argument

public ApproxPercentileRewriter(AstNode root)
{
this.root = requireNonNull(root, "AST passed to rewriter was null");
ApproxPercentilePatternMatcherResult patternMatcherResult = new ApproxPercentilePatternMatcher(root).matchPattern();
this.firstArgToApproxPercentileNode = patternMatcherResult.getFirstArgToApproxPercentileNode();
this.firstArgToPercentiles = patternMatcherResult.getFirstArgToPercentiles();
}

@Override
public Optional<RewriteResult> rewrite()
{
if (!approxPercentileRewritePatternIsPresent()) {
return Optional.empty();
}
String rewrittenSql = unparse(root, this);
return Optional.of(new RewriteResult(REWRITE_NAME, rewrittenSql));
}

@Override
public void visit(FunctionCall node, Void data)
{
if (needsApproxPercentileRewrite(node)) {
applyApproxPercentileRewrite(node);
}
else {
defaultVisit(node, data);
}
}

private boolean approxPercentileRewritePatternIsPresent()
{
return firstArgToApproxPercentileNode.keySet().stream()
.anyMatch(key -> firstArgToApproxPercentileNode.get(key).size() >= 2);
}

private boolean needsApproxPercentileRewrite(FunctionCall node)
{
return firstArgToApproxPercentileNode.containsValue(node) && firstArgToApproxPercentileNode.get(unparse(getNthArgument(node, 0).get()).trim()).size() >= 2;
}

private String getPercentilesAsString(String firstArg)
{
String percentilesAsString = firstArgToPercentiles.get(firstArg).stream()
.map(node -> unparse(node).trim())
.collect(toList())
.toString();
return percentilesAsString.substring(1, percentilesAsString.length() - 1);
}

/**
* Generates a rewritten version of the current subtree.
*
* @param node The function call node we're rewriting
*/
private void applyApproxPercentileRewrite(AstNode node)
{
// First, unparse up to the node's last child. This ensures we don't miss any special tokens
unparseUpto(node.LastChild());
// Then, add the rewritten version to the rewriter's result object (i.e. stringBuilder)
String firstArg = unparse(getNthArgument(node, 0).get()).trim();
AstNode secondArg = getNthArgument(node, 1).get();
Formatter formatter = new Formatter(stringBuilder);
formatter.format(REPLACEMENT_FORMAT, firstArg, getPercentilesAsString(firstArg), getIndexOfPercentile(firstArg, secondArg));
// Lastly, move to end of this node -- we've already added a rewritten version of it to the result, so we don't need to process it further
moveToEndOfNode(node);
}

private int getIndexOfPercentile(String firstArg, AstNode secondArg)
{
return firstArgToPercentiles.get(firstArg).indexOf(secondArg) + ZERO_INDEXING_OFFSET;
}

private static Optional<AstNode> getNthArgument(AstNode node, int n)
{
Optional<AstNode> argList = Optional.ofNullable(node.GetFirstChildOfKind(JJTARGUMENTLIST));
if (!argList.isPresent() || argList.get().jjtGetNumChildren() < n) {
return Optional.empty();
}
AstNode nthArg = (AstNode) argList.get().jjtGetChild(n);
return Optional.of(nthArg);
}

private static boolean hasUnsignedLiteralSecondArg(AstNode node)
{
return getNthArgument(node, 1).filter(astNode -> astNode.getId() == JJTUNSIGNEDNUMERICLITERAL).isPresent();
}

private static boolean isApproxPercentileNode(AstNode node)
{
Optional<AstNode> identifier = Optional.ofNullable(node.GetFirstChildOfKind(JJTIDENTIFIER));
if (!identifier.isPresent()) {
return false;
}
Optional<String> image = Optional.ofNullable(identifier.get().GetImage());
return image.isPresent() && image.get().equalsIgnoreCase(APPROX_PERCENTILE_IDENTIFIER);
}

private static class ApproxPercentilePatternMatcher
extends SqlParserDefaultVisitor
{
private final AstNode root;
private final ImmutableListMultimap.Builder<String, AstNode> firstArgToApproxPercentileNode = ImmutableListMultimap.builder();
private final ImmutableListMultimap.Builder<String, AstNode> firstArgToPercentiles = ImmutableListMultimap.builder();

public ApproxPercentilePatternMatcher(AstNode root)
{
this.root = requireNonNull(root, "AST passed to pattern matcher was null");
}

public ApproxPercentilePatternMatcherResult matchPattern()
{
root.jjtAccept(this, null);
return new ApproxPercentilePatternMatcherResult(firstArgToApproxPercentileNode.build(), firstArgToPercentiles.build());
}

@Override
public void visit(FunctionCall node, Void data)
{
if (isApproxPercentileNode(node) && hasUnsignedLiteralSecondArg(node)) {
String firstArg = unparse(getNthArgument(node, 0).get()).trim();
AstNode secondArg = getNthArgument(node, 1).get();
firstArgToApproxPercentileNode.put(firstArg, node);
firstArgToPercentiles.put(firstArg, secondArg);
}
defaultVisit(node, data);
}
}

private static class ApproxPercentilePatternMatcherResult
{
private final ListMultimap<String, AstNode> firstArgToApproxPercentileNode;
private final ListMultimap<String, AstNode> firstArgToPercentiles;

public ApproxPercentilePatternMatcherResult(ListMultimap<String, AstNode> firstArgToApproxPercentileNode,
ListMultimap<String, AstNode> firstArgToPercentiles)
{
this.firstArgToApproxPercentileNode = requireNonNull(firstArgToApproxPercentileNode);
this.firstArgToPercentiles = requireNonNull(firstArgToPercentiles);
}

public ListMultimap<String, AstNode> getFirstArgToApproxPercentileNode()
{
return firstArgToApproxPercentileNode;
}

public ListMultimap<String, AstNode> getFirstArgToPercentiles()
{
return firstArgToPercentiles;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.facebook.coresql.rewriter;

import com.facebook.coresql.parser.AstNode;
import org.testng.annotations.Test;

import java.util.Optional;

import static com.facebook.coresql.parser.ParserHelper.parseStatement;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;
import static org.testng.Assert.assertNotNull;
import static org.testng.Assert.assertTrue;

public class TestApproxPercentileRewriter
{
private static final String[] STATEMENT_THAT_DOESNT_NEED_REWRITE = new String[] {
// True Negative
"CREATE TABLE blah AS SELECT * FROM (SELECT * FROM (SELECT foo FROM T ORDER BY x LIMIT 10) ORDER BY y LIMIT 10) ORDER BY z LIMIT 10;",
"SELECT dealer_id, sales OVER (PARTITION BY dealer_id ORDER BY sales);",
"INSERT INTO blah SELECT * FROM (SELECT t.date, t.code, t.qty FROM sales AS t ORDER BY t.date LIMIT 100);",
"SELECT (true or false) and false;",
"SELECT * FROM T ORDER BY y;",
"SELECT * FROM T ORDER BY y LIMIT 10;",
"use a.b;",
" SELECT 1;",
"SELECT a FROM T;",
"SELECT a FROM T WHERE p1 > p2;",
"SELECT a, b, c FROM T WHERE c1 < c2 and c3 < c4;",
"SELECT CASE a WHEN IN ( 1 ) THEN b ELSE c END AS x, b, c FROM T WHERE c1 < c2 and c3 < c4;",
"SELECT T.* FROM T JOIN W ON T.x = W.x;",
"SELECT NULL;",
"SELECT ARRAY[x] FROM T;",
"SELECT TRANSFORM(ARRAY[x], x -> x + 2) AS arra FROM T;",
"CREATE TABLE T AS SELECT TRANSFORM(ARRAY[x], x -> x + 2) AS arra FROM T;",
"INSERT INTO T SELECT TRANSFORM(ARRAY[x], x -> x + 2) AS arra FROM T;",
"SELECT ROW_NUMBER() OVER(PARTITION BY x) FROM T;",
"SELECT x, SUM(y) OVER (PARTITION BY y ORDER BY 1) AS min\n" +
"FROM (values ('b',10), ('a', 10)) AS T(x, y)\n;",
"SELECT\n" +
" CAST(MAP() AS map<bigint,array<boolean>>) AS \"bool_tensor_features\";",
"SELECT f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f())))))))))))))))))))))))))))));",
"SELECT abs, 2 as abs;",
// False Positive
"SELECT APPROX_PERCENTILE(x, 0.1) AS percentile_10, APPROX_PERCENTILE(y, 0.2) AS percentile_20, APPROX_PERCENTILE(z, 0.3) AS percentile_30 FROM T;",
"SELECT APPROX_PERCENTILE(x, ARRAY[0.1, 0.2, 0.3]), APPROX_PERCENTILE(x, 0.4) FROM T;",
"SELECT APPROX_PERCENTILE(x, ARRAY[0.1, 0.2, 0.3]) FROM (SELECT APPROX_PERCENTILE(x, 0.1) from T);"
};

private static final String[] STATEMENT_BEFORE_REWRITE = new String[] {
"SELECT APPROX_PERCENTILE(x, 0.1) AS percentile_10, APPROX_PERCENTILE(x, 0.2) AS percentile_20, APPROX_PERCENTILE(x, 0.3) AS percentile_30 FROM T;",
"SELECT APPROX_PERCENTILE(x, 0.1), APPROX_PERCENTILE(x, 0.2) AS percentile_20, APPROX_PERCENTILE(x, 0.3) FROM T;",
"SELECT approx_percentile(y, 0.2), x + 1, approx_percentile(y, 0.1) from T group by 2;",
"SELECT APPROX_PERCENTILE(x, 0.1) AS percentile_10, APPROX_PERCENTILE(x, 0.2) AS percentile_20, APPROX_PERCENTILE(z, 0.3) AS percentile_30 FROM T;",
"SELECT APPROX_PERCENTILE(x, ARRAY[0.1, 0.2, 0.3]), APPROX_PERCENTILE(x, 0.4), APPROX_PERCENTILE(x, 0.5) FROM T;",
"SELECT APPROX_PERCENTILE(x, 0.1) FROM (SELECT APPROX_PERCENTILE(x, 0.2) FROM T);",
"SELECT x FROM (SELECT APPROX_PERCENTILE(x, 0.1) AS percentile_10, APPROX_PERCENTILE(x, 0.2) AS percentile_20, APPROX_PERCENTILE(x, 0.3) AS percentile_30 FROM T);",
"CREATE TABLE blah AS SELECT * FROM (SELECT * FROM (SELECT approx_percentile(y, 0.2), x + 1, approx_percentile(y, 0.1) from T group by 2));",
"SELECT APPROX_PERCENTILE(x, 0.1) FROM (SELECT * FROM (SELECT approx_percentile(x, 0.1), approx_percentile(y, 0.1) from T group by 2));"
};

private static final String[] STATEMENT_AFTER_REWRITE = new String[] {
"SELECT APPROX_PERCENTILE(x, ARRAY[0.1, 0.2, 0.3])[1] AS percentile_10, APPROX_PERCENTILE(x, ARRAY[0.1, 0.2, 0.3])[2] AS percentile_20, APPROX_PERCENTILE(x, ARRAY[0.1, 0.2, 0.3])[3] AS percentile_30 FROM T;",
"SELECT APPROX_PERCENTILE(x, ARRAY[0.1, 0.2, 0.3])[1], APPROX_PERCENTILE(x, ARRAY[0.1, 0.2, 0.3])[2] AS percentile_20, APPROX_PERCENTILE(x, ARRAY[0.1, 0.2, 0.3])[3] FROM T;",
"SELECT approx_percentile(y, ARRAY[0.2, 0.1])[1], x + 1, approx_percentile(y, ARRAY[0.2, 0.1])[2] from T group by 2;",
"SELECT APPROX_PERCENTILE(x, ARRAY[0.1, 0.2])[1] AS percentile_10, APPROX_PERCENTILE(x, ARRAY[0.1, 0.2])[2] AS percentile_20, APPROX_PERCENTILE(z, 0.3) AS percentile_30 FROM T;",
"SELECT APPROX_PERCENTILE(x, ARRAY[0.1, 0.2, 0.3]), APPROX_PERCENTILE(x, ARRAY[0.4, 0.5])[1], APPROX_PERCENTILE(x, ARRAY[0.4, 0.5])[2] FROM T;",
"SELECT APPROX_PERCENTILE(x, ARRAY[0.1, 0.2])[1] FROM (SELECT APPROX_PERCENTILE(x, ARRAY[0.1, 0.2])[2] FROM T);",
"SELECT x FROM (SELECT APPROX_PERCENTILE(x, ARRAY[0.1, 0.2, 0.3])[1] AS percentile_10, APPROX_PERCENTILE(x, ARRAY[0.1, 0.2, 0.3])[2] AS percentile_20, APPROX_PERCENTILE(x, ARRAY[0.1, 0.2, 0.3])[3] AS percentile_30 FROM T);",
"CREATE TABLE blah AS SELECT * FROM (SELECT * FROM (SELECT approx_percentile(y, ARRAY[0.2, 0.1])[1], x + 1, approx_percentile(y, ARRAY[0.2, 0.1])[2] from T group by 2));",
"SELECT APPROX_PERCENTILE(x, ARRAY[0.1, 0.1])[1] FROM (SELECT * FROM (SELECT approx_percentile(x, ARRAY[0.1, 0.1])[2], approx_percentile(y, 0.1) from T group by 2));"
};

private void assertStatementUnchanged(String originalStatement)
{
assertFalse(getRewriteResult(originalStatement).isPresent());
}

private void assertStatementRewritten(String originalStatement, String expectedStatement)
{
Optional<RewriteResult> result = getRewriteResult(originalStatement);
assertTrue(result.isPresent());
assertEquals(result.get().getRewrittenSql(), expectedStatement);
}

private Optional<RewriteResult> getRewriteResult(String originalStatement)
{
AstNode ast = parseStatement(originalStatement);
assertNotNull(ast);
return new ApproxPercentileRewriter(ast).rewrite();
}

@Test
public void rewriteTest()
{
for (int i = 0; i < STATEMENT_BEFORE_REWRITE.length; i++) {
assertStatementRewritten(STATEMENT_BEFORE_REWRITE[i], STATEMENT_AFTER_REWRITE[i]);
}

for (String sql : STATEMENT_THAT_DOESNT_NEED_REWRITE) {
assertStatementUnchanged(sql);
}
}
}
Binary file added rewriter/~$pom.xml
Binary file not shown.

0 comments on commit e8204c4

Please sign in to comment.