diff --git a/rewriter/src/main/java/com/facebook/coresql/rewriter/ApproxPercentileRewriter.java b/rewriter/src/main/java/com/facebook/coresql/rewriter/ApproxPercentileRewriter.java new file mode 100644 index 0000000..050e10e --- /dev/null +++ b/rewriter/src/main/java/com/facebook/coresql/rewriter/ApproxPercentileRewriter.java @@ -0,0 +1,182 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.coresql.rewriter; + +import com.facebook.coresql.parser.AstNode; +import com.facebook.coresql.parser.FunctionCall; +import com.facebook.coresql.parser.SqlParserDefaultVisitor; +import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.ListMultimap; + +import java.util.Formatter; +import java.util.Optional; + +import static com.facebook.coresql.parser.SqlParserTreeConstants.JJTARGUMENTLIST; +import static com.facebook.coresql.parser.SqlParserTreeConstants.JJTIDENTIFIER; +import static com.facebook.coresql.parser.SqlParserTreeConstants.JJTUNSIGNEDNUMERICLITERAL; +import static java.util.Objects.requireNonNull; + +public class ApproxPercentileRewriter + extends Rewriter +{ + private static final int ZERO_INDEXING_OFFSET = 1; + private static final String REPLACEMENT_FORMAT = " APPROX_PERCENTILE(%s, ARRAY%s)[%d]"; + private static final String REWRITE_NAME = "Multiple APPROX PERCENTILE with same first arg and literal second arg"; + private final AstNode root; + private final ListMultimap firstArgToApproxPercentileNode; // A map of String to the APPROX_PERCENTILE nodes with that String as their first argument + private final ListMultimap firstArgToPercentiles; // A map of String to a list of the percentiles of the APPROX_PERCENTILE nodes that have that String as their first argument + + public ApproxPercentileRewriter(AstNode root) + { + this.root = requireNonNull(root, "AST passed to rewriter was null"); + ApproxPercentilePatternMatcherResult patternMatcherResult = new ApproxPercentilePatternMatcher(root).matchPattern(); + this.firstArgToApproxPercentileNode = patternMatcherResult.getFirstArgToApproxPercentileNode(); + this.firstArgToPercentiles = patternMatcherResult.getFirstArgToPercentiles(); + } + + @Override + public Optional rewrite() + { + if (!approxPercentileRewritePatternIsPresent()) { + return Optional.empty(); + } + String rewrittenSql = unparse(root, this); + return Optional.of(new RewriteResult(REWRITE_NAME, rewrittenSql)); + } + + @Override + public void visit(FunctionCall node, Void data) + { + if (canApplyApproxPercentileRewrite(node)) { + applyApproxPercentileRewrite(node); + } + else { + defaultVisit(node, data); + } + } + + private boolean approxPercentileRewritePatternIsPresent() + { + return firstArgToApproxPercentileNode.keySet().stream() + .anyMatch(key -> firstArgToApproxPercentileNode.get(key).size() >= 2); + } + + private boolean canApplyApproxPercentileRewrite(AstNode node) + { + return firstArgToApproxPercentileNode.containsValue(node) && firstArgToApproxPercentileNode.get(unparse(getNthArgument(node, 0).get()).trim()).size() >= 2; + } + + /** + * Generates a rewritten version of the current subtree. + * + * @param node The function call node we're rewriting + */ + private void applyApproxPercentileRewrite(AstNode node) + { + // First, unparse up to the node's first child. This ensures we don't miss any special tokens + unparseUpto((AstNode) node.jjtGetChild(0)); + // Then, add the rewritten version to the rewriter's result object (i.e. stringBuilder) + String firstArg = unparse(getNthArgument(node, 0).get()).trim(); + String secondArg = unparse(getNthArgument(node, 1).get()).trim(); + Formatter formatter = new Formatter(stringBuilder); + formatter.format(REPLACEMENT_FORMAT, firstArg, firstArgToPercentiles.get(firstArg), getIndexOfPercentile(firstArg, secondArg)); + // Lastly, move to end of this node -- we've already added a rewritten version of it to the result, so we don't need to process it further + moveToEndOfNode(node); + } + + private int getIndexOfPercentile(String firstArg, String secondArg) + { + return firstArgToPercentiles.get(firstArg).indexOf(secondArg) + ZERO_INDEXING_OFFSET; + } + + private static Optional getNthArgument(AstNode node, int n) + { + Optional argList = Optional.ofNullable(node.GetFirstChildOfKind(JJTARGUMENTLIST)); + if (!argList.isPresent() || argList.get().jjtGetNumChildren() < n) { + return Optional.empty(); + } + AstNode nthArg = (AstNode) argList.get().jjtGetChild(n); + return Optional.of(nthArg); + } + + private static boolean hasUnsignedLiteralSecondArg(AstNode node) + { + return getNthArgument(node, 1).filter(astNode -> astNode.getId() == JJTUNSIGNEDNUMERICLITERAL).isPresent(); + } + + private static boolean isApproxPercentileNode(AstNode node) + { + Optional identifier = Optional.ofNullable(node.GetFirstChildOfKind(JJTIDENTIFIER)); + if (!identifier.isPresent()) { + return false; + } + Optional image = Optional.ofNullable(identifier.get().GetImage()); + return image.isPresent() && image.get().equalsIgnoreCase("APPROX_PERCENTILE"); + } + + private static class ApproxPercentilePatternMatcher + extends SqlParserDefaultVisitor + { + private final AstNode root; + private final ImmutableListMultimap.Builder firstArgToApproxPercentileNode = ImmutableListMultimap.builder(); + private final ImmutableListMultimap.Builder firstArgToPercentiles = ImmutableListMultimap.builder(); + + public ApproxPercentilePatternMatcher(AstNode root) + { + this.root = requireNonNull(root, "AST passed to pattern matcher was null"); + } + + public ApproxPercentilePatternMatcherResult matchPattern() + { + root.jjtAccept(this, null); + return new ApproxPercentilePatternMatcherResult(firstArgToApproxPercentileNode.build(), firstArgToPercentiles.build()); + } + + @Override + public void visit(FunctionCall node, Void data) + { + if (isApproxPercentileNode(node) && hasUnsignedLiteralSecondArg(node)) { + String firstArg = unparse(getNthArgument(node, 0).get()).trim(); + String secondArg = unparse(getNthArgument(node, 1).get()).trim(); + firstArgToApproxPercentileNode.put(firstArg, node); + firstArgToPercentiles.put(firstArg, secondArg); + } + defaultVisit(node, data); + } + } + + private static class ApproxPercentilePatternMatcherResult + { + private final ListMultimap firstArgToApproxPercentileNode; + private final ListMultimap firstArgToPercentiles; + + public ApproxPercentilePatternMatcherResult(ListMultimap firstArgToApproxPercentileNode, + ListMultimap firstArgToPercentiles) + { + this.firstArgToApproxPercentileNode = requireNonNull(firstArgToApproxPercentileNode); + this.firstArgToPercentiles = requireNonNull(firstArgToPercentiles); + } + + public ListMultimap getFirstArgToApproxPercentileNode() + { + return firstArgToApproxPercentileNode; + } + + public ListMultimap getFirstArgToPercentiles() + { + return firstArgToPercentiles; + } + } +} diff --git a/rewriter/src/test/java/com/facebook/coresql/rewriter/TestApproxPercentileRewriter.java b/rewriter/src/test/java/com/facebook/coresql/rewriter/TestApproxPercentileRewriter.java new file mode 100644 index 0000000..9d5016a --- /dev/null +++ b/rewriter/src/test/java/com/facebook/coresql/rewriter/TestApproxPercentileRewriter.java @@ -0,0 +1,117 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.coresql.rewriter; + +import com.facebook.coresql.parser.AstNode; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.Map; +import java.util.Optional; + +import static com.facebook.coresql.parser.ParserHelper.parseStatement; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertTrue; + +public class TestApproxPercentileRewriter +{ + private static final String[] STATEMENTS_THAT_DONT_NEED_REWRITE = new String[] { + // True Negative + "CREATE TABLE blah AS SELECT * FROM (SELECT * FROM (SELECT foo FROM T ORDER BY x LIMIT 10) ORDER BY y LIMIT 10) ORDER BY z LIMIT 10;", + "SELECT dealer_id, sales OVER (PARTITION BY dealer_id ORDER BY sales);", + "INSERT INTO blah SELECT * FROM (SELECT t.date, t.code, t.qty FROM sales AS t ORDER BY t.date LIMIT 100);", + "SELECT (true or false) and false;", + "SELECT * FROM T ORDER BY y;", + "SELECT * FROM T ORDER BY y LIMIT 10;", + "use a.b;", + " SELECT 1;", + "SELECT a FROM T;", + "SELECT a FROM T WHERE p1 > p2;", + "SELECT a, b, c FROM T WHERE c1 < c2 and c3 < c4;", + "SELECT CASE a WHEN IN ( 1 ) THEN b ELSE c END AS x, b, c FROM T WHERE c1 < c2 and c3 < c4;", + "SELECT T.* FROM T JOIN W ON T.x = W.x;", + "SELECT NULL;", + "SELECT ARRAY[x] FROM T;", + "SELECT TRANSFORM(ARRAY[x], x -> x + 2) AS arra FROM T;", + "CREATE TABLE T AS SELECT TRANSFORM(ARRAY[x], x -> x + 2) AS arra FROM T;", + "INSERT INTO T SELECT TRANSFORM(ARRAY[x], x -> x + 2) AS arra FROM T;", + "SELECT ROW_NUMBER() OVER(PARTITION BY x) FROM T;", + "SELECT x, SUM(y) OVER (PARTITION BY y ORDER BY 1) AS min\n" + + "FROM (values ('b',10), ('a', 10)) AS T(x, y)\n;", + "SELECT\n" + + " CAST(MAP() AS map>) AS \"bool_tensor_features\";", + "SELECT f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f())))))))))))))))))))))))))))));", + "SELECT abs, 2 as abs;", + // False Positive + "SELECT APPROX_PERCENTILE(x, 0.1) AS percentile_10, APPROX_PERCENTILE(y, 0.2) AS percentile_20, APPROX_PERCENTILE(z, 0.3) AS percentile_30 FROM T;", + "SELECT APPROX_PERCENTILE(x, ARRAY[0.1, 0.2, 0.3]), APPROX_PERCENTILE(x, 0.4) FROM T;", + "SELECT APPROX_PERCENTILE(x, ARRAY[0.1, 0.2, 0.3]) FROM (SELECT APPROX_PERCENTILE(x, 0.1) from T);" + }; + + private static final ImmutableMap STATEMENT_TO_REWRITTEN_STATEMENT = + new ImmutableMap.Builder() + .put("SELECT APPROX_PERCENTILE(x, 0.1) AS percentile_10, APPROX_PERCENTILE(x, 0.2) AS percentile_20, APPROX_PERCENTILE(x, 0.3) AS percentile_30 FROM T;", + "SELECT APPROX_PERCENTILE(x, ARRAY[0.1, 0.2, 0.3])[1] AS percentile_10, APPROX_PERCENTILE(x, ARRAY[0.1, 0.2, 0.3])[2] AS percentile_20, APPROX_PERCENTILE(x, ARRAY[0.1, 0.2, 0.3])[3] AS percentile_30 FROM T;") + .put("SELECT APPROX_PERCENTILE(x, 0.1), APPROX_PERCENTILE(x, 0.2) AS percentile_20, APPROX_PERCENTILE(x, 0.3) FROM T;", + "SELECT APPROX_PERCENTILE(x, ARRAY[0.1, 0.2, 0.3])[1], APPROX_PERCENTILE(x, ARRAY[0.1, 0.2, 0.3])[2] AS percentile_20, APPROX_PERCENTILE(x, ARRAY[0.1, 0.2, 0.3])[3] FROM T;") + .put("SELECT approx_percentile(y, 0.2), x + 1, approx_percentile(y, 0.1) from T group by 2;", + "SELECT APPROX_PERCENTILE(y, ARRAY[0.2, 0.1])[1], x + 1, APPROX_PERCENTILE(y, ARRAY[0.2, 0.1])[2] from T group by 2;") + .put("SELECT APPROX_PERCENTILE(x, 0.1) AS percentile_10, APPROX_PERCENTILE(x, 0.2) AS percentile_20, APPROX_PERCENTILE(z, 0.3) AS percentile_30 FROM T;", + "SELECT APPROX_PERCENTILE(x, ARRAY[0.1, 0.2])[1] AS percentile_10, APPROX_PERCENTILE(x, ARRAY[0.1, 0.2])[2] AS percentile_20, APPROX_PERCENTILE(z, 0.3) AS percentile_30 FROM T;") + .put("SELECT APPROX_PERCENTILE(x, ARRAY[0.1, 0.2, 0.3]), APPROX_PERCENTILE(x, 0.4), APPROX_PERCENTILE(x, 0.5) FROM T;", + "SELECT APPROX_PERCENTILE(x, ARRAY[0.1, 0.2, 0.3]), APPROX_PERCENTILE(x, ARRAY[0.4, 0.5])[1], APPROX_PERCENTILE(x, ARRAY[0.4, 0.5])[2] FROM T;") + .put("SELECT APPROX_PERCENTILE(x, 0.1) FROM (SELECT APPROX_PERCENTILE(x, 0.2) FROM T);", + "SELECT APPROX_PERCENTILE(x, ARRAY[0.1, 0.2])[1] FROM (SELECT APPROX_PERCENTILE(x, ARRAY[0.1, 0.2])[2] FROM T);") + .put("SELECT x FROM (SELECT APPROX_PERCENTILE(x, 0.1) AS percentile_10, APPROX_PERCENTILE(x, 0.2) AS percentile_20, APPROX_PERCENTILE(x, 0.3) AS percentile_30 FROM T);", + "SELECT x FROM (SELECT APPROX_PERCENTILE(x, ARRAY[0.1, 0.2, 0.3])[1] AS percentile_10, APPROX_PERCENTILE(x, ARRAY[0.1, 0.2, 0.3])[2] AS percentile_20, APPROX_PERCENTILE(x, ARRAY[0.1, 0.2, 0.3])[3] AS percentile_30 FROM T);") + .put("CREATE TABLE blah AS SELECT * FROM (SELECT * FROM (SELECT approx_percentile(y, 0.2), x + 1, approx_percentile(y, 0.1) from T group by 2));", + "CREATE TABLE blah AS SELECT * FROM (SELECT * FROM (SELECT APPROX_PERCENTILE(y, ARRAY[0.2, 0.1])[1], x + 1, APPROX_PERCENTILE(y, ARRAY[0.2, 0.1])[2] from T group by 2));") + .put("SELECT APPROX_PERCENTILE(x, 0.1) FROM (SELECT * FROM (SELECT approx_percentile(x, 0.1), approx_percentile(y, 0.1) from T group by 2));", + "SELECT APPROX_PERCENTILE(x, ARRAY[0.1, 0.1])[1] FROM (SELECT * FROM (SELECT APPROX_PERCENTILE(x, ARRAY[0.1, 0.1])[1], approx_percentile(y, 0.1) from T group by 2));") + .build(); + + private void assertStatementUnchanged(String originalStatement) + { + assertFalse(getRewriteResult(originalStatement).isPresent()); + } + + private void assertStatementRewritten(String originalStatement, String expectedStatement) + { + Optional result = getRewriteResult(originalStatement); + assertTrue(result.isPresent()); + assertEquals(result.get().getRewrittenSql(), expectedStatement); + } + + private Optional getRewriteResult(String originalStatement) + { + AstNode ast = parseStatement(originalStatement); + assertNotNull(ast); + return new ApproxPercentileRewriter(ast).rewrite(); + } + + @Test + public void rewriteTest() + { + for (Map.Entry entry : STATEMENT_TO_REWRITTEN_STATEMENT.entrySet()) { + assertStatementRewritten(entry.getKey(), entry.getValue()); + } + + for (String sql : STATEMENTS_THAT_DONT_NEED_REWRITE) { + assertStatementUnchanged(sql); + } + } +} diff --git a/rewriter/~$pom.xml b/rewriter/~$pom.xml new file mode 100644 index 0000000..0499b88 Binary files /dev/null and b/rewriter/~$pom.xml differ