Skip to content

Commit

Permalink
Add rewriter for multiple similar APPROX_PERCENTILEs in same clause
Browse files Browse the repository at this point in the history
  • Loading branch information
jnmugerwa committed Mar 29, 2021
1 parent 2dbc1ef commit eea9344
Show file tree
Hide file tree
Showing 3 changed files with 286 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.facebook.coresql.rewriter;

import com.facebook.coresql.parser.AstNode;
import com.facebook.coresql.parser.FunctionCall;
import com.facebook.coresql.parser.SqlParserDefaultVisitor;
import com.google.common.collect.ImmutableListMultimap;

import java.util.Formatter;
import java.util.Map;
import java.util.Optional;

import static com.facebook.coresql.parser.SqlParserTreeConstants.JJTARGUMENTLIST;
import static com.facebook.coresql.parser.SqlParserTreeConstants.JJTIDENTIFIER;
import static com.facebook.coresql.parser.SqlParserTreeConstants.JJTUNSIGNEDNUMERICLITERAL;
import static java.lang.Double.parseDouble;
import static java.util.Collections.binarySearch;
import static java.util.Comparator.naturalOrder;
import static java.util.Objects.requireNonNull;

public class ApproxPercentileRewriter
extends Rewriter
{
private static final int ZERO_INDEXING_OFFSET = 1;
private static final String REPLACEMENT_FORMAT = " APPROX_PERCENTILE(%s, ARRAY%s)[%d]";
private static final String REWRITE_NAME = "Multiple APPROX PERCENTILE with same first arg and literal second arg";
private final AstNode root;
private final ImmutableListMultimap<String, AstNode> firstArgToApproxPercentileNode; // A map of String to the APPROX_PERCENTILE nodes with that String as their first argument
private final ImmutableListMultimap<String, Double> firstArgToPercentiles; // A map of String to a list of the percentiles of the APPROX_PERCENTILE nodes that have that String as their first argument

public ApproxPercentileRewriter(AstNode root)
{
this.root = requireNonNull(root, "AST passed to rewriter was null");
this.firstArgToApproxPercentileNode = new ApproxPercentilePatternMatcher(root).matchPattern();
this.firstArgToPercentiles = getPercentilesFromFirstArgMap();
}

@Override
public Optional<RewriteResult> rewrite()
{
if (!approxPercentileRewritePatternIsPresent()) {
return Optional.empty();
}
String rewrittenSql = unparse(root, this);
return Optional.of(new RewriteResult(REWRITE_NAME, rewrittenSql));
}

@Override
public void visit(FunctionCall node, Void data)
{
if (canApplyApproxPercentileRewrite(node)) {
applyApproxPercentileRewrite(node);
}
else {
defaultVisit(node, data);
}
}

private boolean approxPercentileRewritePatternIsPresent()
{
return firstArgToApproxPercentileNode.keySet().stream()
.anyMatch(key -> firstArgToApproxPercentileNode.get(key).size() >= 2);
}

private boolean canApplyApproxPercentileRewrite(AstNode node)
{
return firstArgToApproxPercentileNode.containsValue(node) && firstArgToApproxPercentileNode.get(unparse(getNthArgument(node, 0).get())).size() >= 2;
}

/**
* Generates a rewritten version of the current subtree.
*
* @param node The function call node we're rewriting
*/
private void applyApproxPercentileRewrite(AstNode node)
{
// First, unparse up to the node's first child. This ensures we don't miss any special tokens
unparseUpto((AstNode) node.jjtGetChild(0));
// Then, add the rewritten version to the rewriter's result object (i.e. stringBuilder)
String firstArg = unparse(getNthArgument(node, 0).get());
Double secondArg = parseDouble(unparse(getNthArgument(node, 1).get()));
Formatter formatter = new Formatter(stringBuilder);
formatter.format(REPLACEMENT_FORMAT, firstArg, firstArgToPercentiles.get(firstArg), getIndexOfPercentile(firstArg, secondArg));
// Lastly, move to end of this node -- we've already added a rewritten version of it to the result, so we don't need to process it further
moveToEndOfNode(node);
}

private int getIndexOfPercentile(String firstArg, Double secondArg)
{
return binarySearch(firstArgToPercentiles.get(firstArg), secondArg) + ZERO_INDEXING_OFFSET;
}

private static Optional<AstNode> getNthArgument(AstNode node, int n)
{
Optional<AstNode> argList = Optional.ofNullable(node.GetFirstChildOfKind(JJTARGUMENTLIST));
if (!argList.isPresent() || argList.get().jjtGetNumChildren() < n) {
return Optional.empty();
}
AstNode nthArg = (AstNode) argList.get().jjtGetChild(n);
return Optional.of(nthArg);
}

private static boolean hasUnsignedLiteralSecondArg(AstNode node)
{
return getNthArgument(node, 1).filter(astNode -> astNode.getId() == JJTUNSIGNEDNUMERICLITERAL).isPresent();
}

private static boolean isApproxPercentileNode(AstNode node)
{
Optional<AstNode> identifier = Optional.ofNullable(node.GetFirstChildOfKind(JJTIDENTIFIER));
if (!identifier.isPresent()) {
return false;
}
Optional<String> image = Optional.ofNullable(identifier.get().GetImage());
return image.isPresent() && image.get().equalsIgnoreCase("APPROX_PERCENTILE");
}

private ImmutableListMultimap<String, Double> getPercentilesFromFirstArgMap()
{
ImmutableListMultimap.Builder<String, Double> firstArgToPercentiles = ImmutableListMultimap.builder();
for (Map.Entry<String, AstNode> entry : firstArgToApproxPercentileNode.entries()) {
String firstArg = entry.getKey();
AstNode approxPercentileNode = entry.getValue();
firstArgToPercentiles.put(firstArg, parseDouble(unparse(getNthArgument(approxPercentileNode, 1).get())));
}
// Sort each percentile list before returning. This will allow binary search downstream
return firstArgToPercentiles.orderValuesBy(naturalOrder()).build();
}

private static class ApproxPercentilePatternMatcher
extends SqlParserDefaultVisitor
{
private final AstNode root;
private final ImmutableListMultimap.Builder<String, AstNode> firstArgToApproxPercentileNode = ImmutableListMultimap.builder();

public ApproxPercentilePatternMatcher(AstNode root)
{
this.root = requireNonNull(root, "AST passed to pattern matcher was null");
}

public ImmutableListMultimap<String, AstNode> matchPattern()
{
root.jjtAccept(this, null);
return firstArgToApproxPercentileNode.build();
}

@Override
public void visit(FunctionCall node, Void data)
{
if (isApproxPercentileNode(node) && hasUnsignedLiteralSecondArg(node)) {
firstArgToApproxPercentileNode.put(unparse(getNthArgument(node, 0).get()), node);
}
defaultVisit(node, data);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.facebook.coresql.rewriter;

import com.facebook.coresql.parser.AstNode;
import com.google.common.collect.ImmutableMap;
import org.testng.annotations.Test;

import java.util.Map;
import java.util.Optional;

import static com.facebook.coresql.parser.ParserHelper.parseStatement;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;
import static org.testng.Assert.assertNotNull;
import static org.testng.Assert.assertTrue;

public class TestApproxPercentileRewriter
{
private static final String[] STATEMENTS_THAT_DONT_NEED_REWRITE = new String[] {
// True Negative
"CREATE TABLE blah AS SELECT * FROM (SELECT * FROM (SELECT foo FROM T ORDER BY x LIMIT 10) ORDER BY y LIMIT 10) ORDER BY z LIMIT 10;",
"SELECT dealer_id, sales OVER (PARTITION BY dealer_id ORDER BY sales);",
"INSERT INTO blah SELECT * FROM (SELECT t.date, t.code, t.qty FROM sales AS t ORDER BY t.date LIMIT 100);",
"SELECT (true or false) and false;",
"SELECT * FROM T ORDER BY y;",
"SELECT * FROM T ORDER BY y LIMIT 10;",
"use a.b;",
" SELECT 1;",
"SELECT a FROM T;",
"SELECT a FROM T WHERE p1 > p2;",
"SELECT a, b, c FROM T WHERE c1 < c2 and c3 < c4;",
"SELECT CASE a WHEN IN ( 1 ) THEN b ELSE c END AS x, b, c FROM T WHERE c1 < c2 and c3 < c4;",
"SELECT T.* FROM T JOIN W ON T.x = W.x;",
"SELECT NULL;",
"SELECT ARRAY[x] FROM T;",
"SELECT TRANSFORM(ARRAY[x], x -> x + 2) AS arra FROM T;",
"CREATE TABLE T AS SELECT TRANSFORM(ARRAY[x], x -> x + 2) AS arra FROM T;",
"INSERT INTO T SELECT TRANSFORM(ARRAY[x], x -> x + 2) AS arra FROM T;",
"SELECT ROW_NUMBER() OVER(PARTITION BY x) FROM T;",
"SELECT x, SUM(y) OVER (PARTITION BY y ORDER BY 1) AS min\n" +
"FROM (values ('b',10), ('a', 10)) AS T(x, y)\n;",
"SELECT\n" +
" CAST(MAP() AS map<bigint,array<boolean>>) AS \"bool_tensor_features\";",
"SELECT f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f())))))))))))))))))))))))))))));",
"SELECT abs, 2 as abs;",
// False Positive
"SELECT APPROX_PERCENTILE(x, 0.1) AS percentile_10, APPROX_PERCENTILE(y, 0.2) AS percentile_20, APPROX_PERCENTILE(z, 0.3) AS percentile_30 FROM T;",
"SELECT APPROX_PERCENTILE(x, ARRAY[0.1, 0.2, 0.3]), APPROX_PERCENTILE(x, 0.4) FROM T;",
"SELECT APPROX_PERCENTILE(x, ARRAY[0.1, 0.2, 0.3]) FROM (SELECT APPROX_PERCENTILE(x, 0.1) from T);"
};

private static final ImmutableMap<String, String> STATEMENT_TO_REWRITTEN_STATEMENT =
new ImmutableMap.Builder<String, String>()
.put("SELECT APPROX_PERCENTILE(x, 0.1) AS percentile_10, APPROX_PERCENTILE(x, 0.2) AS percentile_20, APPROX_PERCENTILE(x, 0.3) AS percentile_30 FROM T;",
"SELECT APPROX_PERCENTILE(x, ARRAY[0.1, 0.2, 0.3])[1] AS percentile_10, APPROX_PERCENTILE(x, ARRAY[0.1, 0.2, 0.3])[2] AS percentile_20, APPROX_PERCENTILE(x, ARRAY[0.1, 0.2, 0.3])[3] AS percentile_30 FROM T;")
.put("SELECT APPROX_PERCENTILE(x, 0.1), APPROX_PERCENTILE(x, 0.2) AS percentile_20, APPROX_PERCENTILE(x, 0.3) FROM T;",
"SELECT APPROX_PERCENTILE(x, ARRAY[0.1, 0.2, 0.3])[1], APPROX_PERCENTILE(x, ARRAY[0.1, 0.2, 0.3])[2] AS percentile_20, APPROX_PERCENTILE(x, ARRAY[0.1, 0.2, 0.3])[3] FROM T;")
.put("SELECT approx_percentile(y, 0.2), x + 1, approx_percentile(y, 0.1) from T group by 2;",
"SELECT APPROX_PERCENTILE(y, ARRAY[0.1, 0.2])[2], x + 1, APPROX_PERCENTILE(y, ARRAY[0.1, 0.2])[1] from T group by 2;")
.put("SELECT APPROX_PERCENTILE(x, 0.1) AS percentile_10, APPROX_PERCENTILE(x, 0.2) AS percentile_20, APPROX_PERCENTILE(z, 0.3) AS percentile_30 FROM T;",
"SELECT APPROX_PERCENTILE(x, ARRAY[0.1, 0.2])[1] AS percentile_10, APPROX_PERCENTILE(x, ARRAY[0.1, 0.2])[2] AS percentile_20, APPROX_PERCENTILE(z, 0.3) AS percentile_30 FROM T;")
.put("SELECT APPROX_PERCENTILE(x, ARRAY[0.1, 0.2, 0.3]), APPROX_PERCENTILE(x, 0.4), APPROX_PERCENTILE(x, 0.5) FROM T;",
"SELECT APPROX_PERCENTILE(x, ARRAY[0.1, 0.2, 0.3]), APPROX_PERCENTILE(x, ARRAY[0.4, 0.5])[1], APPROX_PERCENTILE(x, ARRAY[0.4, 0.5])[2] FROM T;")
.put("SELECT APPROX_PERCENTILE(x, 0.1) FROM (SELECT APPROX_PERCENTILE(x, 0.2) FROM T);",
"SELECT APPROX_PERCENTILE(x, ARRAY[0.1, 0.2])[1] FROM (SELECT APPROX_PERCENTILE(x, ARRAY[0.1, 0.2])[2] FROM T);")
.put("SELECT x FROM (SELECT APPROX_PERCENTILE(x, 0.1) AS percentile_10, APPROX_PERCENTILE(x, 0.2) AS percentile_20, APPROX_PERCENTILE(x, 0.3) AS percentile_30 FROM T);",
"SELECT x FROM (SELECT APPROX_PERCENTILE(x, ARRAY[0.1, 0.2, 0.3])[1] AS percentile_10, APPROX_PERCENTILE(x, ARRAY[0.1, 0.2, 0.3])[2] AS percentile_20, APPROX_PERCENTILE(x, ARRAY[0.1, 0.2, 0.3])[3] AS percentile_30 FROM T);")
.put("CREATE TABLE blah AS SELECT * FROM (SELECT * FROM (SELECT approx_percentile(y, 0.2), x + 1, approx_percentile(y, 0.1) from T group by 2));",
"CREATE TABLE blah AS SELECT * FROM (SELECT * FROM (SELECT APPROX_PERCENTILE(y, ARRAY[0.1, 0.2])[2], x + 1, APPROX_PERCENTILE(y, ARRAY[0.1, 0.2])[1] from T group by 2));")
.put("SELECT APPROX_PERCENTILE(x, 0.1) FROM (SELECT * FROM (SELECT approx_percentile(x, 0.1), approx_percentile(y, 0.1) from T group by 2));",
"SELECT APPROX_PERCENTILE(x, ARRAY[0.1, 0.1])[1] FROM (SELECT * FROM (SELECT APPROX_PERCENTILE(x, ARRAY[0.1, 0.1])[1], approx_percentile(y, 0.1) from T group by 2));")
.build();

private void assertStatementUnchanged(String originalStatement)
{
assertFalse(getRewriteResult(originalStatement).isPresent());
}

private void assertStatementRewritten(String originalStatement, String expectedStatement)
{
Optional<RewriteResult> result = getRewriteResult(originalStatement);
assertTrue(result.isPresent());
assertEquals(result.get().getRewrittenSql(), expectedStatement);
}

private Optional<RewriteResult> getRewriteResult(String originalStatement)
{
AstNode ast = parseStatement(originalStatement);
assertNotNull(ast);
return new ApproxPercentileRewriter(ast).rewrite();
}

@Test
public void rewriteTest()
{
for (Map.Entry<String, String> entry : STATEMENT_TO_REWRITTEN_STATEMENT.entrySet()) {
assertStatementRewritten(entry.getKey(), entry.getValue());
}

for (String sql : STATEMENTS_THAT_DONT_NEED_REWRITE) {
assertStatementUnchanged(sql);
}
}
}
Binary file added rewriter/~$pom.xml
Binary file not shown.

0 comments on commit eea9344

Please sign in to comment.